@@ -34,10 +34,11 @@ def retrieve(self, obj_key):
34
34
return self .forward_map [obj_key ]
35
35
36
36
class ASTSynthesizer :
37
- def __init__ (self , type_map , value_map , expanded_from = None ):
37
+ def __init__ (self , type_map , value_map , quote_function = None , expanded_from = None ):
38
38
self .source = ""
39
39
self .source_buffer = source .Buffer (self .source , "<synthesized>" )
40
40
self .type_map , self .value_map = type_map , value_map
41
+ self .quote_function = quote_function
41
42
self .expanded_from = expanded_from
42
43
43
44
def finalize (self ):
@@ -82,6 +83,10 @@ def quote(self, value):
82
83
return asttyped .ListT (elts = elts , ctx = None , type = builtins .TList (),
83
84
begin_loc = begin_loc , end_loc = end_loc ,
84
85
loc = begin_loc .join (end_loc ))
86
+ elif inspect .isfunction (value ) or inspect .ismethod (value ):
87
+ function_name , function_type = self .quote_function (value , self .expanded_from )
88
+ return asttyped .NameT (id = function_name , ctx = None , type = function_type ,
89
+ loc = self ._add (repr (value )))
85
90
else :
86
91
quote_loc = self ._add ('`' )
87
92
repr_loc = self ._add (repr (value ))
@@ -155,6 +160,36 @@ def call(self, function_node, args, kwargs):
155
160
begin_loc = begin_loc , end_loc = end_loc , star_loc = None , dstar_loc = None ,
156
161
loc = name_loc .join (end_loc ))
157
162
163
+ def assign_local (self , var_name , value ):
164
+ name_loc = self ._add (var_name )
165
+ _ = self ._add (" " )
166
+ equals_loc = self ._add ("=" )
167
+ _ = self ._add (" " )
168
+ value_node = self .quote (value )
169
+
170
+ var_node = asttyped .NameT (id = var_name , ctx = None , type = value_node .type ,
171
+ loc = name_loc )
172
+
173
+ return ast .Assign (targets = [var_node ], value = value_node ,
174
+ op_locs = [equals_loc ], loc = name_loc .join (value_node .loc ))
175
+
176
+ def assign_attribute (self , obj , attr_name , value ):
177
+ obj_node = self .quote (obj )
178
+ dot_loc = self ._add ("." )
179
+ name_loc = self ._add (attr_name )
180
+ _ = self ._add (" " )
181
+ equals_loc = self ._add ("=" )
182
+ _ = self ._add (" " )
183
+ value_node = self .quote (value )
184
+
185
+ attr_node = asttyped .AttributeT (value = obj_node , attr = attr_name , ctx = None ,
186
+ type = value_node .type ,
187
+ dot_loc = dot_loc , attr_loc = name_loc ,
188
+ loc = obj_node .loc .join (name_loc ))
189
+
190
+ return ast .Assign (targets = [attr_node ], value = value_node ,
191
+ op_locs = [equals_loc ], loc = name_loc .join (value_node .loc ))
192
+
158
193
class StitchingASTTypedRewriter (ASTTypedRewriter ):
159
194
def __init__ (self , engine , prelude , globals , host_environment , quote ):
160
195
super ().__init__ (engine , prelude )
@@ -221,7 +256,20 @@ def visit_AttributeT(self, node):
221
256
# overhead (i.e. synthesizing a source buffer), but has the advantage
222
257
# of having the host-to-ARTIQ mapping code in only one place and
223
258
# also immediately getting proper diagnostics on type errors.
224
- ast = self .quote (getattr (object_value , node .attr ), object_loc .expanded_from )
259
+ attr_value = getattr (object_value , node .attr )
260
+ if (inspect .ismethod (attr_value ) and hasattr (attr_value .__func__ , 'artiq_embedded' )
261
+ and types .is_instance (object_type )):
262
+ # In cases like:
263
+ # class c:
264
+ # @kernel
265
+ # def f(self): pass
266
+ # we want f to be defined on the class, not on the instance.
267
+ attributes = object_type .constructor .attributes
268
+ attr_value = attr_value .__func__
269
+ else :
270
+ attributes = object_type .attributes
271
+
272
+ ast = self .quote (attr_value , None )
225
273
226
274
def proxy_diagnostic (diag ):
227
275
note = diagnostic .Diagnostic ("note" ,
@@ -238,17 +286,17 @@ def proxy_diagnostic(diag):
238
286
Inferencer (engine = proxy_engine ).visit (ast )
239
287
IntMonomorphizer (engine = proxy_engine ).visit (ast )
240
288
241
- if node .attr not in object_type . attributes :
289
+ if node .attr not in attributes :
242
290
# We just figured out what the type should be. Add it.
243
- object_type . attributes [node .attr ] = ast .type
244
- elif object_type . attributes [node .attr ] != ast .type :
291
+ attributes [node .attr ] = ast .type
292
+ elif attributes [node .attr ] != ast .type :
245
293
# Does this conflict with an earlier guess?
246
294
printer = types .TypePrinter ()
247
295
diag = diagnostic .Diagnostic ("error" ,
248
296
"host object has an attribute of type {typea}, which is"
249
297
" different from previously inferred type {typeb}" ,
250
298
{"typea" : printer .name (ast .type ),
251
- "typeb" : printer .name (object_type . attributes [node .attr ])},
299
+ "typeb" : printer .name (attributes [node .attr ])},
252
300
object_loc )
253
301
self .engine .process (diag )
254
302
@@ -261,11 +309,9 @@ def freeze(obj):
261
309
return self .visit (obj )
262
310
elif isinstance (obj , types .Type ):
263
311
return hash (obj .find ())
264
- elif isinstance (obj , list ):
265
- return tuple (obj )
266
312
else :
267
- assert obj is None or isinstance ( obj , ( bool , int , float , str ))
268
- return obj
313
+ # We don't care; only types change during inference.
314
+ pass
269
315
270
316
fields = node ._fields
271
317
if hasattr (node , '_types' ):
@@ -281,6 +327,7 @@ def __init__(self, engine=None):
281
327
282
328
self .name = ""
283
329
self .typedtree = []
330
+ self .inject_at = 0
284
331
self .prelude = prelude .globals ()
285
332
self .globals = {}
286
333
@@ -290,6 +337,17 @@ def __init__(self, engine=None):
290
337
self .type_map = {}
291
338
self .value_map = defaultdict (lambda : [])
292
339
340
+ def stitch_call (self , function , args , kwargs ):
341
+ function_node = self ._quote_embedded_function (function )
342
+ self .typedtree .append (function_node )
343
+
344
+ # We synthesize source code for the initial call so that
345
+ # diagnostics would have something meaningful to display to the user.
346
+ synthesizer = self ._synthesizer ()
347
+ call_node = synthesizer .call (function_node , args , kwargs )
348
+ synthesizer .finalize ()
349
+ self .typedtree .append (call_node )
350
+
293
351
def finalize (self ):
294
352
inferencer = StitchingInferencer (engine = self .engine ,
295
353
value_map = self .value_map ,
@@ -306,12 +364,50 @@ def finalize(self):
306
364
break
307
365
old_typedtree_hash = typedtree_hash
308
366
367
+ # For every host class we embed, add an appropriate constructor
368
+ # as a global. This is necessary for method lookup, which uses
369
+ # the getconstructor instruction.
370
+ for instance_type , constructor_type in list (self .type_map .values ()):
371
+ # Do we have any direct reference to a constructor?
372
+ if len (self .value_map [constructor_type ]) > 0 :
373
+ # Yes, use it.
374
+ constructor , _constructor_loc = self .value_map [constructor_type ][0 ]
375
+ else :
376
+ # No, extract one from a reference to an instance.
377
+ instance , _instance_loc = self .value_map [instance_type ][0 ]
378
+ constructor = type (instance )
379
+
380
+ self .globals [constructor_type .name ] = constructor_type
381
+
382
+ synthesizer = self ._synthesizer ()
383
+ ast = synthesizer .assign_local (constructor_type .name , constructor )
384
+ synthesizer .finalize ()
385
+ self ._inject (ast )
386
+
387
+ for attr in constructor_type .attributes :
388
+ if types .is_function (constructor_type .attributes [attr ]):
389
+ synthesizer = self ._synthesizer ()
390
+ ast = synthesizer .assign_attribute (constructor , attr ,
391
+ getattr (constructor , attr ))
392
+ synthesizer .finalize ()
393
+ self ._inject (ast )
394
+
309
395
# After we have found all functions, synthesize a module to hold them.
310
396
source_buffer = source .Buffer ("" , "<synthesized>" )
311
397
self .typedtree = asttyped .ModuleT (
312
398
typing_env = self .globals , globals_in_scope = set (),
313
399
body = self .typedtree , loc = source .Range (source_buffer , 0 , 0 ))
314
400
401
+ def _inject (self , node ):
402
+ self .typedtree .insert (self .inject_at , node )
403
+ self .inject_at += 1
404
+
405
+ def _synthesizer (self , expanded_from = None ):
406
+ return ASTSynthesizer (expanded_from = expanded_from ,
407
+ type_map = self .type_map ,
408
+ value_map = self .value_map ,
409
+ quote_function = self ._quote_function )
410
+
315
411
def _quote_embedded_function (self , function ):
316
412
if not hasattr (function , "artiq_embedded" ):
317
413
raise ValueError ("{} is not an embedded function" .format (repr (function )))
@@ -414,10 +510,7 @@ def _type_of_param(self, function, loc, param, is_syscall):
414
510
# This is tricky, because the default value might not have
415
511
# a well-defined type in APython.
416
512
# In this case, we bail out, but mention why we do it.
417
- synthesizer = ASTSynthesizer (type_map = self .type_map ,
418
- value_map = self .value_map )
419
- ast = synthesizer .quote (param .default )
420
- synthesizer .finalize ()
513
+ ast = self ._quote (param .default , None )
421
514
422
515
def proxy_diagnostic (diag ):
423
516
note = diagnostic .Diagnostic ("note" ,
@@ -499,20 +592,21 @@ def _quote_foreign_function(self, function, loc, syscall):
499
592
self .globals [function_name ] = function_type
500
593
self .functions [function ] = function_name
501
594
502
- return function_name
595
+ return function_name , function_type
503
596
504
597
def _quote_function (self , function , loc ):
505
598
if function in self .functions :
506
- return self .functions [function ]
599
+ function_name = self .functions [function ]
600
+ return function_name , self .globals [function_name ]
507
601
508
602
if hasattr (function , "artiq_embedded" ):
509
603
if function .artiq_embedded .function is not None :
510
604
# Insert the typed AST for the new function and restart inference.
511
605
# It doesn't really matter where we insert as long as it is before
512
606
# the final call.
513
607
function_node = self ._quote_embedded_function (function )
514
- self .typedtree . insert ( 0 , function_node )
515
- return function_node .name
608
+ self ._inject ( function_node )
609
+ return function_node .name , self . globals [ function_node . name ]
516
610
elif function .artiq_embedded .syscall is not None :
517
611
# Insert a storage-less global whose type instructs the compiler
518
612
# to perform a system call instead of a regular call.
@@ -527,31 +621,7 @@ def _quote_function(self, function, loc):
527
621
syscall = None )
528
622
529
623
def _quote (self , value , loc ):
530
- if inspect .isfunction (value ) or inspect .ismethod (value ):
531
- # It's a function. We need to translate the function and insert
532
- # a reference to it.
533
- function_name = self ._quote_function (value , loc )
534
- return asttyped .NameT (id = function_name , ctx = None ,
535
- type = self .globals [function_name ],
536
- loc = loc )
537
-
538
- else :
539
- # It's just a value. Quote it.
540
- synthesizer = ASTSynthesizer (expanded_from = loc ,
541
- type_map = self .type_map ,
542
- value_map = self .value_map )
543
- node = synthesizer .quote (value )
544
- synthesizer .finalize ()
545
- return node
546
-
547
- def stitch_call (self , function , args , kwargs ):
548
- function_node = self ._quote_embedded_function (function )
549
- self .typedtree .append (function_node )
550
-
551
- # We synthesize source code for the initial call so that
552
- # diagnostics would have something meaningful to display to the user.
553
- synthesizer = ASTSynthesizer (type_map = self .type_map ,
554
- value_map = self .value_map )
555
- call_node = synthesizer .call (function_node , args , kwargs )
624
+ synthesizer = self ._synthesizer (loc )
625
+ node = synthesizer .quote (value )
556
626
synthesizer .finalize ()
557
- self . typedtree . append ( call_node )
627
+ return node
0 commit comments