Skip to content

Commit c21387d

Browse files
author
whitequark
committedAug 28, 2015
compiler.embedding: support calling methods marked as @kernel.
1 parent d0fd618 commit c21387d

File tree

3 files changed

+125
-51
lines changed

3 files changed

+125
-51
lines changed
 

Diff for: ‎artiq/compiler/embedding.py

+115-45
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,11 @@ def retrieve(self, obj_key):
3434
return self.forward_map[obj_key]
3535

3636
class ASTSynthesizer:
37-
def __init__(self, type_map, value_map, expanded_from=None):
37+
def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
3838
self.source = ""
3939
self.source_buffer = source.Buffer(self.source, "<synthesized>")
4040
self.type_map, self.value_map = type_map, value_map
41+
self.quote_function = quote_function
4142
self.expanded_from = expanded_from
4243

4344
def finalize(self):
@@ -82,6 +83,10 @@ def quote(self, value):
8283
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
8384
begin_loc=begin_loc, end_loc=end_loc,
8485
loc=begin_loc.join(end_loc))
86+
elif inspect.isfunction(value) or inspect.ismethod(value):
87+
function_name, function_type = self.quote_function(value, self.expanded_from)
88+
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
89+
loc=self._add(repr(value)))
8590
else:
8691
quote_loc = self._add('`')
8792
repr_loc = self._add(repr(value))
@@ -155,6 +160,36 @@ def call(self, function_node, args, kwargs):
155160
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
156161
loc=name_loc.join(end_loc))
157162

163+
def assign_local(self, var_name, value):
164+
name_loc = self._add(var_name)
165+
_ = self._add(" ")
166+
equals_loc = self._add("=")
167+
_ = self._add(" ")
168+
value_node = self.quote(value)
169+
170+
var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
171+
loc=name_loc)
172+
173+
return ast.Assign(targets=[var_node], value=value_node,
174+
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
175+
176+
def assign_attribute(self, obj, attr_name, value):
177+
obj_node = self.quote(obj)
178+
dot_loc = self._add(".")
179+
name_loc = self._add(attr_name)
180+
_ = self._add(" ")
181+
equals_loc = self._add("=")
182+
_ = self._add(" ")
183+
value_node = self.quote(value)
184+
185+
attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
186+
type=value_node.type,
187+
dot_loc=dot_loc, attr_loc=name_loc,
188+
loc=obj_node.loc.join(name_loc))
189+
190+
return ast.Assign(targets=[attr_node], value=value_node,
191+
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))
192+
158193
class StitchingASTTypedRewriter(ASTTypedRewriter):
159194
def __init__(self, engine, prelude, globals, host_environment, quote):
160195
super().__init__(engine, prelude)
@@ -221,7 +256,20 @@ def visit_AttributeT(self, node):
221256
# overhead (i.e. synthesizing a source buffer), but has the advantage
222257
# of having the host-to-ARTIQ mapping code in only one place and
223258
# also immediately getting proper diagnostics on type errors.
224-
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
259+
attr_value = getattr(object_value, node.attr)
260+
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
261+
and types.is_instance(object_type)):
262+
# In cases like:
263+
# class c:
264+
# @kernel
265+
# def f(self): pass
266+
# we want f to be defined on the class, not on the instance.
267+
attributes = object_type.constructor.attributes
268+
attr_value = attr_value.__func__
269+
else:
270+
attributes = object_type.attributes
271+
272+
ast = self.quote(attr_value, None)
225273

226274
def proxy_diagnostic(diag):
227275
note = diagnostic.Diagnostic("note",
@@ -238,17 +286,17 @@ def proxy_diagnostic(diag):
238286
Inferencer(engine=proxy_engine).visit(ast)
239287
IntMonomorphizer(engine=proxy_engine).visit(ast)
240288

241-
if node.attr not in object_type.attributes:
289+
if node.attr not in attributes:
242290
# We just figured out what the type should be. Add it.
243-
object_type.attributes[node.attr] = ast.type
244-
elif object_type.attributes[node.attr] != ast.type:
291+
attributes[node.attr] = ast.type
292+
elif attributes[node.attr] != ast.type:
245293
# Does this conflict with an earlier guess?
246294
printer = types.TypePrinter()
247295
diag = diagnostic.Diagnostic("error",
248296
"host object has an attribute of type {typea}, which is"
249297
" different from previously inferred type {typeb}",
250298
{"typea": printer.name(ast.type),
251-
"typeb": printer.name(object_type.attributes[node.attr])},
299+
"typeb": printer.name(attributes[node.attr])},
252300
object_loc)
253301
self.engine.process(diag)
254302

@@ -261,11 +309,9 @@ def freeze(obj):
261309
return self.visit(obj)
262310
elif isinstance(obj, types.Type):
263311
return hash(obj.find())
264-
elif isinstance(obj, list):
265-
return tuple(obj)
266312
else:
267-
assert obj is None or isinstance(obj, (bool, int, float, str))
268-
return obj
313+
# We don't care; only types change during inference.
314+
pass
269315

270316
fields = node._fields
271317
if hasattr(node, '_types'):
@@ -281,6 +327,7 @@ def __init__(self, engine=None):
281327

282328
self.name = ""
283329
self.typedtree = []
330+
self.inject_at = 0
284331
self.prelude = prelude.globals()
285332
self.globals = {}
286333

@@ -290,6 +337,17 @@ def __init__(self, engine=None):
290337
self.type_map = {}
291338
self.value_map = defaultdict(lambda: [])
292339

340+
def stitch_call(self, function, args, kwargs):
341+
function_node = self._quote_embedded_function(function)
342+
self.typedtree.append(function_node)
343+
344+
# We synthesize source code for the initial call so that
345+
# diagnostics would have something meaningful to display to the user.
346+
synthesizer = self._synthesizer()
347+
call_node = synthesizer.call(function_node, args, kwargs)
348+
synthesizer.finalize()
349+
self.typedtree.append(call_node)
350+
293351
def finalize(self):
294352
inferencer = StitchingInferencer(engine=self.engine,
295353
value_map=self.value_map,
@@ -306,12 +364,50 @@ def finalize(self):
306364
break
307365
old_typedtree_hash = typedtree_hash
308366

367+
# For every host class we embed, add an appropriate constructor
368+
# as a global. This is necessary for method lookup, which uses
369+
# the getconstructor instruction.
370+
for instance_type, constructor_type in list(self.type_map.values()):
371+
# Do we have any direct reference to a constructor?
372+
if len(self.value_map[constructor_type]) > 0:
373+
# Yes, use it.
374+
constructor, _constructor_loc = self.value_map[constructor_type][0]
375+
else:
376+
# No, extract one from a reference to an instance.
377+
instance, _instance_loc = self.value_map[instance_type][0]
378+
constructor = type(instance)
379+
380+
self.globals[constructor_type.name] = constructor_type
381+
382+
synthesizer = self._synthesizer()
383+
ast = synthesizer.assign_local(constructor_type.name, constructor)
384+
synthesizer.finalize()
385+
self._inject(ast)
386+
387+
for attr in constructor_type.attributes:
388+
if types.is_function(constructor_type.attributes[attr]):
389+
synthesizer = self._synthesizer()
390+
ast = synthesizer.assign_attribute(constructor, attr,
391+
getattr(constructor, attr))
392+
synthesizer.finalize()
393+
self._inject(ast)
394+
309395
# After we have found all functions, synthesize a module to hold them.
310396
source_buffer = source.Buffer("", "<synthesized>")
311397
self.typedtree = asttyped.ModuleT(
312398
typing_env=self.globals, globals_in_scope=set(),
313399
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))
314400

401+
def _inject(self, node):
402+
self.typedtree.insert(self.inject_at, node)
403+
self.inject_at += 1
404+
405+
def _synthesizer(self, expanded_from=None):
406+
return ASTSynthesizer(expanded_from=expanded_from,
407+
type_map=self.type_map,
408+
value_map=self.value_map,
409+
quote_function=self._quote_function)
410+
315411
def _quote_embedded_function(self, function):
316412
if not hasattr(function, "artiq_embedded"):
317413
raise ValueError("{} is not an embedded function".format(repr(function)))
@@ -414,10 +510,7 @@ def _type_of_param(self, function, loc, param, is_syscall):
414510
# This is tricky, because the default value might not have
415511
# a well-defined type in APython.
416512
# In this case, we bail out, but mention why we do it.
417-
synthesizer = ASTSynthesizer(type_map=self.type_map,
418-
value_map=self.value_map)
419-
ast = synthesizer.quote(param.default)
420-
synthesizer.finalize()
513+
ast = self._quote(param.default, None)
421514

422515
def proxy_diagnostic(diag):
423516
note = diagnostic.Diagnostic("note",
@@ -499,20 +592,21 @@ def _quote_foreign_function(self, function, loc, syscall):
499592
self.globals[function_name] = function_type
500593
self.functions[function] = function_name
501594

502-
return function_name
595+
return function_name, function_type
503596

504597
def _quote_function(self, function, loc):
505598
if function in self.functions:
506-
return self.functions[function]
599+
function_name = self.functions[function]
600+
return function_name, self.globals[function_name]
507601

508602
if hasattr(function, "artiq_embedded"):
509603
if function.artiq_embedded.function is not None:
510604
# Insert the typed AST for the new function and restart inference.
511605
# It doesn't really matter where we insert as long as it is before
512606
# the final call.
513607
function_node = self._quote_embedded_function(function)
514-
self.typedtree.insert(0, function_node)
515-
return function_node.name
608+
self._inject(function_node)
609+
return function_node.name, self.globals[function_node.name]
516610
elif function.artiq_embedded.syscall is not None:
517611
# Insert a storage-less global whose type instructs the compiler
518612
# to perform a system call instead of a regular call.
@@ -527,31 +621,7 @@ def _quote_function(self, function, loc):
527621
syscall=None)
528622

529623
def _quote(self, value, loc):
530-
if inspect.isfunction(value) or inspect.ismethod(value):
531-
# It's a function. We need to translate the function and insert
532-
# a reference to it.
533-
function_name = self._quote_function(value, loc)
534-
return asttyped.NameT(id=function_name, ctx=None,
535-
type=self.globals[function_name],
536-
loc=loc)
537-
538-
else:
539-
# It's just a value. Quote it.
540-
synthesizer = ASTSynthesizer(expanded_from=loc,
541-
type_map=self.type_map,
542-
value_map=self.value_map)
543-
node = synthesizer.quote(value)
544-
synthesizer.finalize()
545-
return node
546-
547-
def stitch_call(self, function, args, kwargs):
548-
function_node = self._quote_embedded_function(function)
549-
self.typedtree.append(function_node)
550-
551-
# We synthesize source code for the initial call so that
552-
# diagnostics would have something meaningful to display to the user.
553-
synthesizer = ASTSynthesizer(type_map=self.type_map,
554-
value_map=self.value_map)
555-
call_node = synthesizer.call(function_node, args, kwargs)
624+
synthesizer = self._synthesizer(loc)
625+
node = synthesizer.quote(value)
556626
synthesizer.finalize()
557-
self.typedtree.append(call_node)
627+
return node

Diff for: ‎artiq/compiler/transforms/inferencer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,10 @@ def makenotes(printer, typea, typeb, loca, locb):
127127
when=" while inferring the type for self argument")
128128

129129
attr_type = types.TMethod(object_type, attr_type)
130-
self._unify(node.type, attr_type,
131-
node.loc, None)
130+
131+
if not types.is_var(attr_type):
132+
self._unify(node.type, attr_type,
133+
node.loc, None)
132134
else:
133135
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
134136
highlights, notes = [node.value.loc], []

Diff for: ‎artiq/compiler/transforms/llvm_ir_generator.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def llty_of_type(self, typ, bare=False, for_return=False):
266266
elif types.is_constructor(typ):
267267
name = "class.{}".format(typ.name)
268268
else:
269-
name = typ.name
269+
name = "instance.{}".format(typ.name)
270270

271271
llty = self.llcontext.get_identified_type(name)
272272
if llty.elements is None:
@@ -991,7 +991,7 @@ def _quote(self, value, typ, path):
991991
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
992992
lambda: path() + [attr]))
993993

994-
llvalue = ll.Constant.literal_struct(llfields)
994+
llvalue = ll.Constant(llty.pointee, llfields)
995995
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
996996
llconst.initializer = llvalue
997997
llconst.linkage = "private"
@@ -1012,8 +1012,10 @@ def _quote(self, value, typ, path):
10121012
elif builtins.is_str(typ):
10131013
assert isinstance(value, (str, bytes))
10141014
return self.llstr_of_str(value)
1015-
elif types.is_rpc_function(typ):
1016-
return ll.Constant.literal_struct([])
1015+
elif types.is_function(typ):
1016+
# RPC and C functions have no runtime representation; ARTIQ
1017+
# functions are initialized explicitly.
1018+
return ll.Constant(llty, ll.Undefined)
10171019
else:
10181020
print(typ)
10191021
assert False

0 commit comments

Comments
 (0)
Please sign in to comment.