# Copyright 2016-2017 Tobias Grosser # # Use of this software is governed by the MIT license # # Written by Tobias Grosser, Weststrasse 47, CH-8003, Zurich import sys import isl # Test that isl objects can be constructed. # # This tests: # - construction from a string # - construction from an integer # - static constructor without a parameter # - conversion construction # - construction of empty union set # # The tests to construct from integers and strings cover functionality that # is also tested in the parameter type tests, but here the presence of # multiple overloaded constructors and overload resolution is tested. # def test_constructors(): zero1 = isl.val("0") assert(zero1.is_zero()) zero2 = isl.val(0) assert(zero2.is_zero()) zero3 = isl.val.zero() assert(zero3.is_zero()) bs = isl.basic_set("{ [1] }") result = isl.set("{ [1] }") s = isl.set(bs) assert(s.is_equal(result)) us = isl.union_set("{ A[1]; B[2, 3] }") empty = isl.union_set.empty() assert(us.is_equal(us.union(empty))) # Test integer function parameters for a particular integer value. # def test_int(i): val_int = isl.val(i) val_str = isl.val(str(i)) assert(val_int.eq(val_str)) # Test integer function parameters. # # Verify that extreme values and zero work. # def test_parameters_int(): test_int(sys.maxsize) test_int(-sys.maxsize - 1) test_int(0) # Test isl objects parameters. # # Verify that isl objects can be passed as lvalue and rvalue parameters. # Also verify that isl object parameters are automatically type converted if # there is an inheritance relation. Finally, test function calls without # any additional parameters, apart from the isl object on which # the method is called. # def test_parameters_obj(): a = isl.set("{ [0] }") b = isl.set("{ [1] }") c = isl.set("{ [2] }") expected = isl.set("{ [i] : 0 <= i <= 2 }") tmp = a.union(b) res_lvalue_param = tmp.union(c) assert(res_lvalue_param.is_equal(expected)) res_rvalue_param = a.union(b).union(c) assert(res_rvalue_param.is_equal(expected)) a2 = isl.basic_set("{ [0] }") assert(a.is_equal(a2)) two = isl.val(2) half = isl.val("1/2") res_only_this_param = two.inv() assert(res_only_this_param.eq(half)) # Test different kinds of parameters to be passed to functions. # # This includes integer and isl object parameters. # def test_parameters(): test_parameters_int() test_parameters_obj() # Test that isl objects are returned correctly. # # This only tests that after combining two objects, the result is successfully # returned. # def test_return_obj(): one = isl.val("1") two = isl.val("2") three = isl.val("3") res = one.add(two) assert(res.eq(three)) # Test that integer values are returned correctly. # def test_return_int(): one = isl.val("1") neg_one = isl.val("-1") zero = isl.val("0") assert(one.sgn() > 0) assert(neg_one.sgn() < 0) assert(zero.sgn() == 0) # Test that isl_bool values are returned correctly. # # In particular, check the conversion to bool in case of true and false. # def test_return_bool(): empty = isl.set("{ : false }") univ = isl.set("{ : }") b_true = empty.is_empty() b_false = univ.is_empty() assert(b_true) assert(not b_false) # Test that strings are returned correctly. # Do so by calling overloaded isl.ast_build.from_expr methods. # def test_return_string(): context = isl.set("[n] -> { : }") build = isl.ast_build.from_context(context) pw_aff = isl.pw_aff("[n] -> { [n] }") set = isl.set("[n] -> { : n >= 0 }") expr = build.expr_from(pw_aff) expected_string = "n" assert(expected_string == expr.to_C_str()) expr = build.expr_from(set) expected_string = "n >= 0" assert(expected_string == expr.to_C_str()) # Test that return values are handled correctly. # # Test that isl objects, integers, boolean values, and strings are # returned correctly. # def test_return(): test_return_obj() test_return_int() test_return_bool() test_return_string() # A class that is used to test isl.id.user. # class S: def __init__(self): self.value = 42 # Test isl.id.user. # # In particular, check that the object attached to an identifier # can be retrieved again. # def test_user(): id = isl.id("test", 5) id2 = isl.id("test2") id3 = isl.id("S", S()) assert id.user() == 5, f"unexpected user object {id.user()}" assert id2.user() is None, f"unexpected user object {id2.user()}" s = id3.user() assert isinstance(s, S), f"unexpected user object {s}" assert s.value == 42, f"unexpected user object {s}" # Test that foreach functions are modeled correctly. # # Verify that closures are correctly called as callback of a 'foreach' # function and that variables captured by the closure work correctly. Also # check that the foreach function handles exceptions thrown from # the closure and that it propagates the exception. # def test_foreach(): s = isl.set("{ [0]; [1]; [2] }") list = [] def add(bs): list.append(bs) s.foreach_basic_set(add) assert(len(list) == 3) assert(list[0].is_subset(s)) assert(list[1].is_subset(s)) assert(list[2].is_subset(s)) assert(not list[0].is_equal(list[1])) assert(not list[0].is_equal(list[2])) assert(not list[1].is_equal(list[2])) def fail(bs): raise Exception("fail") caught = False try: s.foreach_basic_set(fail) except: caught = True assert(caught) # Test the functionality of "foreach_scc" functions. # # In particular, test it on a list of elements that can be completely sorted # but where two of the elements ("a" and "b") are incomparable. # def test_foreach_scc(): list = isl.id_list(3) sorted = [isl.id_list(3)] data = { 'a' : isl.map("{ [0] -> [1] }"), 'b' : isl.map("{ [1] -> [0] }"), 'c' : isl.map("{ [i = 0:1] -> [i] }"), } for k, v in data.items(): list = list.add(k) id = data['a'].space().domain().identity_multi_pw_aff_on_domain() def follows(a, b): map = data[b.name()].apply_domain(data[a.name()]) return not map.lex_ge_at(id).is_empty() def add_single(scc): assert(scc.size() == 1) sorted[0] = sorted[0].concat(scc) list.foreach_scc(follows, add_single) assert(sorted[0].size() == 3) assert(sorted[0].at(0).name() == "b") assert(sorted[0].at(1).name() == "c") assert(sorted[0].at(2).name() == "a") # Test the functionality of "every" functions. # # In particular, test the generic functionality and # test that exceptions are properly propagated. # def test_every(): us = isl.union_set("{ A[i]; B[j] }") def is_empty(s): return s.is_empty() assert(not us.every_set(is_empty)) def is_non_empty(s): return not s.is_empty() assert(us.every_set(is_non_empty)) def in_A(s): return s.is_subset(isl.set("{ A[x] }")) assert(not us.every_set(in_A)) def not_in_A(s): return not s.is_subset(isl.set("{ A[x] }")) assert(not us.every_set(not_in_A)) def fail(s): raise Exception("fail") caught = False try: us.ever_set(fail) except: caught = True assert(caught) # Check basic construction of spaces. # def test_space(): unit = isl.space.unit() set_space = unit.add_named_tuple("A", 3) map_space = set_space.add_named_tuple("B", 2) set = isl.set.universe(set_space) map = isl.map.universe(map_space) assert(set.is_equal(isl.set("{ A[*,*,*] }"))) assert(map.is_equal(isl.map("{ A[*,*,*] -> B[*,*] }"))) # Construct a simple schedule tree with an outer sequence node and # a single-dimensional band node in each branch, with one of them # marked coincident. # def construct_schedule_tree(): A = isl.union_set("{ A[i] : 0 <= i < 10 }") B = isl.union_set("{ B[i] : 0 <= i < 20 }") node = isl.schedule_node.from_domain(A.union(B)) node = node.child(0) filters = isl.union_set_list(A).add(B) node = node.insert_sequence(filters) f_A = isl.multi_union_pw_aff("[ { A[i] -> [i] } ]") node = node.child(0) node = node.child(0) node = node.insert_partial_schedule(f_A) node = node.member_set_coincident(0, True) node = node.ancestor(2) f_B = isl.multi_union_pw_aff("[ { B[i] -> [i] } ]") node = node.child(1) node = node.child(0) node = node.insert_partial_schedule(f_B) node = node.ancestor(2) return node.schedule() # Test basic schedule tree functionality. # # In particular, create a simple schedule tree and # - check that the root node is a domain node # - test map_descendant_bottom_up # - test foreach_descendant_top_down # - test every_descendant # def test_schedule_tree(): schedule = construct_schedule_tree() root = schedule.root() assert(type(root) == isl.schedule_node_domain) count = [0] def inc_count(node): count[0] += 1 return node root = root.map_descendant_bottom_up(inc_count) assert(count[0] == 8) def fail_map(node): raise Exception("fail") return node caught = False try: root.map_descendant_bottom_up(fail_map) except: caught = True assert(caught) count = [0] def inc_count(node): count[0] += 1 return True root.foreach_descendant_top_down(inc_count) assert(count[0] == 8) count = [0] def inc_count(node): count[0] += 1 return False root.foreach_descendant_top_down(inc_count) assert(count[0] == 1) def is_not_domain(node): return type(node) != isl.schedule_node_domain assert(root.child(0).every_descendant(is_not_domain)) assert(not root.every_descendant(is_not_domain)) def fail(node): raise Exception("fail") caught = False try: root.every_descendant(fail) except: caught = True assert(caught) domain = root.domain() filters = [isl.union_set("{}")] def collect_filters(node): if type(node) == isl.schedule_node_filter: filters[0] = filters[0].union(node.filter()) return True root.every_descendant(collect_filters) assert(domain.is_equal(filters[0])) # Test marking band members for unrolling. # "schedule" is the schedule created by construct_schedule_tree. # It schedules two statements, with 10 and 20 instances, respectively. # Unrolling all band members therefore results in 30 at-domain calls # by the AST generator. # def test_ast_build_unroll(schedule): root = schedule.root() def mark_unroll(node): if type(node) == isl.schedule_node_band: node = node.member_set_ast_loop_unroll(0) return node root = root.map_descendant_bottom_up(mark_unroll) schedule = root.schedule() count_ast = [0] def inc_count_ast(node, build): count_ast[0] += 1 return node build = isl.ast_build() build = build.set_at_each_domain(inc_count_ast) ast = build.node_from(schedule) assert(count_ast[0] == 30) # Test basic AST generation from a schedule tree. # # In particular, create a simple schedule tree and # - generate an AST from the schedule tree # - test at_each_domain # - test unrolling # def test_ast_build(): schedule = construct_schedule_tree() count_ast = [0] def inc_count_ast(node, build): count_ast[0] += 1 return node build = isl.ast_build() build_copy = build.set_at_each_domain(inc_count_ast) ast = build.node_from(schedule) assert(count_ast[0] == 0) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert(count_ast[0] == 2) build = build_copy count_ast[0] = 0 ast = build.node_from(schedule) assert(count_ast[0] == 2) do_fail = True count_ast_fail = [0] def fail_inc_count_ast(node, build): count_ast_fail[0] += 1 if do_fail: raise Exception("fail") return node build = isl.ast_build() build = build.set_at_each_domain(fail_inc_count_ast) caught = False try: ast = build.node_from(schedule) except: caught = True assert(caught) assert(count_ast_fail[0] > 0) build_copy = build build_copy = build_copy.set_at_each_domain(inc_count_ast) count_ast[0] = 0 ast = build_copy.node_from(schedule) assert(count_ast[0] == 2) count_ast_fail[0] = 0 do_fail = False ast = build.node_from(schedule) assert(count_ast_fail[0] == 2) test_ast_build_unroll(schedule) # Test basic AST expression generation from an affine expression. # def test_ast_build_expr(): pa = isl.pw_aff("[n] -> { [n + 1] }") build = isl.ast_build.from_context(pa.domain()) op = build.expr_from(pa) assert(type(op) == isl.ast_expr_op_add) assert(op.n_arg() == 2) # Test the isl Python interface # # This includes: # - Object construction # - Different parameter types # - Different return types # - isl.id.user # - Foreach functions # - Foreach SCC function # - Every functions # - Spaces # - Schedule trees # - AST generation # - AST expression generation # test_constructors() test_parameters() test_return() test_user() test_foreach() test_foreach_scc() test_every() test_space() test_schedule_tree() test_ast_build() test_ast_build_expr()