Skip to content

Commit 6a6d7da

Browse files
author
whitequark
committedJan 5, 2016
transforms.artiq_ir_generator: add support for user-defined context managers.
1 parent d633c8e commit 6a6d7da

File tree

2 files changed

+173
-104
lines changed

2 files changed

+173
-104
lines changed
 

Diff for: ‎artiq/compiler/transforms/artiq_ir_generator.py

+140-104
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ def visit_Break(self, node):
555555
def visit_Continue(self, node):
556556
self.append(ir.Branch(self.continue_target))
557557

558-
def raise_exn(self, exn, loc=None):
558+
def raise_exn(self, exn=None, loc=None):
559559
if self.final_branch is not None:
560560
raise_proxy = self.add_block("try.raise")
561561
self.final_branch(raise_proxy, self.current_block)
@@ -718,19 +718,34 @@ def final_branch(target, block):
718718
if not post_handler.is_terminated():
719719
post_handler.append(ir.Branch(tail))
720720

721-
def visit_With(self, node):
722-
if len(node.items) != 1:
723-
diag = diagnostic.Diagnostic("fatal",
724-
"only one expression per 'with' statement is supported",
725-
{"type": types.TypePrinter().name(typ)},
726-
node.context_expr.loc)
727-
self.engine.process(diag)
721+
def _try_finally(self, body_gen, finally_gen, name):
722+
dispatcher = self.add_block("{}.dispatch".format(name))
723+
724+
try:
725+
old_unwind, self.unwind_target = self.unwind_target, dispatcher
726+
body_gen()
727+
finally:
728+
self.unwind_target = old_unwind
728729

730+
if not self.current_block.is_terminated():
731+
finally_gen()
732+
733+
self.post_body = self.current_block
734+
735+
self.current_block = self.add_block("{}.cleanup".format(name))
736+
dispatcher.append(ir.LandingPad(self.current_block))
737+
finally_gen()
738+
self.raise_exn()
739+
740+
self.current_block = self.post_body
741+
742+
def visit_With(self, node):
729743
context_expr_node = node.items[0].context_expr
730744
optional_vars_node = node.items[0].optional_vars
731745

732746
if types.is_builtin(context_expr_node.type, "sequential"):
733747
self.visit(node.body)
748+
return
734749
elif types.is_builtin(context_expr_node.type, "parallel"):
735750
parallel = self.append(ir.Parallel([]))
736751

@@ -748,32 +763,44 @@ def visit_With(self, node):
748763
for tail in tails:
749764
if not tail.is_terminated():
750765
tail.append(ir.Branch(self.current_block))
751-
elif isinstance(context_expr_node, asttyped.CallT) and \
752-
types.is_builtin(context_expr_node.func.type, "watchdog"):
753-
timeout = self.visit(context_expr_node.args[0])
754-
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
755-
ir.Constant(1000, builtins.TFloat())))
756-
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
766+
return
767+
768+
cleanup = []
769+
for item_node in node.items:
770+
context_expr_node = item_node.context_expr
771+
optional_vars_node = item_node.optional_vars
772+
773+
if isinstance(context_expr_node, asttyped.CallT) and \
774+
types.is_builtin(context_expr_node.func.type, "watchdog"):
775+
timeout = self.visit(context_expr_node.args[0])
776+
timeout_ms = self.append(ir.Arith(ast.Mult(loc=None), timeout,
777+
ir.Constant(1000, builtins.TFloat())))
778+
timeout_ms_int = self.append(ir.Coerce(timeout_ms, builtins.TInt32()))
779+
780+
watchdog_id = self.append(ir.Builtin("watchdog_set", [timeout_ms_int],
781+
builtins.TInt32()))
782+
cleanup.append(lambda:
783+
self.append(ir.Builtin("watchdog_clear", [watchdog_id], builtins.TNone())))
784+
else: # user-defined context manager
785+
context_mgr = self.visit(context_expr_node)
786+
enter_fn = self._get_attribute(context_mgr, '__enter__')
787+
exit_fn = self._get_attribute(context_mgr, '__exit__')
757788

758-
watchdog = self.append(ir.Builtin("watchdog_set", [timeout_ms_int], builtins.TInt32()))
789+
try:
790+
self.current_assign = self._user_call(enter_fn, [], {})
791+
if optional_vars_node is not None:
792+
self.visit(optional_vars_node)
793+
finally:
794+
self.current_assign = None
759795

760-
dispatcher = self.add_block("watchdog.dispatch")
796+
none = self.append(ir.Alloc([], builtins.TNone()))
797+
cleanup.append(lambda:
798+
self._user_call(exit_fn, [none, none, none], {}))
761799

762-
try:
763-
old_unwind, self.unwind_target = self.unwind_target, dispatcher
764-
self.visit(node.body)
765-
finally:
766-
self.unwind_target = old_unwind
767-
768-
cleanup = self.add_block('watchdog.cleanup')
769-
landingpad = dispatcher.append(ir.LandingPad(cleanup))
770-
cleanup.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
771-
cleanup.append(ir.Reraise(self.unwind_target))
772-
773-
if not self.current_block.is_terminated():
774-
self.append(ir.Builtin("watchdog_clear", [watchdog], builtins.TNone()))
775-
else:
776-
assert False
800+
self._try_finally(
801+
body_gen=lambda: self.visit(node.body),
802+
finally_gen=lambda: [thunk() for thunk in cleanup],
803+
name="with")
777804

778805
# Expression visitors
779806
# These visitors return a node in addition to mutating
@@ -850,31 +877,35 @@ def visit_NameT(self, node):
850877
else:
851878
return self._set_local(node.id, self.current_assign)
852879

853-
def visit_AttributeT(self, node):
854-
try:
855-
old_assign, self.current_assign = self.current_assign, None
856-
obj = self.visit(node.value)
857-
finally:
858-
self.current_assign = old_assign
859-
860-
if node.attr not in obj.type.find().attributes:
880+
def _get_attribute(self, obj, attr_name):
881+
if attr_name not in obj.type.find().attributes:
861882
# A class attribute. Get the constructor (class object) and
862883
# extract the attribute from it.
863884
constr_type = obj.type.constructor
864885
constr = self.append(ir.GetConstructor(self._env_for(constr_type.name),
865886
constr_type.name, constr_type,
866887
name="constructor." + constr_type.name))
867888

868-
if types.is_function(constr.type.attributes[node.attr]):
889+
if types.is_function(constr.type.attributes[attr_name]):
869890
# A method. Construct a method object instead.
870-
func = self.append(ir.GetAttr(constr, node.attr))
871-
return self.append(ir.Alloc([func, obj], node.type))
891+
func = self.append(ir.GetAttr(constr, attr_name))
892+
return self.append(ir.Alloc([func, obj],
893+
types.TMethod(obj.type, func.type)))
872894
else:
873895
obj = constr
874896

897+
return self.append(ir.GetAttr(obj, attr_name,
898+
name="{}.{}".format(_readable_name(obj), attr_name)))
899+
900+
def visit_AttributeT(self, node):
901+
try:
902+
old_assign, self.current_assign = self.current_assign, None
903+
obj = self.visit(node.value)
904+
finally:
905+
self.current_assign = old_assign
906+
875907
if self.current_assign is None:
876-
return self.append(ir.GetAttr(obj, node.attr,
877-
name="{}.{}".format(_readable_name(obj), node.attr)))
908+
return self._get_attribute(obj, node.attr)
878909
elif types.is_rpc_function(self.current_assign.type):
879910
# RPC functions are just type-level markers
880911
return self.append(ir.Builtin("nop", [], builtins.TNone()))
@@ -1624,77 +1655,82 @@ def body_gen(index):
16241655
node.loc)
16251656
self.engine.process(diag)
16261657

1658+
def _user_call(self, callee, positional, keywords, arg_exprs={}):
1659+
if types.is_function(callee.type):
1660+
func = callee
1661+
self_arg = None
1662+
fn_typ = callee.type
1663+
offset = 0
1664+
elif types.is_method(callee.type):
1665+
func = self.append(ir.GetAttr(callee, "__func__"))
1666+
self_arg = self.append(ir.GetAttr(callee, "__self__"))
1667+
fn_typ = types.get_method_function(callee.type)
1668+
offset = 1
1669+
else:
1670+
assert False
1671+
1672+
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
1673+
1674+
for index, arg in enumerate(positional):
1675+
if index + offset < len(fn_typ.args):
1676+
args[index + offset] = arg
1677+
else:
1678+
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
1679+
1680+
for keyword in keywords:
1681+
arg = keywords[keyword]
1682+
if keyword in fn_typ.args:
1683+
for index, arg_name in enumerate(fn_typ.args):
1684+
if keyword == arg_name:
1685+
assert args[index] is None
1686+
args[index] = arg
1687+
break
1688+
elif keyword in fn_typ.optargs:
1689+
for index, optarg_name in enumerate(fn_typ.optargs):
1690+
if keyword == optarg_name:
1691+
assert args[len(fn_typ.args) + index] is None
1692+
args[len(fn_typ.args) + index] = \
1693+
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
1694+
break
1695+
1696+
for index, optarg_name in enumerate(fn_typ.optargs):
1697+
if args[len(fn_typ.args) + index] is None:
1698+
args[len(fn_typ.args) + index] = \
1699+
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
1700+
1701+
if self_arg is not None:
1702+
assert args[0] is None
1703+
args[0] = self_arg
1704+
1705+
assert None not in args
1706+
1707+
if self.unwind_target is None:
1708+
insn = self.append(ir.Call(func, args, arg_exprs))
1709+
else:
1710+
after_invoke = self.add_block()
1711+
insn = self.append(ir.Invoke(func, args, arg_exprs,
1712+
after_invoke, self.unwind_target))
1713+
self.current_block = after_invoke
1714+
1715+
return insn
1716+
16271717
def visit_CallT(self, node):
1628-
typ = node.func.type.find()
1718+
if not types.is_builtin(node.func.type):
1719+
callee = self.visit(node.func)
1720+
args = [self.visit(arg_node) for arg_node in node.args]
1721+
keywords = {kw_node.arg: self.visit(kw_node.value) for kw_node in node.keywords}
16291722

16301723
if node.iodelay is not None and not iodelay.is_const(node.iodelay, 0):
16311724
before_delay = self.current_block
16321725
during_delay = self.add_block()
16331726
before_delay.append(ir.Branch(during_delay))
16341727
self.current_block = during_delay
16351728

1636-
if types.is_builtin(typ):
1729+
if types.is_builtin(node.func.type):
16371730
insn = self.visit_builtin_call(node)
16381731
else:
1639-
if types.is_function(typ):
1640-
func = self.visit(node.func)
1641-
self_arg = None
1642-
fn_typ = typ
1643-
offset = 0
1644-
elif types.is_method(typ):
1645-
method = self.visit(node.func)
1646-
func = self.append(ir.GetAttr(method, "__func__"))
1647-
self_arg = self.append(ir.GetAttr(method, "__self__"))
1648-
fn_typ = types.get_method_function(typ)
1649-
offset = 1
1650-
else:
1651-
assert False
1652-
1653-
args = [None] * (len(fn_typ.args) + len(fn_typ.optargs))
1654-
1655-
for index, arg_node in enumerate(node.args):
1656-
arg = self.visit(arg_node)
1657-
if index + offset < len(fn_typ.args):
1658-
args[index + offset] = arg
1659-
else:
1660-
args[index + offset] = self.append(ir.Alloc([arg], ir.TOption(arg.type)))
1661-
1662-
for keyword in node.keywords:
1663-
arg = self.visit(keyword.value)
1664-
if keyword.arg in fn_typ.args:
1665-
for index, arg_name in enumerate(fn_typ.args):
1666-
if keyword.arg == arg_name:
1667-
assert args[index] is None
1668-
args[index] = arg
1669-
break
1670-
elif keyword.arg in fn_typ.optargs:
1671-
for index, optarg_name in enumerate(fn_typ.optargs):
1672-
if keyword.arg == optarg_name:
1673-
assert args[len(fn_typ.args) + index] is None
1674-
args[len(fn_typ.args) + index] = \
1675-
self.append(ir.Alloc([arg], ir.TOption(arg.type)))
1676-
break
1677-
1678-
for index, optarg_name in enumerate(fn_typ.optargs):
1679-
if args[len(fn_typ.args) + index] is None:
1680-
args[len(fn_typ.args) + index] = \
1681-
self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name])))
1682-
1683-
if self_arg is not None:
1684-
assert args[0] is None
1685-
args[0] = self_arg
1686-
1687-
assert None not in args
1688-
1689-
if self.unwind_target is None:
1690-
insn = self.append(ir.Call(func, args, node.arg_exprs))
1691-
else:
1692-
after_invoke = self.add_block()
1693-
insn = self.append(ir.Invoke(func, args, node.arg_exprs,
1694-
after_invoke, self.unwind_target))
1695-
self.current_block = after_invoke
1732+
insn = self._user_call(callee, args, keywords, node.arg_exprs)
16961733

1697-
method_key = None
16981734
if isinstance(node.func, asttyped.AttributeT):
16991735
attr_node = node.func
17001736
self.method_map[(attr_node.value.type.find(), attr_node.attr)].append(insn)

Diff for: ‎lit-test/test/integration/with.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# RUN: %python -m artiq.compiler.testbench.jit %s
2+
# RUN: %python %s
3+
4+
class contextmgr:
5+
def __enter__(self):
6+
print(2)
7+
8+
def __exit__(self, n1, n2, n3):
9+
print(4)
10+
11+
# CHECK-L: a 1
12+
# CHECK-L: 2
13+
# CHECK-L: a 3
14+
# CHECK-L: 4
15+
# CHECK-L: a 5
16+
print("a", 1)
17+
with contextmgr():
18+
print("a", 3)
19+
print("a", 5)
20+
21+
# CHECK-L: b 1
22+
# CHECK-L: 2
23+
# CHECK-L: 4
24+
# CHECK-L: b 6
25+
try:
26+
print("b", 1)
27+
with contextmgr():
28+
[0][1]
29+
print("b", 3)
30+
print("b", 5)
31+
except:
32+
pass
33+
print("b", 6)

0 commit comments

Comments
 (0)
Please sign in to comment.