Skip to content

Commit

Permalink
Add support for Assert.
Browse files Browse the repository at this point in the history
whitequark committed Jul 21, 2015
1 parent 5d518dc commit 236d5b8
Showing 7 changed files with 199 additions and 21 deletions.
33 changes: 28 additions & 5 deletions artiq/compiler/ir.py
Original file line number Diff line number Diff line change
@@ -106,14 +106,13 @@ class User(NamedValue):
def __init__(self, operands, typ, name):
super().__init__(typ, name)
self.operands = []
if operands is not None:
self.set_operands(operands)
self.set_operands(operands)

def set_operands(self, new_operands):
for operand in self.operands:
for operand in set(self.operands):
operand.uses.remove(self)
self.operands = new_operands
for operand in self.operands:
for operand in set(self.operands):
operand.uses.add(self)

def drop_references(self):
@@ -162,6 +161,9 @@ def remove_from_parent(self):
def erase(self):
self.remove_from_parent()
self.drop_references()
# Check this after drop_references in case this
# is a self-referencing phi.
assert not any(self.uses)

def replace_with(self, value):
self.replace_all_uses_with(value)
@@ -220,7 +222,21 @@ def incoming_value_for_block(self, target_block):
def add_incoming(self, value, block):
assert value.type == self.type
self.operands.append(value)
value.uses.add(self)
self.operands.append(block)
block.uses.add(self)

def remove_incoming_value(self, value):
index = self.operands.index(value)
self.operands[index].uses.remove(self)
self.operands[index + 1].uses.remove(self)
del self.operands[index:index + 2]

def remove_incoming_block(self, block):
index = self.operands.index(block)
self.operands[index - 1].uses.remove(self)
self.operands[index].uses.remove(self)
del self.operands[index - 1:index + 1]

def __str__(self):
if builtins.is_none(self.type):
@@ -268,9 +284,13 @@ def remove_from_parent(self):
self.function.remove(self)

def erase(self):
for insn in self.instructions:
# self.instructions is updated while iterating
for insn in list(self.instructions):
insn.erase()
self.remove_from_parent()
# Check this after erasing instructions in case the block
# loops into itself.
assert not any(self.uses)

def prepend(self, insn):
assert isinstance(insn, Instruction)
@@ -817,6 +837,7 @@ class Select(Instruction):
"""
def __init__(self, cond, if_true, if_false, name=""):
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, Value)
assert isinstance(if_false, Value)
assert if_true.type == if_false.type
@@ -864,8 +885,10 @@ class BranchIf(Terminator):
"""
def __init__(self, cond, if_true, if_false, name=""):
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock)
assert if_true != if_false # use Branch instead
super().__init__([cond, if_true, if_false], builtins.TNone(), name)

def opcode(self):
118 changes: 106 additions & 12 deletions artiq/compiler/transforms/artiq_ir_generator.py
Original file line number Diff line number Diff line change
@@ -39,16 +39,22 @@ class ARTIQIRGenerator(algorithm.Visitor):
set of variables that will be resolved in global scope
:ivar current_block: (:class:`ir.BasicBlock`)
basic block to which any new instruction will be appended
:ivar current_env: (:class:`ir.Environment`)
:ivar current_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the chained function environment, containing variables that
can become upvalues
:ivar current_private_env: (:class:`ir.Environment`)
:ivar current_private_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the private function environment, containing internal state
:ivar current_assign: (:class:`ir.Value` or None)
the right-hand side of current assignment statement, or
a component of a composite right-hand side when visiting
a composite left-hand side, such as, in ``x, y = z``,
the 2nd tuple element when visting ``y``
:ivar current_assert_env: (:class:`ir.Alloc` of type :class:`ir.TEnvironment`)
the environment where the individual components of current assert
statement are stored until display
:ivar current_assert_subexprs: (list of (:class:`ast.AST`, string))
the mapping from components of current assert statement to the names
their values have in :ivar:`current_assert_env`
:ivar break_target: (:class:`ir.BasicBlock` or None)
the basic block to which ``break`` will transfer control
:ivar continue_target: (:class:`ir.BasicBlock` or None)
@@ -72,6 +78,8 @@ def __init__(self, module_name, engine):
self.current_env = None
self.current_private_env = None
self.current_assign = None
self.current_assert_env = None
self.current_assert_subexprs = None
self.break_target = None
self.continue_target = None
self.return_target = None
@@ -203,7 +211,7 @@ def visit_function(self, node, is_lambda):
self.append(ir.SetLocal(env, arg_name, args[index]))
for index, (arg_name, env_default_name) in enumerate(zip(typ.optargs, defaults)):
default = self.append(ir.GetLocal(self.current_env, env_default_name))
value = self.append(ir.Builtin("unwrap", [optargs[index], default],
value = self.append(ir.Builtin("unwrap_or", [optargs[index], default],
typ.optargs[arg_name]))
self.append(ir.SetLocal(env, arg_name, value))

@@ -736,7 +744,7 @@ def visit_TupleT(self, node):
for index, elt_node in enumerate(node.elts):
self.current_assign = \
self.append(ir.GetAttr(old_assign, index,
name="{}.{}".format(old_assign.name, index)),
name="{}.e{}".format(old_assign.name, index)),
loc=elt_node.loc)
self.visit(elt_node)
finally:
@@ -805,18 +813,26 @@ def body_gen(index):
def visit_BoolOpT(self, node):
blocks = []
for value_node in node.values:
value_head = self.current_block
value = self.visit(value_node)
blocks.append((value, self.current_block))
self.instrument_assert(value_node, value)
value_tail = self.current_block

blocks.append((value, value_head, value_tail))
self.current_block = self.add_block()

tail = self.current_block
phi = self.append(ir.Phi(node.type))
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
phi.add_incoming(value, block)
if isinstance(node.op, ast.And):
block.append(ir.BranchIf(value, next_block, tail))
for ((value, value_head, value_tail), (next_value_head, next_value_tail)) in \
zip(blocks, [(h,t) for (v,h,t) in blocks[1:]] + [(tail, tail)]):
phi.add_incoming(value, value_tail)
if next_value_head != tail:
if isinstance(node.op, ast.And):
value_tail.append(ir.BranchIf(value, next_value_head, tail))
else:
value_tail.append(ir.BranchIf(value, tail, next_value_head))
else:
block.append(ir.BranchIf(value, tail, next_block))
value_tail.append(ir.Branch(tail))
return phi

def visit_UnaryOpT(self, node):
@@ -1005,7 +1021,7 @@ def polymorphic_compare_pair_inclusion(self, op, needle, haystack):
ir.Constant(False, builtins.TBool())))
result = self.append(ir.Select(result, on_step,
ir.Constant(False, builtins.TBool())))
elif builtins.isiterable(haystack.type):
elif builtins.is_iterable(haystack.type):
length = self.iterable_len(haystack)

cmp_result = loop_body2 = None
@@ -1068,8 +1084,10 @@ def visit_CompareT(self, node):
# of comparisons.
blocks = []
lhs = self.visit(node.left)
self.instrument_assert(node.left, lhs)
for op, rhs_node in zip(node.ops, node.comparators):
rhs = self.visit(rhs_node)
self.instrument_assert(rhs_node, rhs)
result = self.polymorphic_compare_pair(op, lhs, rhs)
blocks.append((result, self.current_block))
self.current_block = self.add_block()
@@ -1079,7 +1097,10 @@ def visit_CompareT(self, node):
phi = self.append(ir.Phi(node.type))
for ((value, block), next_block) in zip(blocks, [b for (v,b) in blocks[1:]] + [tail]):
phi.add_incoming(value, block)
block.append(ir.BranchIf(value, next_block, tail))
if next_block != tail:
block.append(ir.BranchIf(value, next_block, tail))
else:
block.append(ir.Branch(tail))
return phi

def visit_builtin_call(self, node):
@@ -1211,6 +1232,79 @@ def visit_CallT(self, node):
self.current_block = after_invoke
return invoke

def instrument_assert(self, node, value):
if self.current_assert_env is not None:
if isinstance(value, ir.Constant):
return # don't display the values of constants

if any([algorithm.compare(node, subexpr)
for (subexpr, name) in self.current_assert_subexprs]):
return # don't display the same subexpression twice

name = self.current_assert_env.type.add("subexpr", ir.TOption(node.type))
value_opt = self.append(ir.Alloc([value], ir.TOption(node.type)),
loc=node.loc)
self.append(ir.SetLocal(self.current_assert_env, name, value_opt),
loc=node.loc)
self.current_assert_subexprs.append((node, name))

def visit_Assert(self, node):
try:
assert_env = self.current_assert_env = \
self.append(ir.Alloc([], ir.TEnvironment({}), name="assertenv"))
assert_subexprs = self.current_assert_subexprs = []
init = self.current_block

prehead = self.current_block = self.add_block()
cond = self.visit(node.test)
head = self.current_block
finally:
self.current_assert_env = None
self.current_assert_subexprs = None

for subexpr_node, subexpr_name in assert_subexprs:
empty = init.append(ir.Alloc([], ir.TOption(subexpr_node.type)))
init.append(ir.SetLocal(assert_env, subexpr_name, empty))
init.append(ir.Branch(prehead))

if_failed = self.current_block = self.add_block()

if node.msg:
explanation = node.msg.s
else:
explanation = node.loc.source()
self.append(ir.Builtin("printf", [
ir.Constant("assertion failed at %s: %s\n", builtins.TStr()),
ir.Constant(str(node.loc.begin()), builtins.TStr()),
ir.Constant(str(explanation), builtins.TStr()),
], builtins.TNone()))

for subexpr_node, subexpr_name in assert_subexprs:
subexpr_head = self.current_block
subexpr_value_opt = self.append(ir.GetLocal(assert_env, subexpr_name))
subexpr_cond = self.append(ir.Builtin("is_some", [subexpr_value_opt],
builtins.TBool()))

subexpr_body = self.current_block = self.add_block()
self.append(ir.Builtin("printf", [
ir.Constant(" (%s) = ", builtins.TStr()),
ir.Constant(subexpr_node.loc.source(), builtins.TStr())
], builtins.TNone()))
subexpr_value = self.append(ir.Builtin("unwrap", [subexpr_value_opt],
subexpr_node.type))
self.polymorphic_print([subexpr_value], separator="", suffix="\n")
subexpr_postbody = self.current_block

subexpr_tail = self.current_block = self.add_block()
self.append(ir.Branch(subexpr_tail), block=subexpr_postbody)
self.append(ir.BranchIf(subexpr_cond, subexpr_body, subexpr_tail), block=subexpr_head)

self.append(ir.Builtin("abort", [], builtins.TNone()))
self.append(ir.Unreachable())

tail = self.current_block = self.add_block()
self.append(ir.BranchIf(cond, tail, if_failed), block=head)

def polymorphic_print(self, values, separator, suffix=""):
format_string = ""
args = []
1 change: 0 additions & 1 deletion artiq/compiler/transforms/asttyped_rewriter.py
Original file line number Diff line number Diff line change
@@ -421,7 +421,6 @@ def visit_unsupported(self, node):
visit_YieldFrom = visit_unsupported

# stmt
visit_Assert = visit_unsupported
visit_ClassDef = visit_unsupported
visit_Delete = visit_unsupported
visit_Import = visit_unsupported
23 changes: 22 additions & 1 deletion artiq/compiler/transforms/dead_code_eliminator.py
Original file line number Diff line number Diff line change
@@ -16,4 +16,25 @@ def process(self, functions):
def process_function(self, func):
for block in func.basic_blocks:
if not any(block.predecessors()) and block != func.entry():
block.erase()
self.remove_block(block)

def remove_block(self, block):
# block.uses are updated while iterating
for use in set(block.uses):
if isinstance(use, ir.Phi):
use.remove_incoming_block(block)
if not any(use.operands):
self.remove_instruction(use)
else:
assert False

block.erase()

def remove_instruction(self, insn):
for use in set(insn.uses):
if isinstance(use, ir.Phi):
use.remove_incoming_value(insn)
if not any(use.operands):
self.remove_instruction(use)

insn.erase()
11 changes: 11 additions & 0 deletions artiq/compiler/transforms/inferencer.py
Original file line number Diff line number Diff line change
@@ -947,3 +947,14 @@ def makenotes(printer, typea, typeb, loca, locb):
else:
self._unify(self.function.return_type, node.value.type,
self.function.name_loc, node.value.loc, makenotes)

def visit_Assert(self, node):
self.generic_visit(node)
self._unify(node.test.type, builtins.TBool(),
node.test.loc, None)
if node.msg is not None:
if not isinstance(node.msg, asttyped.StrT):
diag = diagnostic.Diagnostic("error",
"assertion message must be a string literal", {},
node.msg.loc)
self.engine.process(diag)
28 changes: 26 additions & 2 deletions artiq/compiler/transforms/llvm_ir_generator.py
Original file line number Diff line number Diff line change
@@ -95,7 +95,9 @@ def llbuiltin(self, name):
if llfun is not None:
return llfun

if name in ("llvm.abort", "llvm.donothing"):
if name in "llvm.donothing":
llty = ll.FunctionType(ll.VoidType(), [])
elif name in "llvm.trap":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "llvm.round.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType()])
@@ -181,6 +183,18 @@ def process_Alloc(self, insn):
if ir.is_environment(insn.type):
return self.llbuilder.alloca(self.llty_of_type(insn.type, bare=True),
name=insn.name)
elif ir.is_option(insn.type):
if len(insn.operands) == 0: # empty
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
return self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), False), 0,
name=insn.name)
elif len(insn.operands) == 1: # full
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llvalue = self.llbuilder.insert_value(llvalue, ll.Constant(ll.IntType(1), True), 0)
return self.llbuilder.insert_value(llvalue, self.map(insn.operands[0]), 1,
name=insn.name)
else:
assert False
elif builtins.is_list(insn.type):
llsize = self.map(insn.operands[0])
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
@@ -382,7 +396,17 @@ def process_Compare(self, insn):
def process_Builtin(self, insn):
if insn.op == "nop":
return self.llbuilder.call(self.llbuiltin("llvm.donothing"), [])
if insn.op == "abort":
return self.llbuilder.call(self.llbuiltin("llvm.trap"), [])
elif insn.op == "is_some":
optarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 0,
name=insn.name)
elif insn.op == "unwrap":
optarg = self.map(insn.operands[0])
return self.llbuilder.extract_value(optarg, 1,
name=insn.name)
elif insn.op == "unwrap_or":
optarg, default = map(self.map, insn.operands)
has_arg = self.llbuilder.extract_value(optarg, 0)
arg = self.llbuilder.extract_value(optarg, 1)
@@ -455,7 +479,7 @@ def process_Unreachable(self, insn):

def process_Raise(self, insn):
# TODO: hack before EH is working
llinsn = self.llbuilder.call(self.llbuiltin("llvm.abort"), [],
llinsn = self.llbuilder.call(self.llbuiltin("llvm.trap"), [],
name=insn.name)
self.llbuilder.unreachable()
return llinsn
6 changes: 6 additions & 0 deletions lit-test/compiler/inferencer/error_assert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t
# RUN: OutputCheck %s --file-to-check=%t

x = "A"
# CHECK-L: ${LINE:+1}: error: assertion message must be a string literal
assert True, x

0 comments on commit 236d5b8

Please sign in to comment.