6
6
"""
7
7
8
8
import os , re , linecache , inspect
9
- from collections import OrderedDict
9
+ from collections import OrderedDict , defaultdict
10
10
11
11
from pythonparser import ast , source , diagnostic , parse_buffer
12
12
13
13
from . import types , builtins , asttyped , prelude
14
14
from .transforms import ASTTypedRewriter , Inferencer , IntMonomorphizer
15
+ from .validators import MonomorphismValidator
15
16
16
17
17
18
class ObjectMap :
@@ -34,10 +35,11 @@ def retrieve(self, obj_key):
34
35
return self .forward_map [obj_key ]
35
36
36
37
class ASTSynthesizer :
37
- def __init__ (self , type_map , expanded_from = None ):
38
+ def __init__ (self , type_map , value_map , expanded_from = None ):
38
39
self .source = ""
39
40
self .source_buffer = source .Buffer (self .source , "<synthesized>" )
40
- self .type_map , self .expanded_from = type_map , expanded_from
41
+ self .type_map , self .value_map = type_map , value_map
42
+ self .expanded_from = expanded_from
41
43
42
44
def finalize (self ):
43
45
self .source_buffer .source = self .source
@@ -82,6 +84,11 @@ def quote(self, value):
82
84
begin_loc = begin_loc , end_loc = end_loc ,
83
85
loc = begin_loc .join (end_loc ))
84
86
else :
87
+ quote_loc = self ._add ('`' )
88
+ repr_loc = self ._add (repr (value ))
89
+ unquote_loc = self ._add ('`' )
90
+ loc = quote_loc .join (unquote_loc )
91
+
85
92
if isinstance (value , type ):
86
93
typ = value
87
94
else :
@@ -98,16 +105,14 @@ def quote(self, value):
98
105
99
106
self .type_map [typ ] = instance_type , constructor_type
100
107
101
- quote_loc = self ._add ('`' )
102
- repr_loc = self ._add (repr (value ))
103
- unquote_loc = self ._add ('`' )
104
-
105
108
if isinstance (value , type ):
109
+ self .value_map [constructor_type ].append ((value , loc ))
106
110
return asttyped .QuoteT (value = value , type = constructor_type ,
107
- loc = quote_loc . join ( unquote_loc ) )
111
+ loc = loc )
108
112
else :
113
+ self .value_map [instance_type ].append ((value , loc ))
109
114
return asttyped .QuoteT (value = value , type = instance_type ,
110
- loc = quote_loc . join ( unquote_loc ) )
115
+ loc = loc )
111
116
112
117
def call (self , function_node , args , kwargs ):
113
118
"""
@@ -151,14 +156,16 @@ def call(self, function_node, args, kwargs):
151
156
loc = name_loc .join (end_loc ))
152
157
153
158
class StitchingASTTypedRewriter (ASTTypedRewriter ):
154
- def __init__ (self , engine , prelude , globals , host_environment , quote_function , type_map ):
159
+ def __init__ (self , engine , prelude , globals , host_environment , quote_function ,
160
+ type_map , value_map ):
155
161
super ().__init__ (engine , prelude )
156
162
self .globals = globals
157
163
self .env_stack .append (self .globals )
158
164
159
165
self .host_environment = host_environment
160
166
self .quote_function = quote_function
161
167
self .type_map = type_map
168
+ self .value_map = value_map
162
169
163
170
def visit_Name (self , node ):
164
171
typ = super ()._try_find_name (node .id )
@@ -180,7 +187,9 @@ def visit_Name(self, node):
180
187
181
188
else :
182
189
# It's just a value. Quote it.
183
- synthesizer = ASTSynthesizer (expanded_from = node .loc , type_map = self .type_map )
190
+ synthesizer = ASTSynthesizer (expanded_from = node .loc ,
191
+ type_map = self .type_map ,
192
+ value_map = self .value_map )
184
193
node = synthesizer .quote (value )
185
194
synthesizer .finalize ()
186
195
return node
@@ -190,6 +199,83 @@ def visit_Name(self, node):
190
199
node .loc )
191
200
self .engine .process (diag )
192
201
202
+ class StitchingInferencer (Inferencer ):
203
+ def __init__ (self , engine , type_map , value_map ):
204
+ super ().__init__ (engine )
205
+ self .type_map , self .value_map = type_map , value_map
206
+
207
+ def visit_AttributeT (self , node ):
208
+ self .generic_visit (node )
209
+ object_type = node .value .type .find ()
210
+
211
+ # The inferencer can only observe types, not values; however,
212
+ # when we work with host objects, we have to get the values
213
+ # somewhere, since host interpreter does not have types.
214
+ # Since we have categorized every host object we quoted according to
215
+ # its type, we now interrogate every host object we have to ensure
216
+ # that we can successfully serialize the value of the attribute we
217
+ # are now adding at the code generation stage.
218
+ #
219
+ # FIXME: We perform exhaustive checks of every known host object every
220
+ # time an attribute access is visited, which is potentially quadratic.
221
+ # This is done because it is simpler than performing the checks only when:
222
+ # * a previously unknown attribute is encountered,
223
+ # * a previously unknown host object is encountered;
224
+ # which would be the optimal solution.
225
+ for object_value , object_loc in self .value_map [object_type ]:
226
+ if not hasattr (object_value , node .attr ):
227
+ note = diagnostic .Diagnostic ("note" ,
228
+ "attribute accessed here" , {},
229
+ node .loc )
230
+ diag = diagnostic .Diagnostic ("error" ,
231
+ "host object does not have an attribute '{attr}'" ,
232
+ {"attr" : node .attr },
233
+ object_loc , notes = [note ])
234
+ self .engine .process (diag )
235
+ return
236
+
237
+ # Figure out what ARTIQ type does the value of the attribute have.
238
+ # We do this by quoting it, as if to serialize. This has some
239
+ # overhead (i.e. synthesizing a source buffer), but has the advantage
240
+ # of having the host-to-ARTIQ mapping code in only one place and
241
+ # also immediately getting proper diagnostics on type errors.
242
+ synthesizer = ASTSynthesizer (type_map = self .type_map ,
243
+ value_map = self .value_map )
244
+ ast = synthesizer .quote (getattr (object_value , node .attr ))
245
+ synthesizer .finalize ()
246
+
247
+ def proxy_diagnostic (diag ):
248
+ note = diagnostic .Diagnostic ("note" ,
249
+ "expanded from here while trying to infer a type for an"
250
+ " attribute '{attr}' of a host object" ,
251
+ {"attr" : node .attr },
252
+ node .loc )
253
+ diag .notes .append (note )
254
+
255
+ self .engine .process (diag )
256
+
257
+ proxy_engine = diagnostic .Engine ()
258
+ proxy_engine .process = proxy_diagnostic
259
+ Inferencer (engine = proxy_engine ).visit (ast )
260
+ IntMonomorphizer (engine = proxy_engine ).visit (ast )
261
+ MonomorphismValidator (engine = proxy_engine ).visit (ast )
262
+
263
+ if node .attr not in object_type .attributes :
264
+ # We just figured out what the type should be. Add it.
265
+ object_type .attributes [node .attr ] = ast .type
266
+ elif object_type .attributes [node .attr ] != ast .type :
267
+ # Does this conflict with an earlier guess?
268
+ printer = types .TypePrinter ()
269
+ diag = diagnostic .Diagnostic ("error" ,
270
+ "host object has an attribute of type {typea}, which is"
271
+ " different from previously inferred type {typeb}" ,
272
+ {"typea" : printer .name (ast .type ),
273
+ "typeb" : printer .name (object_type .attributes [node .attr ])},
274
+ object_loc )
275
+ self .engine .process (diag )
276
+
277
+ super ().visit_AttributeT (node )
278
+
193
279
class Stitcher :
194
280
def __init__ (self , engine = None ):
195
281
if engine is None :
@@ -206,9 +292,11 @@ def __init__(self, engine=None):
206
292
207
293
self .object_map = ObjectMap ()
208
294
self .type_map = {}
295
+ self .value_map = defaultdict (lambda : [])
209
296
210
297
def finalize (self ):
211
- inferencer = Inferencer (engine = self .engine )
298
+ inferencer = StitchingInferencer (engine = self .engine ,
299
+ type_map = self .type_map , value_map = self .value_map )
212
300
213
301
# Iterate inference to fixed point.
214
302
self .inference_finished = False
@@ -262,7 +350,8 @@ def _quote_embedded_function(self, function):
262
350
asttyped_rewriter = StitchingASTTypedRewriter (
263
351
engine = self .engine , prelude = self .prelude ,
264
352
globals = self .globals , host_environment = host_environment ,
265
- quote_function = self ._quote_function , type_map = self .type_map )
353
+ quote_function = self ._quote_function ,
354
+ type_map = self .type_map , value_map = self .value_map )
266
355
return asttyped_rewriter .visit (function_node )
267
356
268
357
def _function_loc (self , function ):
@@ -324,7 +413,8 @@ def _type_of_param(self, function, loc, param, is_syscall):
324
413
# This is tricky, because the default value might not have
325
414
# a well-defined type in APython.
326
415
# In this case, we bail out, but mention why we do it.
327
- synthesizer = ASTSynthesizer (type_map = self .type_map )
416
+ synthesizer = ASTSynthesizer (type_map = self .type_map ,
417
+ value_map = self .value_map )
328
418
ast = synthesizer .quote (param .default )
329
419
synthesizer .finalize ()
330
420
@@ -442,7 +532,8 @@ def stitch_call(self, function, args, kwargs):
442
532
443
533
# We synthesize source code for the initial call so that
444
534
# diagnostics would have something meaningful to display to the user.
445
- synthesizer = ASTSynthesizer (type_map = self .type_map )
535
+ synthesizer = ASTSynthesizer (type_map = self .type_map ,
536
+ value_map = self .value_map )
446
537
call_node = synthesizer .call (function_node , args , kwargs )
447
538
synthesizer .finalize ()
448
539
self .typedtree .append (call_node )
0 commit comments