Skip to content

Commit 236d5b8

Browse files
author
whitequark
committedJul 21, 2015
Add support for Assert.
·
8.01.0rc1
1 parent 5d518dc commit 236d5b8

File tree

7 files changed

+199
-21
lines changed

7 files changed

+199
-21
lines changed
 

‎artiq/compiler/ir.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,13 @@ class User(NamedValue):
106106
def __init__(self, operands, typ, name):
107107
super().__init__(typ, name)
108108
self.operands = []
109-
if operands is not None:
110-
self.set_operands(operands)
109+
self.set_operands(operands)
111110

112111
def set_operands(self, new_operands):
113-
for operand in self.operands:
112+
for operand in set(self.operands):
114113
operand.uses.remove(self)
115114
self.operands = new_operands
116-
for operand in self.operands:
115+
for operand in set(self.operands):
117116
operand.uses.add(self)
118117

119118
def drop_references(self):
@@ -162,6 +161,9 @@ def remove_from_parent(self):
162161
def erase(self):
163162
self.remove_from_parent()
164163
self.drop_references()
164+
# Check this after drop_references in case this
165+
# is a self-referencing phi.
166+
assert not any(self.uses)
165167

166168
def replace_with(self, value):
167169
self.replace_all_uses_with(value)
@@ -220,7 +222,21 @@ def incoming_value_for_block(self, target_block):
220222
def add_incoming(self, value, block):
221223
assert value.type == self.type
222224
self.operands.append(value)
225+
value.uses.add(self)
223226
self.operands.append(block)
227+
block.uses.add(self)
228+
229+
def remove_incoming_value(self, value):
230+
index = self.operands.index(value)
231+
self.operands[index].uses.remove(self)
232+
self.operands[index + 1].uses.remove(self)
233+
del self.operands[index:index + 2]
234+
235+
def remove_incoming_block(self, block):
236+
index = self.operands.index(block)
237+
self.operands[index - 1].uses.remove(self)
238+
self.operands[index].uses.remove(self)
239+
del self.operands[index - 1:index + 1]
224240

225241
def __str__(self):
226242
if builtins.is_none(self.type):
@@ -268,9 +284,13 @@ def remove_from_parent(self):
268284
self.function.remove(self)
269285

270286
def erase(self):
271-
for insn in self.instructions:
287+
# self.instructions is updated while iterating
288+
for insn in list(self.instructions):
272289
insn.erase()
273290
self.remove_from_parent()
291+
# Check this after erasing instructions in case the block
292+
# loops into itself.
293+
assert not any(self.uses)
274294

275295
def prepend(self, insn):
276296
assert isinstance(insn, Instruction)
@@ -817,6 +837,7 @@ class Select(Instruction):
817837
"""
818838
def __init__(self, cond, if_true, if_false, name=""):
819839
assert isinstance(cond, Value)
840+
assert builtins.is_bool(cond.type)
820841
assert isinstance(if_true, Value)
821842
assert isinstance(if_false, Value)
822843
assert if_true.type == if_false.type
@@ -864,8 +885,10 @@ class BranchIf(Terminator):
864885
"""
865886
def __init__(self, cond, if_true, if_false, name=""):
866887
assert isinstance(cond, Value)
888+
assert builtins.is_bool(cond.type)
867889
assert isinstance(if_true, BasicBlock)
868890
assert isinstance(if_false, BasicBlock)
891+
assert if_true != if_false # use Branch instead
869892
super().__init__([cond, if_true, if_false], builtins.TNone(), name)
870893

871894
def opcode(self):

‎artiq/compiler/transforms/artiq_ir_generator.py

Lines changed: 106 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,22 @@ class ARTIQIRGenerator(algorithm.Visitor):
3939
set of variables that will be resolved in global scope
4040
:ivar current_block: (:class:`ir.BasicBlock`)
4141
basic block to which any new instruction will be appended
42-
:ivar current_env: (:class:`ir.Environment`)
42+
:ivar current_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
4343
the chained function environment, containing variables that
4444
can become upvalues
45-
:ivar current_private_env: (:class:`ir.Environment`)
45+
:ivar current_private_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
4646
the private function environment, containing internal state
4747
:ivar current_assign: (:class:`ir.Value` or None)
4848
the right-hand side of current assignment statement, or
4949
a component of a composite right-hand side when visiting
5050
a composite left-hand side, such as, in ``x, y = z``,
5151
the 2nd tuple element when visting ``y``
52+
:ivar current_assert_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
53+
the environment where the individual components of current assert
54+
statement are stored until display
55+
:ivar current_assert_subexprs: (list of (:class:`ast.AST`, string))
56+
the mapping from components of current assert statement to the names
57+
their values have in :ivar:`current_assert_env`
5258
:ivar break_target: (:class:`ir.BasicBlock` or None)
5359
the basic block to which ``break`` will transfer control
5460
:ivar continue_target: (:class:`ir.BasicBlock` or None)
@@ -72,6 +78,8 @@ def __init__(self, module_name, engine):
7278
self.current_env = None
7379
self.current_private_env = None
7480
self.current_assign = None
81+
self.current_assert_env = None
82+
self.current_assert_subexprs = None
7583
self.break_target = None
7684
self.continue_target = None
7785
self.return_target = None
@@ -203,7 +211,7 @@ def visit_function(self, node, is_lambda):
203211
self.append(ir.SetLocal(env, arg_name, args[index]))
204212
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)):
205213
default = self.append(ir.GetLocal(self.current_env, env_default_name))
206-
value = self.append(ir.Builtin("unwrap", [optargs[index], default],
214+
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
207215
typ.optargs[arg_name]))
208216
self.append(ir.SetLocal(env, arg_name, value))
209217

@@ -736,7 +744,7 @@ def visit_TupleT(self, node):
736744
for index, elt_node in enumerate(node.elts):
737745
self.current_assign = \
738746
self.append(ir.GetAttr(old_assign, index,
739-
name="{}.{}".format(old_assign.name, index)),
747+
name="{}.e{}".format(old_assign.name, index)),
740748
loc=elt_node.loc)
741749
self.visit(elt_node)
742750
finally:
@@ -805,18 +813,26 @@ def body_gen(index):
805813
def visit_BoolOpT(self, node):
806814
blocks = []
807815
for value_node in node.values:
816+
value_head = self.current_block
808817
value = self.visit(value_node)
809-
blocks.append((value, self.current_block))
818+
self.instrument_assert(value_node, value)
819+
value_tail = self.current_block
820+
821+
blocks.append((value, value_head, value_tail))
810822
self.current_block = self.add_block()
811823

812824
tail = self.current_block
813825
phi = self.append(ir.Phi(node.type))
814-
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
815-
phi.add_incoming(value, block)
816-
if isinstance(node.op, ast.And):
817-
block.append(ir.BranchIf(value, next_block, tail))
826+
for ((value, value_head, value_tail), (next_value_head, next_value_tail)) in \
827+
zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]):
828+
phi.add_incoming(value, value_tail)
829+
if next_value_head != tail:
830+
if isinstance(node.op, ast.And):
831+
value_tail.append(ir.BranchIf(value, next_value_head, tail))
832+
else:
833+
value_tail.append(ir.BranchIf(value, tail, next_value_head))
818834
else:
819-
block.append(ir.BranchIf(value, tail, next_block))
835+
value_tail.append(ir.Branch(tail))
820836
return phi
821837

822838
def visit_UnaryOpT(self, node):
@@ -1005,7 +1021,7 @@ def polymorphic_compare_pair_inclusion(self, op, needle, haystack):
10051021
ir.Constant(False, builtins.TBool())))
10061022
result = self.append(ir.Select(result, on_step,
10071023
ir.Constant(False, builtins.TBool())))
1008-
elif builtins.isiterable(haystack.type):
1024+
elif builtins.is_iterable(haystack.type):
10091025
length = self.iterable_len(haystack)
10101026

10111027
cmp_result = loop_body2 = None
@@ -1068,8 +1084,10 @@ def visit_CompareT(self, node):
10681084
# of comparisons.
10691085
blocks = []
10701086
lhs = self.visit(node.left)
1087+
self.instrument_assert(node.left, lhs)
10711088
for op, rhs_node in zip(node.ops, node.comparators):
10721089
rhs = self.visit(rhs_node)
1090+
self.instrument_assert(rhs_node, rhs)
10731091
result = self.polymorphic_compare_pair(op, lhs, rhs)
10741092
blocks.append((result, self.current_block))
10751093
self.current_block = self.add_block()
@@ -1079,7 +1097,10 @@ def visit_CompareT(self, node):
10791097
phi = self.append(ir.Phi(node.type))
10801098
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
10811099
phi.add_incoming(value, block)
1082-
block.append(ir.BranchIf(value, next_block, tail))
1100+
if next_block != tail:
1101+
block.append(ir.BranchIf(value, next_block, tail))
1102+
else:
1103+
block.append(ir.Branch(tail))
10831104
return phi
10841105

10851106
def visit_builtin_call(self, node):
@@ -1211,6 +1232,79 @@ def visit_CallT(self, node):
12111232
self.current_block = after_invoke
12121233
return invoke
12131234

1235+
def instrument_assert(self, node, value):
1236+
if self.current_assert_env is not None:
1237+
if isinstance(value, ir.Constant):
1238+
return # don't display the values of constants
1239+
1240+
if any([algorithm.compare(node, subexpr)
1241+
for (subexpr, name) in self.current_assert_subexprs]):
1242+
return # don't display the same subexpression twice
1243+
1244+
name = self.current_assert_env.type.add("subexpr", ir.TOption(node.type))
1245+
value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)),
1246+
loc=node.loc)
1247+
self.append(ir.SetLocal(self.current_assert_env, name, value_opt),
1248+
loc=node.loc)
1249+
self.current_assert_subexprs.append((node, name))
1250+
1251+
def visit_Assert(self, node):
1252+
try:
1253+
assert_env = self.current_assert_env = \
1254+
self.append(ir.Alloc([], ir.TEnvironment({}), name="assertenv"))
1255+
assert_subexprs = self.current_assert_subexprs = []
1256+
init = self.current_block
1257+
1258+
prehead = self.current_block = self.add_block()
1259+
cond = self.visit(node.test)
1260+
head = self.current_block
1261+
finally:
1262+
self.current_assert_env = None
1263+
self.current_assert_subexprs = None
1264+
1265+
for subexpr_node, subexpr_name in assert_subexprs:
1266+
empty = init.append(ir.Alloc([], ir.TOption(subexpr_node.type)))
1267+
init.append(ir.SetLocal(assert_env, subexpr_name, empty))
1268+
init.append(ir.Branch(prehead))
1269+
1270+
if_failed = self.current_block = self.add_block()
1271+
1272+
if node.msg:
1273+
explanation = node.msg.s
1274+
else:
1275+
explanation = node.loc.source()
1276+
self.append(ir.Builtin("printf", [
1277+
ir.Constant("assertion failed at %s: %s\n", builtins.TStr()),
1278+
ir.Constant(str(node.loc.begin()), builtins.TStr()),
1279+
ir.Constant(str(explanation), builtins.TStr()),
1280+
], builtins.TNone()))
1281+
1282+
for subexpr_node, subexpr_name in assert_subexprs:
1283+
subexpr_head = self.current_block
1284+
subexpr_value_opt = self.append(ir.GetLocal(assert_env, subexpr_name))
1285+
subexpr_cond = self.append(ir.Builtin("is_some", [subexpr_value_opt],
1286+
builtins.TBool()))
1287+
1288+
subexpr_body = self.current_block = self.add_block()
1289+
self.append(ir.Builtin("printf", [
1290+
ir.Constant(" (%s) = ", builtins.TStr()),
1291+
ir.Constant(subexpr_node.loc.source(), builtins.TStr())
1292+
], builtins.TNone()))
1293+
subexpr_value = self.append(ir.Builtin("unwrap", [subexpr_value_opt],
1294+
subexpr_node.type))
1295+
self.polymorphic_print([subexpr_value], separator="", suffix="\n")
1296+
subexpr_postbody = self.current_block
1297+
1298+
subexpr_tail = self.current_block = self.add_block()
1299+
self.append(ir.Branch(subexpr_tail), block=subexpr_postbody)
1300+
self.append(ir.BranchIf(subexpr_cond, subexpr_body, subexpr_tail), block=subexpr_head)
1301+
1302+
self.append(ir.Builtin("abort", [], builtins.TNone()))
1303+
self.append(ir.Unreachable())
1304+
1305+
tail = self.current_block = self.add_block()
1306+
self.append(ir.BranchIf(cond, tail, if_failed), block=head)
1307+
12141308
def polymorphic_print(self, values, separator, suffix=""):
12151309
format_string = ""
12161310
args = []

‎artiq/compiler/transforms/asttyped_rewriter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,6 @@ def visit_unsupported(self, node):
421421
visit_YieldFrom = visit_unsupported
422422

423423
# stmt
424-
visit_Assert = visit_unsupported
425424
visit_ClassDef = visit_unsupported
426425
visit_Delete = visit_unsupported
427426
visit_Import = visit_unsupported

‎artiq/compiler/transforms/dead_code_eliminator.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,25 @@ def process(self, functions):
1616
def process_function(self, func):
1717
for block in func.basic_blocks:
1818
if not any(block.predecessors()) and block != func.entry():
19-
block.erase()
19+
self.remove_block(block)
20+
21+
def remove_block(self, block):
22+
# block.uses are updated while iterating
23+
for use in set(block.uses):
24+
if isinstance(use, ir.Phi):
25+
use.remove_incoming_block(block)
26+
if not any(use.operands):
27+
self.remove_instruction(use)
28+
else:
29+
assert False
30+
31+
block.erase()
32+
33+
def remove_instruction(self, insn):
34+
for use in set(insn.uses):
35+
if isinstance(use, ir.Phi):
36+
use.remove_incoming_value(insn)
37+
if not any(use.operands):
38+
self.remove_instruction(use)
39+
40+
insn.erase()

‎artiq/compiler/transforms/inferencer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,3 +947,14 @@ def makenotes(printer, typea, typeb, loca, locb):
947947
else:
948948
self._unify(self.function.return_type, node.value.type,
949949
self.function.name_loc, node.value.loc, makenotes)
950+
951+
def visit_Assert(self, node):
952+
self.generic_visit(node)
953+
self._unify(node.test.type, builtins.TBool(),
954+
node.test.loc, None)
955+
if node.msg is not None:
956+
if not isinstance(node.msg, asttyped.StrT):
957+
diag = diagnostic.Diagnostic("error",
958+
"assertion message must be a string literal", {},
959+
node.msg.loc)
960+
self.engine.process(diag)

‎artiq/compiler/transforms/llvm_ir_generator.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def llbuiltin(self, name):
9595
if llfun is not None:
9696
return llfun
9797

98-
if name in ("llvm.abort", "llvm.donothing"):
98+
if name in "llvm.donothing":
99+
llty = ll.FunctionType(ll.VoidType(), [])
100+
elif name in "llvm.trap":
99101
llty = ll.FunctionType(ll.VoidType(), [])
100102
elif name == "llvm.round.f64":
101103
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
@@ -181,6 +183,18 @@ def process_Alloc(self, insn):
181183
if ir.is_environment(insn.type):
182184
return self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True),
183185
name=insn.name)
186+
elif ir.is_option(insn.type):
187+
if len(insn.operands) == 0: # empty
188+
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
189+
return self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), False), 0,
190+
name=insn.name)
191+
elif len(insn.operands) == 1: # full
192+
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
193+
llvalue = self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), True), 0)
194+
return self.llbuilder.insert_value(llvalue, self.map(insn.operands[0]), 1,
195+
name=insn.name)
196+
else:
197+
assert False
184198
elif builtins.is_list(insn.type):
185199
llsize = self.map(insn.operands[0])
186200
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
@@ -382,7 +396,17 @@ def process_Compare(self, insn):
382396
def process_Builtin(self, insn):
383397
if insn.op == "nop":
384398
return self.llbuilder.call(self.llbuiltin("llvm.donothing"), [])
399+
if insn.op == "abort":
400+
return self.llbuilder.call(self.llbuiltin("llvm.trap"), [])
401+
elif insn.op == "is_some":
402+
optarg = self.map(insn.operands[0])
403+
return self.llbuilder.extract_value(optarg, 0,
404+
name=insn.name)
385405
elif insn.op == "unwrap":
406+
optarg = self.map(insn.operands[0])
407+
return self.llbuilder.extract_value(optarg, 1,
408+
name=insn.name)
409+
elif insn.op == "unwrap_or":
386410
optarg, default = map(self.map, insn.operands)
387411
has_arg = self.llbuilder.extract_value(optarg, 0)
388412
arg = self.llbuilder.extract_value(optarg, 1)
@@ -455,7 +479,7 @@ def process_Unreachable(self, insn):
455479

456480
def process_Raise(self, insn):
457481
# TODO: hack before EH is working
458-
llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [],
482+
llinsn = self.llbuilder.call(self.llbuiltin("llvm.trap"), [],
459483
name=insn.name)
460484
self.llbuilder.unreachable()
461485
return llinsn
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
2+
# RUN: OutputCheck %s --file-to-check=%t
3+
4+
x = "A"
5+
# CHECK-L: ${LINE:+1}: error: assertion message must be a string literal
6+
assert True, x

0 commit comments

Comments
 (0)
Please sign in to comment.