Skip to content

Commit

Permalink
transforms.artiq_ir_generator: add support for user-defined context m…
Browse files Browse the repository at this point in the history
…anagers.
whitequark committed Jan 5, 2016
1 parent d633c8e commit 6a6d7da
Showing 2 changed files with 173 additions and 104 deletions.
244 changes: 140 additions & 104 deletions artiq/compiler/transforms/artiq_ir_generator.py
Original file line number Diff line number Diff line change
@@ -555,7 +555,7 @@ def visit_Break(self, node):
def visit_Continue(self, node):
self.append(ir.Branch(self.continue_target))

def raise_exn(self, exn, loc=None):
def raise_exn(self, exn=None, loc=None):
if self.final_branch is not None:
raise_proxy = self.add_block("try.raise")
self.final_branch(raise_proxy, self.current_block)
@@ -718,19 +718,34 @@ def final_branch(target, block):
if not post_handler.is_terminated():
post_handler.append(ir.Branch(tail))

def visit_With(self, node):
if len(node.items) != 1:
diag = diagnostic.Diagnostic("fatal",
"only one expression per 'with' statement is supported",
{"type": types.TypePrinter().name(typ)},
node.context_expr.loc)
self.engine.process(diag)
def _try_finally(self, body_gen, finally_gen, name):
dispatcher = self.add_block("{}.dispatch".format(name))

try:
old_unwind, self.unwind_target = self.unwind_target, dispatcher
body_gen()
finally:
self.unwind_target = old_unwind

if not self.current_block.is_terminated():
finally_gen()

self.post_body = self.current_block

self.current_block = self.add_block("{}.cleanup".format(name))
dispatcher.append(ir.LandingPad(self.current_block))
finally_gen()
self.raise_exn()

self.current_block = self.post_body

def visit_With(self, node):
context_expr_node = node.items[0].context_expr
optional_vars_node = node.items[0].optional_vars

if types.is_builtin(context_expr_node.type, "sequential"):
self.visit(node.body)
return
elif types.is_builtin(context_expr_node.type, "parallel"):
parallel = self.append(ir.Parallel([]))

@@ -748,32 +763,44 @@ def visit_With(self, node):
for tail in tails:
if not tail.is_terminated():
tail.append(ir.Branch(self.current_block))
elif isinstance(context_expr_node, asttyped.CallT) and \
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
return

cleanup = []
for item_node in node.items:
context_expr_node = item_node.context_expr
optional_vars_node = item_node.optional_vars

if isinstance(context_expr_node, asttyped.CallT) and \
types.is_builtin(context_expr_node.func.type, "watchdog"):
timeout = self.visit(context_expr_node.args[0])
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
ir.Constant(1000, builtins.TFloat())))
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))

watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
builtins.TInt32()))
cleanup.append(lambda:
self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
else: # user-defined context manager
context_mgr = self.visit(context_expr_node)
enter_fn = self._get_attribute(context_mgr, '__enter__')
exit_fn = self._get_attribute(context_mgr, '__exit__')

watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32()))
try:
self.current_assign = self._user_call(enter_fn, [], {})
if optional_vars_node is not None:
self.visit(optional_vars_node)
finally:
self.current_assign = None

dispatcher = self.add_block("watchdog.dispatch")
none = self.append(ir.Alloc([], builtins.TNone()))
cleanup.append(lambda:
self._user_call(exit_fn, [none, none, none], {}))

try:
old_unwind, self.unwind_target = self.unwind_target, dispatcher
self.visit(node.body)
finally:
self.unwind_target = old_unwind

cleanup = self.add_block('watchdog.cleanup')
landingpad = dispatcher.append(ir.LandingPad(cleanup))
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
cleanup.append(ir.Reraise(self.unwind_target))

if not self.current_block.is_terminated():
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
else:
assert False
self._try_finally(
body_gen=lambda: self.visit(node.body),
finally_gen=lambda: [thunk() for thunk in cleanup],
name="with")

# Expression visitors
# These visitors return a node in addition to mutating
@@ -850,31 +877,35 @@ def visit_NameT(self, node):
else:
return self._set_local(node.id, self.current_assign)

def visit_AttributeT(self, node):
try:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign

if node.attr not in obj.type.find().attributes:
def _get_attribute(self, obj, attr_name):
if attr_name not in obj.type.find().attributes:
# A class attribute. Get the constructor (class object) and
# extract the attribute from it.
constr_type = obj.type.constructor
constr = self.append(ir.GetConstructor(self._env_for(constr_type.name),
constr_type.name, constr_type,
name="constructor." + constr_type.name))

if types.is_function(constr.type.attributes[node.attr]):
if types.is_function(constr.type.attributes[attr_name]):
# A method. Construct a method object instead.
func = self.append(ir.GetAttr(constr, node.attr))
return self.append(ir.Alloc([func, obj], node.type))
func = self.append(ir.GetAttr(constr, attr_name))
return self.append(ir.Alloc([func, obj],
types.TMethod(obj.type, func.type)))
else:
obj = constr

return self.append(ir.GetAttr(obj, attr_name,
name="{}.{}".format(_readable_name(obj), attr_name)))

def visit_AttributeT(self, node):
try:
old_assign, self.current_assign = self.current_assign, None
obj = self.visit(node.value)
finally:
self.current_assign = old_assign

if self.current_assign is None:
return self.append(ir.GetAttr(obj, node.attr,
name="{}.{}".format(_readable_name(obj), node.attr)))
return self._get_attribute(obj, node.attr)
elif types.is_rpc_function(self.current_assign.type):
# RPC functions are just type-level markers
return self.append(ir.Builtin("nop", [], builtins.TNone()))
@@ -1624,77 +1655,82 @@ def body_gen(index):
node.loc)
self.engine.process(diag)

def _user_call(self, callee, positional, keywords, arg_exprs={}):
if types.is_function(callee.type):
func = callee
self_arg = None
fn_typ = callee.type
offset = 0
elif types.is_method(callee.type):
func = self.append(ir.GetAttr(callee, "__func__"))
self_arg = self.append(ir.GetAttr(callee, "__self__"))
fn_typ = types.get_method_function(callee.type)
offset = 1
else:
assert False

args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))

for index, arg in enumerate(positional):
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))

for keyword in keywords:
arg = keywords[keyword]
if keyword in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break

for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))

if self_arg is not None:
assert args[0] is None
args[0] = self_arg

assert None not in args

if self.unwind_target is None:
insn = self.append(ir.Call(func, args, arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke

return insn

def visit_CallT(self, node):
typ = node.func.type.find()
if not types.is_builtin(node.func.type):
callee = self.visit(node.func)
args = [self.visit(arg_node) for arg_node in node.args]
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}

if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
before_delay = self.current_block
during_delay = self.add_block()
before_delay.append(ir.Branch(during_delay))
self.current_block = during_delay

if types.is_builtin(typ):
if types.is_builtin(node.func.type):
insn = self.visit_builtin_call(node)
else:
if types.is_function(typ):
func = self.visit(node.func)
self_arg = None
fn_typ = typ
offset = 0
elif types.is_method(typ):
method = self.visit(node.func)
func = self.append(ir.GetAttr(method, "__func__"))
self_arg = self.append(ir.GetAttr(method, "__self__"))
fn_typ = types.get_method_function(typ)
offset = 1
else:
assert False

args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))

for index, arg_node in enumerate(node.args):
arg = self.visit(arg_node)
if index + offset < len(fn_typ.args):
args[index + offset] = arg
else:
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))

for keyword in node.keywords:
arg = self.visit(keyword.value)
if keyword.arg in fn_typ.args:
for index, arg_name in enumerate(fn_typ.args):
if keyword.arg == arg_name:
assert args[index] is None
args[index] = arg
break
elif keyword.arg in fn_typ.optargs:
for index, optarg_name in enumerate(fn_typ.optargs):
if keyword.arg == optarg_name:
assert args[len(fn_typ.args) + index] is None
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
break

for index, optarg_name in enumerate(fn_typ.optargs):
if args[len(fn_typ.args) + index] is None:
args[len(fn_typ.args) + index] = \
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))

if self_arg is not None:
assert args[0] is None
args[0] = self_arg

assert None not in args

if self.unwind_target is None:
insn = self.append(ir.Call(func, args, node.arg_exprs))
else:
after_invoke = self.add_block()
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
after_invoke, self.unwind_target))
self.current_block = after_invoke
insn = self._user_call(callee, args, keywords, node.arg_exprs)

method_key = None
if isinstance(node.func, asttyped.AttributeT):
attr_node = node.func
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)
33 changes: 33 additions & 0 deletions lit-test/test/integration/with.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
# RUN: %python %s

class contextmgr:
def __enter__(self):
print(2)

def __exit__(self, n1, n2, n3):
print(4)

# CHECK-L: a 1
# CHECK-L: 2
# CHECK-L: a 3
# CHECK-L: 4
# CHECK-L: a 5
print("a", 1)
with contextmgr():
print("a", 3)
print("a", 5)

# CHECK-L: b 1
# CHECK-L: 2
# CHECK-L: 4
# CHECK-L: b 6
try:
print("b", 1)
with contextmgr():
[0][1]
print("b", 3)
print("b", 5)
except:
pass
print("b", 6)

0 comments on commit 6a6d7da

Please sign in to comment.