Skip to content

Commit

Permalink
compiler.embedding: support calling methods marked as @kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 28, 2015
1 parent d0fd618 commit c21387d
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 51 deletions.
160 changes: 115 additions & 45 deletions artiq/compiler/embedding.py
Expand Up @@ -34,10 +34,11 @@ def retrieve(self, obj_key):
return self.forward_map[obj_key]

class ASTSynthesizer:
def __init__(self, type_map, value_map, expanded_from=None):
def __init__(self, type_map, value_map, quote_function=None, expanded_from=None):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.type_map, self.value_map = type_map, value_map
self.quote_function = quote_function
self.expanded_from = expanded_from

def finalize(self):
Expand Down Expand Up @@ -82,6 +83,10 @@ def quote(self, value):
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
elif inspect.isfunction(value) or inspect.ismethod(value):
function_name, function_type = self.quote_function(value, self.expanded_from)
return asttyped.NameT(id=function_name, ctx=None, type=function_type,
loc=self._add(repr(value)))
else:
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
Expand Down Expand Up @@ -155,6 +160,36 @@ def call(self, function_node, args, kwargs):
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))

def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)

var_node = asttyped.NameT(id=var_name, ctx=None, type=value_node.type,
loc=name_loc)

return ast.Assign(targets=[var_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))

def assign_attribute(self, obj, attr_name, value):
obj_node = self.quote(obj)
dot_loc = self._add(".")
name_loc = self._add(attr_name)
_ = self._add(" ")
equals_loc = self._add("=")
_ = self._add(" ")
value_node = self.quote(value)

attr_node = asttyped.AttributeT(value=obj_node, attr=attr_name, ctx=None,
type=value_node.type,
dot_loc=dot_loc, attr_loc=name_loc,
loc=obj_node.loc.join(name_loc))

return ast.Assign(targets=[attr_node], value=value_node,
op_locs=[equals_loc], loc=name_loc.join(value_node.loc))

class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote):
super().__init__(engine, prelude)
Expand Down Expand Up @@ -221,7 +256,20 @@ def visit_AttributeT(self, node):
# overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors.
ast = self.quote(getattr(object_value, node.attr), object_loc.expanded_from)
attr_value = getattr(object_value, node.attr)
if (inspect.ismethod(attr_value) and hasattr(attr_value.__func__, 'artiq_embedded')
and types.is_instance(object_type)):
# In cases like:
# class c:
# @kernel
# def f(self): pass
# we want f to be defined on the class, not on the instance.
attributes = object_type.constructor.attributes
attr_value = attr_value.__func__
else:
attributes = object_type.attributes

ast = self.quote(attr_value, None)

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
Expand All @@ -238,17 +286,17 @@ def proxy_diagnostic(diag):
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)

if node.attr not in object_type.attributes:
if node.attr not in attributes:
# We just figured out what the type should be. Add it.
object_type.attributes[node.attr] = ast.type
elif object_type.attributes[node.attr] != ast.type:
attributes[node.attr] = ast.type
elif attributes[node.attr] != ast.type:
# Does this conflict with an earlier guess?
printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error",
"host object has an attribute of type {typea}, which is"
" different from previously inferred type {typeb}",
{"typea": printer.name(ast.type),
"typeb": printer.name(object_type.attributes[node.attr])},
"typeb": printer.name(attributes[node.attr])},
object_loc)
self.engine.process(diag)

Expand All @@ -261,11 +309,9 @@ def freeze(obj):
return self.visit(obj)
elif isinstance(obj, types.Type):
return hash(obj.find())
elif isinstance(obj, list):
return tuple(obj)
else:
assert obj is None or isinstance(obj, (bool, int, float, str))
return obj
# We don't care; only types change during inference.
pass

fields = node._fields
if hasattr(node, '_types'):
Expand All @@ -281,6 +327,7 @@ def __init__(self, engine=None):

self.name = ""
self.typedtree = []
self.inject_at = 0
self.prelude = prelude.globals()
self.globals = {}

Expand All @@ -290,6 +337,17 @@ def __init__(self, engine=None):
self.type_map = {}
self.value_map = defaultdict(lambda: [])

def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.append(call_node)

def finalize(self):
inferencer = StitchingInferencer(engine=self.engine,
value_map=self.value_map,
Expand All @@ -306,12 +364,50 @@ def finalize(self):
break
old_typedtree_hash = typedtree_hash

# For every host class we embed, add an appropriate constructor
# as a global. This is necessary for method lookup, which uses
# the getconstructor instruction.
for instance_type, constructor_type in list(self.type_map.values()):
# Do we have any direct reference to a constructor?
if len(self.value_map[constructor_type]) > 0:
# Yes, use it.
constructor, _constructor_loc = self.value_map[constructor_type][0]
else:
# No, extract one from a reference to an instance.
instance, _instance_loc = self.value_map[instance_type][0]
constructor = type(instance)

self.globals[constructor_type.name] = constructor_type

synthesizer = self._synthesizer()
ast = synthesizer.assign_local(constructor_type.name, constructor)
synthesizer.finalize()
self._inject(ast)

for attr in constructor_type.attributes:
if types.is_function(constructor_type.attributes[attr]):
synthesizer = self._synthesizer()
ast = synthesizer.assign_attribute(constructor, attr,
getattr(constructor, attr))
synthesizer.finalize()
self._inject(ast)

# After we have found all functions, synthesize a module to hold them.
source_buffer = source.Buffer("", "<synthesized>")
self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=source.Range(source_buffer, 0, 0))

def _inject(self, node):
self.typedtree.insert(self.inject_at, node)
self.inject_at += 1

def _synthesizer(self, expanded_from=None):
return ASTSynthesizer(expanded_from=expanded_from,
type_map=self.type_map,
value_map=self.value_map,
quote_function=self._quote_function)

def _quote_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))
Expand Down Expand Up @@ -414,10 +510,7 @@ def _type_of_param(self, function, loc, param, is_syscall):
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
ast = synthesizer.quote(param.default)
synthesizer.finalize()
ast = self._quote(param.default, None)

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
Expand Down Expand Up @@ -499,20 +592,21 @@ def _quote_foreign_function(self, function, loc, syscall):
self.globals[function_name] = function_type
self.functions[function] = function_name

return function_name
return function_name, function_type

def _quote_function(self, function, loc):
if function in self.functions:
return self.functions[function]
function_name = self.functions[function]
return function_name, self.globals[function_name]

if hasattr(function, "artiq_embedded"):
if function.artiq_embedded.function is not None:
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
return function_node.name
self._inject(function_node)
return function_node.name, self.globals[function_node.name]
elif function.artiq_embedded.syscall is not None:
# Insert a storage-less global whose type instructs the compiler
# to perform a system call instead of a regular call.
Expand All @@ -527,31 +621,7 @@ def _quote_function(self, function, loc):
syscall=None)

def _quote(self, value, loc):
if inspect.isfunction(value) or inspect.ismethod(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self._quote_function(value, loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=loc)

else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value)
synthesizer.finalize()
return node

def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer = self._synthesizer(loc)
node = synthesizer.quote(value)
synthesizer.finalize()
self.typedtree.append(call_node)
return node
6 changes: 4 additions & 2 deletions artiq/compiler/transforms/inferencer.py
Expand Up @@ -127,8 +127,10 @@ def makenotes(printer, typea, typeb, loca, locb):
when=" while inferring the type for self argument")

attr_type = types.TMethod(object_type, attr_type)
self._unify(node.type, attr_type,
node.loc, None)

if not types.is_var(attr_type):
self._unify(node.type, attr_type,
node.loc, None)
else:
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
highlights, notes = [node.value.loc], []
Expand Down
10 changes: 6 additions & 4 deletions artiq/compiler/transforms/llvm_ir_generator.py
Expand Up @@ -266,7 +266,7 @@ def llty_of_type(self, typ, bare=False, for_return=False):
elif types.is_constructor(typ):
name = "class.{}".format(typ.name)
else:
name = typ.name
name = "instance.{}".format(typ.name)

llty = self.llcontext.get_identified_type(name)
if llty.elements is None:
Expand Down Expand Up @@ -991,7 +991,7 @@ def _quote(self, value, typ, path):
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
lambda: path() + [attr]))

llvalue = ll.Constant.literal_struct(llfields)
llvalue = ll.Constant(llty.pointee, llfields)
llconst = ll.GlobalVariable(self.llmodule, llvalue.type, global_name)
llconst.initializer = llvalue
llconst.linkage = "private"
Expand All @@ -1012,8 +1012,10 @@ def _quote(self, value, typ, path):
elif builtins.is_str(typ):
assert isinstance(value, (str, bytes))
return self.llstr_of_str(value)
elif types.is_rpc_function(typ):
return ll.Constant.literal_struct([])
elif types.is_function(typ):
# RPC and C functions have no runtime representation; ARTIQ
# functions are initialized explicitly.
return ll.Constant(llty, ll.Undefined)
else:
print(typ)
assert False
Expand Down

0 comments on commit c21387d

Please sign in to comment.