5
5
annotated as ``@kernel`` when they are referenced.
6
6
"""
7
7
8
- import inspect
8
+ import inspect , os
9
9
from pythonparser import ast , source , diagnostic , parse_buffer
10
10
from . import types , builtins , asttyped , prelude
11
11
from .transforms import ASTTypedRewriter , Inferencer
@@ -28,11 +28,12 @@ def _add(self, fragment):
28
28
29
29
def quote (self , value ):
30
30
"""Construct an AST fragment equal to `value`."""
31
- if value in (None , True , False ):
32
- if node .value is True or node .value is False :
33
- typ = builtins .TBool ()
34
- elif node .value is None :
35
- typ = builtins .TNone ()
31
+ if value is None :
32
+ typ = builtins .TNone ()
33
+ return asttyped .NameConstantT (value = value , type = typ ,
34
+ loc = self ._add (repr (value )))
35
+ elif value is True or value is False :
36
+ typ = builtins .TBool ()
36
37
return asttyped .NameConstantT (value = value , type = typ ,
37
38
loc = self ._add (repr (value )))
38
39
elif isinstance (value , (int , float )):
@@ -45,12 +46,12 @@ def quote(self, value):
45
46
elif isinstance (value , list ):
46
47
begin_loc = self ._add ("[" )
47
48
elts = []
48
- for index , elt in value :
49
+ for index , elt in enumerate ( value ) :
49
50
elts .append (self .quote (elt ))
50
51
if index < len (value ) - 1 :
51
52
self ._add (", " )
52
53
end_loc = self ._add ("]" )
53
- return asttyped .ListT (elts = elts , ctx = None , type = types . TVar (),
54
+ return asttyped .ListT (elts = elts , ctx = None , type = builtins . TList (),
54
55
begin_loc = begin_loc , end_loc = end_loc ,
55
56
loc = begin_loc .join (end_loc ))
56
57
else :
@@ -99,7 +100,43 @@ def call(self, function_node, args, kwargs):
99
100
loc = name_loc .join (end_loc ))
100
101
101
102
class StitchingASTTypedRewriter (ASTTypedRewriter ):
102
- pass
103
+ def __init__ (self , engine , prelude , globals , host_environment , quote_function ):
104
+ super ().__init__ (engine , prelude )
105
+ self .globals = globals
106
+ self .env_stack .append (self .globals )
107
+
108
+ self .host_environment = host_environment
109
+ self .quote_function = quote_function
110
+
111
+ def visit_Name (self , node ):
112
+ typ = super ()._try_find_name (node .id )
113
+ if typ is not None :
114
+ # Value from device environment.
115
+ return asttyped .NameT (type = typ , id = node .id , ctx = node .ctx ,
116
+ loc = node .loc )
117
+ else :
118
+ # Try to find this value in the host environment and quote it.
119
+ if node .id in self .host_environment :
120
+ value = self .host_environment [node .id ]
121
+ if inspect .isfunction (value ):
122
+ # It's a function. We need to translate the function and insert
123
+ # a reference to it.
124
+ function_name = self .quote_function (value )
125
+ return asttyped .NameT (id = function_name , ctx = None ,
126
+ type = self .globals [function_name ],
127
+ loc = node .loc )
128
+
129
+ else :
130
+ # It's just a value. Quote it.
131
+ synthesizer = ASTSynthesizer ()
132
+ node = synthesizer .quote (value )
133
+ synthesizer .finalize ()
134
+ return node
135
+ else :
136
+ diag = diagnostic .Diagnostic ("fatal" ,
137
+ "name '{name}' is not bound to anything" , {"name" :node .id },
138
+ node .loc )
139
+ self .engine .process (diag )
103
140
104
141
class Stitcher :
105
142
def __init__ (self , engine = None ):
@@ -108,50 +145,93 @@ def __init__(self, engine=None):
108
145
else :
109
146
self .engine = engine
110
147
111
- self .asttyped_rewriter = StitchingASTTypedRewriter (
112
- engine = self .engine , globals = prelude .globals ())
113
- self .inferencer = Inferencer (engine = self .engine )
148
+ self .name = ""
149
+ self .typedtree = []
150
+ self .prelude = prelude .globals ()
151
+ self .globals = {}
114
152
115
- self .name = "stitched"
116
- self .typedtree = None
117
- self .globals = self .asttyped_rewriter .globals
153
+ self .functions = {}
118
154
119
155
self .rpc_map = {}
120
156
121
157
def _iterate (self ):
158
+ inferencer = Inferencer (engine = self .engine )
159
+
122
160
# Iterate inference to fixed point.
123
161
self .inference_finished = False
124
162
while not self .inference_finished :
125
163
self .inference_finished = True
126
- self .inferencer .visit (self .typedtree )
164
+ inferencer .visit (self .typedtree )
165
+
166
+ # After we have found all functions, synthesize a module to hold them.
167
+ self .typedtree = asttyped .ModuleT (
168
+ typing_env = self .globals , globals_in_scope = set (),
169
+ body = self .typedtree , loc = None )
127
170
128
- def _parse_embedded_function (self , function ):
171
+ def _quote_embedded_function (self , function ):
129
172
if not hasattr (function , "artiq_embedded" ):
130
173
raise ValueError ("{} is not an embedded function" .format (repr (function )))
131
174
132
175
# Extract function source.
133
176
embedded_function = function .artiq_embedded .function
134
177
source_code = inspect .getsource (embedded_function )
135
178
filename = embedded_function .__code__ .co_filename
179
+ module_name , _ = os .path .splitext (os .path .basename (filename ))
136
180
first_line = embedded_function .__code__ .co_firstlineno
137
181
182
+ # Extract function environment.
183
+ host_environment = dict ()
184
+ host_environment .update (embedded_function .__globals__ )
185
+ cells = embedded_function .__closure__
186
+ cell_names = embedded_function .__code__ .co_freevars
187
+ host_environment .update ({var : cells [index ] for index , var in enumerate (cell_names )})
188
+
138
189
# Parse.
139
190
source_buffer = source .Buffer (source_code , filename , first_line )
140
191
parsetree , comments = parse_buffer (source_buffer , engine = self .engine )
192
+ function_node = parsetree .body [0 ]
141
193
142
- # Rewrite into typed form.
143
- typedtree = self .asttyped_rewriter .visit (parsetree )
194
+ # Mangle the name, since we put everything into a single module.
195
+ function_node .name = "{}.{}" .format (module_name , function_node .name )
196
+
197
+ # Normally, LocalExtractor would populate the typing environment
198
+ # of the module with the function name. However, since we run
199
+ # ASTTypedRewriter on the function node directly, we need to do it
200
+ # explicitly.
201
+ self .globals [function_node .name ] = types .TVar ()
202
+
203
+ # Memoize the function before typing it to handle recursive
204
+ # invocations.
205
+ self .functions [function ] = function_node .name
144
206
145
- return typedtree , typedtree .body [0 ]
207
+ # Rewrite into typed form.
208
+ asttyped_rewriter = StitchingASTTypedRewriter (
209
+ engine = self .engine , prelude = self .prelude ,
210
+ globals = self .globals , host_environment = host_environment ,
211
+ quote_function = self ._quote_function )
212
+ return asttyped_rewriter .visit (function_node )
213
+
214
+ def _quote_function (self , function ):
215
+ if function in self .functions :
216
+ return self .functions [function ]
217
+
218
+ # Insert the typed AST for the new function and restart inference.
219
+ # It doesn't really matter where we insert as long as it is before
220
+ # the final call.
221
+ function_node = self ._quote_embedded_function (function )
222
+ self .typedtree .insert (0 , function_node )
223
+ self .inference_finished = False
224
+ return function_node .name
146
225
147
226
def stitch_call (self , function , args , kwargs ):
148
- self .typedtree , function_node = self ._parse_embedded_function (function )
227
+ function_node = self ._quote_embedded_function (function )
228
+ self .typedtree .append (function_node )
149
229
150
- # We synthesize fake source code for the initial call so that
230
+ # We synthesize source code for the initial call so that
151
231
# diagnostics would have something meaningful to display to the user.
152
232
synthesizer = ASTSynthesizer ()
153
233
call_node = synthesizer .call (function_node , args , kwargs )
154
234
synthesizer .finalize ()
155
- self .typedtree .body . append (call_node )
235
+ self .typedtree .append (call_node )
156
236
157
237
self ._iterate ()
0 commit comments