Skip to content

Commit 50448ef

Browse files
author
whitequark
committedAug 7, 2015
Add support for referring to host values in embedded functions.
1 parent 353f454 commit 50448ef

File tree

4 files changed

+110
-27
lines changed

4 files changed

+110
-27
lines changed
 

Diff for: ‎artiq/compiler/embedding.py

+103-23
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
annotated as ``@kernel`` when they are referenced.
66
"""
77

8-
import inspect
8+
import inspect, os
99
from pythonparser import ast, source, diagnostic, parse_buffer
1010
from . import types, builtins, asttyped, prelude
1111
from .transforms import ASTTypedRewriter, Inferencer
@@ -28,11 +28,12 @@ def _add(self, fragment):
2828

2929
def quote(self, value):
3030
"""Construct an AST fragment equal to `value`."""
31-
if value in (None, True, False):
32-
if node.value is True or node.value is False:
33-
typ = builtins.TBool()
34-
elif node.value is None:
35-
typ = builtins.TNone()
31+
if value is None:
32+
typ = builtins.TNone()
33+
return asttyped.NameConstantT(value=value, type=typ,
34+
loc=self._add(repr(value)))
35+
elif value is True or value is False:
36+
typ = builtins.TBool()
3637
return asttyped.NameConstantT(value=value, type=typ,
3738
loc=self._add(repr(value)))
3839
elif isinstance(value, (int, float)):
@@ -45,12 +46,12 @@ def quote(self, value):
4546
elif isinstance(value, list):
4647
begin_loc = self._add("[")
4748
elts = []
48-
for index, elt in value:
49+
for index, elt in enumerate(value):
4950
elts.append(self.quote(elt))
5051
if index < len(value) - 1:
5152
self._add(", ")
5253
end_loc = self._add("]")
53-
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
54+
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
5455
begin_loc=begin_loc, end_loc=end_loc,
5556
loc=begin_loc.join(end_loc))
5657
else:
@@ -99,7 +100,43 @@ def call(self, function_node, args, kwargs):
99100
loc=name_loc.join(end_loc))
100101

101102
class StitchingASTTypedRewriter(ASTTypedRewriter):
102-
pass
103+
def __init__(self, engine, prelude, globals, host_environment, quote_function):
104+
super().__init__(engine, prelude)
105+
self.globals = globals
106+
self.env_stack.append(self.globals)
107+
108+
self.host_environment = host_environment
109+
self.quote_function = quote_function
110+
111+
def visit_Name(self, node):
112+
typ = super()._try_find_name(node.id)
113+
if typ is not None:
114+
# Value from device environment.
115+
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
116+
loc=node.loc)
117+
else:
118+
# Try to find this value in the host environment and quote it.
119+
if node.id in self.host_environment:
120+
value = self.host_environment[node.id]
121+
if inspect.isfunction(value):
122+
# It's a function. We need to translate the function and insert
123+
# a reference to it.
124+
function_name = self.quote_function(value)
125+
return asttyped.NameT(id=function_name, ctx=None,
126+
type=self.globals[function_name],
127+
loc=node.loc)
128+
129+
else:
130+
# It's just a value. Quote it.
131+
synthesizer = ASTSynthesizer()
132+
node = synthesizer.quote(value)
133+
synthesizer.finalize()
134+
return node
135+
else:
136+
diag = diagnostic.Diagnostic("fatal",
137+
"name '{name}' is not bound to anything", {"name":node.id},
138+
node.loc)
139+
self.engine.process(diag)
103140

104141
class Stitcher:
105142
def __init__(self, engine=None):
@@ -108,50 +145,93 @@ def __init__(self, engine=None):
108145
else:
109146
self.engine = engine
110147

111-
self.asttyped_rewriter = StitchingASTTypedRewriter(
112-
engine=self.engine, globals=prelude.globals())
113-
self.inferencer = Inferencer(engine=self.engine)
148+
self.name = ""
149+
self.typedtree = []
150+
self.prelude = prelude.globals()
151+
self.globals = {}
114152

115-
self.name = "stitched"
116-
self.typedtree = None
117-
self.globals = self.asttyped_rewriter.globals
153+
self.functions = {}
118154

119155
self.rpc_map = {}
120156

121157
def _iterate(self):
158+
inferencer = Inferencer(engine=self.engine)
159+
122160
# Iterate inference to fixed point.
123161
self.inference_finished = False
124162
while not self.inference_finished:
125163
self.inference_finished = True
126-
self.inferencer.visit(self.typedtree)
164+
inferencer.visit(self.typedtree)
165+
166+
# After we have found all functions, synthesize a module to hold them.
167+
self.typedtree = asttyped.ModuleT(
168+
typing_env=self.globals, globals_in_scope=set(),
169+
body=self.typedtree, loc=None)
127170

128-
def _parse_embedded_function(self, function):
171+
def _quote_embedded_function(self, function):
129172
if not hasattr(function, "artiq_embedded"):
130173
raise ValueError("{} is not an embedded function".format(repr(function)))
131174

132175
# Extract function source.
133176
embedded_function = function.artiq_embedded.function
134177
source_code = inspect.getsource(embedded_function)
135178
filename = embedded_function.__code__.co_filename
179+
module_name, _ = os.path.splitext(os.path.basename(filename))
136180
first_line = embedded_function.__code__.co_firstlineno
137181

182+
# Extract function environment.
183+
host_environment = dict()
184+
host_environment.update(embedded_function.__globals__)
185+
cells = embedded_function.__closure__
186+
cell_names = embedded_function.__code__.co_freevars
187+
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})
188+
138189
# Parse.
139190
source_buffer = source.Buffer(source_code, filename, first_line)
140191
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
192+
function_node = parsetree.body[0]
141193

142-
# Rewrite into typed form.
143-
typedtree = self.asttyped_rewriter.visit(parsetree)
194+
# Mangle the name, since we put everything into a single module.
195+
function_node.name = "{}.{}".format(module_name, function_node.name)
196+
197+
# Normally, LocalExtractor would populate the typing environment
198+
# of the module with the function name. However, since we run
199+
# ASTTypedRewriter on the function node directly, we need to do it
200+
# explicitly.
201+
self.globals[function_node.name] = types.TVar()
202+
203+
# Memoize the function before typing it to handle recursive
204+
# invocations.
205+
self.functions[function] = function_node.name
144206

145-
return typedtree, typedtree.body[0]
207+
# Rewrite into typed form.
208+
asttyped_rewriter = StitchingASTTypedRewriter(
209+
engine=self.engine, prelude=self.prelude,
210+
globals=self.globals, host_environment=host_environment,
211+
quote_function=self._quote_function)
212+
return asttyped_rewriter.visit(function_node)
213+
214+
def _quote_function(self, function):
215+
if function in self.functions:
216+
return self.functions[function]
217+
218+
# Insert the typed AST for the new function and restart inference.
219+
# It doesn't really matter where we insert as long as it is before
220+
# the final call.
221+
function_node = self._quote_embedded_function(function)
222+
self.typedtree.insert(0, function_node)
223+
self.inference_finished = False
224+
return function_node.name
146225

147226
def stitch_call(self, function, args, kwargs):
148-
self.typedtree, function_node = self._parse_embedded_function(function)
227+
function_node = self._quote_embedded_function(function)
228+
self.typedtree.append(function_node)
149229

150-
# We synthesize fake source code for the initial call so that
230+
# We synthesize source code for the initial call so that
151231
# diagnostics would have something meaningful to display to the user.
152232
synthesizer = ASTSynthesizer()
153233
call_node = synthesizer.call(function_node, args, kwargs)
154234
synthesizer.finalize()
155-
self.typedtree.body.append(call_node)
235+
self.typedtree.append(call_node)
156236

157237
self._iterate()

Diff for: ‎artiq/compiler/module.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def build_llvm_ir(self, target):
6767

6868
def entry_point(self):
6969
"""Return the name of the function that is the entry point of this module."""
70-
return self.name + ".__modinit__"
70+
if self.name != "":
71+
return self.name + ".__modinit__"
72+
else:
73+
return "__modinit__"
7174

7275
def __repr__(self):
7376
printer = types.TypePrinter()

Diff for: ‎artiq/compiler/transforms/artiq_ir_generator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
7070
def __init__(self, module_name, engine):
7171
self.engine = engine
7272
self.functions = []
73-
self.name = [module_name]
73+
self.name = [module_name] if module_name != "" else []
7474
self.current_loc = None
7575
self.current_function = None
7676
self.current_globals = set()

Diff for: ‎artiq/compiler/transforms/asttyped_rewriter.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
185185
via :class:`LocalExtractor`.
186186
"""
187187

188-
def __init__(self, engine, globals):
188+
def __init__(self, engine, prelude):
189189
self.engine = engine
190190
self.globals = None
191-
self.env_stack = [globals]
191+
self.env_stack = [prelude]
192192

193193
def _try_find_name(self, name):
194194
for typing_env in reversed(self.env_stack):

0 commit comments

Comments
 (0)