Skip to content

Commit

Permalink
compiler.embedding: support calling methods marked as @kernel.
Browse files Browse the repository at this point in the history
whitequark committed Aug 28, 2015

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d0fd618 commit c21387d
Showing 3 changed files with 125 additions and 51 deletions.
160 changes: 115 additions & 45 deletions artiq/compiler/embedding.py
Original file line number Diff line number Diff line change
@@ -34,10 +34,11 @@ def retrieve(self, obj_key):
return self.forward_map[obj_key]

class ASTSynthesizer:
def __init__(self, type_map, value_map, expanded_from=None):
def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.type_map, self.value_map = type_map, value_map
self.quote_function = quote_function
self.expanded_from = expanded_from

def finalize(self):
@@ -82,6 +83,10 @@ def quote(self, value):
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
elif inspect.isfunction(value) or inspect.ismethod(value):
function_name, function_type = self.quote_function(value, self.expanded_from)
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
loc=self._add(repr(value)))
else:
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
@@ -155,6 +160,36 @@ def call(self, function_node, args, kwargs):
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))

def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)

var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
loc=name_loc)

return ast.Assign(targets=[var_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))

def assign_attribute(self, obj, attr_name, value):
obj_node = self.quote(obj)
dot_loc = self._add(".")
name_loc = self._add(attr_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)

attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
type=value_node.type,
dot_loc=dot_loc, attr_loc=name_loc,
loc=obj_node.loc.join(name_loc))

return ast.Assign(targets=[attr_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))

class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote):
super().__init__(engine, prelude)
@@ -221,7 +256,20 @@ def visit_AttributeT(self, node):
# overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors.
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
attr_value = getattr(object_value, node.attr)
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
and types.is_instance(object_type)):
# In cases like:
# class c:
# @kernel
# def f(self): pass
# we want f to be defined on the class, not on the instance.
attributes = object_type.constructor.attributes
attr_value = attr_value.__func__
else:
attributes = object_type.attributes

ast = self.quote(attr_value, None)

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
@@ -238,17 +286,17 @@ def proxy_diagnostic(diag):
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)

if node.attr not in object_type.attributes:
if node.attr not in attributes:
# We just figured out what the type should be. Add it.
object_type.attributes[node.attr] = ast.type
elif object_type.attributes[node.attr] != ast.type:
attributes[node.attr] = ast.type
elif attributes[node.attr] != ast.type:
# Does this conflict with an earlier guess?
printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error",
"host object has an attribute of type {typea}, which is"
" different from previously inferred type {typeb}",
{"typea": printer.name(ast.type),
"typeb": printer.name(object_type.attributes[node.attr])},
"typeb": printer.name(attributes[node.attr])},
object_loc)
self.engine.process(diag)

@@ -261,11 +309,9 @@ def freeze(obj):
return self.visit(obj)
elif isinstance(obj, types.Type):
return hash(obj.find())
elif isinstance(obj, list):
return tuple(obj)
else:
assert obj is None or isinstance(obj, (bool, int, float, str))
return obj
# We don't care; only types change during inference.
pass

fields = node._fields
if hasattr(node, '_types'):
@@ -281,6 +327,7 @@ def __init__(self, engine=None):

self.name = ""
self.typedtree = []
self.inject_at = 0
self.prelude = prelude.globals()
self.globals = {}

@@ -290,6 +337,17 @@ def __init__(self, engine=None):
self.type_map = {}
self.value_map = defaultdict(lambda: [])

def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.append(call_node)

def finalize(self):
inferencer = StitchingInferencer(engine=self.engine,
value_map=self.value_map,
@@ -306,12 +364,50 @@ def finalize(self):
break
old_typedtree_hash = typedtree_hash

# For every host class we embed, add an appropriate constructor
# as a global. This is necessary for method lookup, which uses
# the getconstructor instruction.
for instance_type, constructor_type in list(self.type_map.values()):
# Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0:
# Yes, use it.
constructor, _constructor_loc = self.value_map[constructor_type][0]
else:
# No, extract one from a reference to an instance.
instance, _instance_loc = self.value_map[instance_type][0]
constructor = type(instance)

self.globals[constructor_type.name] = constructor_type

synthesizer = self._synthesizer()
ast = synthesizer.assign_local(constructor_type.name, constructor)
synthesizer.finalize()
self._inject(ast)

for attr in constructor_type.attributes:
if types.is_function(constructor_type.attributes[attr]):
synthesizer = self._synthesizer()
ast = synthesizer.assign_attribute(constructor, attr,
getattr(constructor, attr))
synthesizer.finalize()
self._inject(ast)

# After we have found all functions, synthesize a module to hold them.
source_buffer = source.Buffer("", "<synthesized>")
self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))

def _inject(self, node):
self.typedtree.insert(self.inject_at, node)
self.inject_at += 1

def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from,
type_map=self.type_map,
value_map=self.value_map,
quote_function=self._quote_function)

def _quote_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))
@@ -414,10 +510,7 @@ def _type_of_param(self, function, loc, param, is_syscall):
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
ast = synthesizer.quote(param.default)
synthesizer.finalize()
ast = self._quote(param.default, None)

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
@@ -499,20 +592,21 @@ def _quote_foreign_function(self, function, loc, syscall):
self.globals[function_name] = function_type
self.functions[function] = function_name

return function_name
return function_name, function_type

def _quote_function(self, function, loc):
if function in self.functions:
return self.functions[function]
function_name = self.functions[function]
return function_name, self.globals[function_name]

if hasattr(function, "artiq_embedded"):
if function.artiq_embedded.function is not None:
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
return function_node.name
self._inject(function_node)
return function_node.name, self.globals[function_node.name]
elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call.
@@ -527,31 +621,7 @@ def _quote_function(self, function, loc):
syscall=None)

def _quote(self, value, loc):
if inspect.isfunction(value) or inspect.ismethod(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self._quote_function(value, loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=loc)

else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value)
synthesizer.finalize()
return node

def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer = self._synthesizer(loc)
node = synthesizer.quote(value)
synthesizer.finalize()
self.typedtree.append(call_node)
return node
6 changes: 4 additions & 2 deletions artiq/compiler/transforms/inferencer.py
Original file line number Diff line number Diff line change
@@ -127,8 +127,10 @@ def makenotes(printer, typea, typeb, loca, locb):
when=" while inferring the type for self argument")

attr_type = types.TMethod(object_type, attr_type)
self._unify(node.type, attr_type,
node.loc, None)

if not types.is_var(attr_type):
self._unify(node.type, attr_type,
node.loc, None)
else:
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
highlights, notes = [node.value.loc], []
10 changes: 6 additions & 4 deletions artiq/compiler/transforms/llvm_ir_generator.py
Original file line number Diff line number Diff line change
@@ -266,7 +266,7 @@ def llty_of_type(self, typ, bare=False, for_return=False):
elif types.is_constructor(typ):
name = "class.{}".format(typ.name)
else:
name = typ.name
name = "instance.{}".format(typ.name)

llty = self.llcontext.get_identified_type(name)
if llty.elements is None:
@@ -991,7 +991,7 @@ def _quote(self, value, typ, path):
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
lambda: path() + [attr]))

llvalue = ll.Constant.literal_struct(llfields)
llvalue = ll.Constant(llty.pointee, llfields)
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue
llconst.linkage = "private"
@@ -1012,8 +1012,10 @@ def _quote(self, value, typ, path):
elif builtins.is_str(typ):
assert isinstance(value, (str, bytes))
return self.llstr_of_str(value)
elif types.is_rpc_function(typ):
return ll.Constant.literal_struct([])
elif types.is_function(typ):
# RPC and C functions have no runtime representation; ARTIQ
# functions are initialized explicitly.
return ll.Constant(llty, ll.Undefined)
else:
print(typ)
assert False

0 comments on commit c21387d

Please sign in to comment.