Skip to content

Commit

Permalink
Add support for referring to host values in embedded functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 7, 2015
1 parent 353f454 commit 50448ef
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 27 deletions.
126 changes: 103 additions & 23 deletions artiq/compiler/embedding.py
Expand Up @@ -5,7 +5,7 @@
annotated as ``@kernel`` when they are referenced.
"""

import inspect
import inspect, os
from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer
Expand All @@ -28,11 +28,12 @@ def _add(self, fragment):

def quote(self, value):
"""Construct an AST fragment equal to `value`."""
if value in (None, True, False):
if node.value is True or node.value is False:
typ = builtins.TBool()
elif node.value is None:
typ = builtins.TNone()
if value is None:
typ = builtins.TNone()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif value is True or value is False:
typ = builtins.TBool()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, (int, float)):
Expand All @@ -45,12 +46,12 @@ def quote(self, value):
elif isinstance(value, list):
begin_loc = self._add("[")
elts = []
for index, elt in value:
for index, elt in enumerate(value):
elts.append(self.quote(elt))
if index < len(value) - 1:
self._add(", ")
end_loc = self._add("]")
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
else:
Expand Down Expand Up @@ -99,7 +100,43 @@ def call(self, function_node, args, kwargs):
loc=name_loc.join(end_loc))

class StitchingASTTypedRewriter(ASTTypedRewriter):
pass
def __init__(self, engine, prelude, globals, host_environment, quote_function):
super().__init__(engine, prelude)
self.globals = globals
self.env_stack.append(self.globals)

self.host_environment = host_environment
self.quote_function = quote_function

def visit_Name(self, node):
typ = super()._try_find_name(node.id)
if typ is not None:
# Value from device environment.
return asttyped.NameT(type=typ, id=node.id, ctx=node.ctx,
loc=node.loc)
else:
# Try to find this value in the host environment and quote it.
if node.id in self.host_environment:
value = self.host_environment[node.id]
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self.quote_function(value)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=node.loc)

else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer()
node = synthesizer.quote(value)
synthesizer.finalize()
return node
else:
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name":node.id},
node.loc)
self.engine.process(diag)

class Stitcher:
def __init__(self, engine=None):
Expand All @@ -108,50 +145,93 @@ def __init__(self, engine=None):
else:
self.engine = engine

self.asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, globals=prelude.globals())
self.inferencer = Inferencer(engine=self.engine)
self.name = ""
self.typedtree = []
self.prelude = prelude.globals()
self.globals = {}

self.name = "stitched"
self.typedtree = None
self.globals = self.asttyped_rewriter.globals
self.functions = {}

self.rpc_map = {}

def _iterate(self):
inferencer = Inferencer(engine=self.engine)

# Iterate inference to fixed point.
self.inference_finished = False
while not self.inference_finished:
self.inference_finished = True
self.inferencer.visit(self.typedtree)
inferencer.visit(self.typedtree)

# After we have found all functions, synthesize a module to hold them.
self.typedtree = asttyped.ModuleT(
typing_env=self.globals, globals_in_scope=set(),
body=self.typedtree, loc=None)

def _parse_embedded_function(self, function):
def _quote_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))

# Extract function source.
embedded_function = function.artiq_embedded.function
source_code = inspect.getsource(embedded_function)
filename = embedded_function.__code__.co_filename
module_name, _ = os.path.splitext(os.path.basename(filename))
first_line = embedded_function.__code__.co_firstlineno

# Extract function environment.
host_environment = dict()
host_environment.update(embedded_function.__globals__)
cells = embedded_function.__closure__
cell_names = embedded_function.__code__.co_freevars
host_environment.update({var: cells[index] for index, var in enumerate(cell_names)})

# Parse.
source_buffer = source.Buffer(source_code, filename, first_line)
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)
function_node = parsetree.body[0]

# Rewrite into typed form.
typedtree = self.asttyped_rewriter.visit(parsetree)
# Mangle the name, since we put everything into a single module.
function_node.name = "{}.{}".format(module_name, function_node.name)

# Normally, LocalExtractor would populate the typing environment
# of the module with the function name. However, since we run
# ASTTypedRewriter on the function node directly, we need to do it
# explicitly.
self.globals[function_node.name] = types.TVar()

# Memoize the function before typing it to handle recursive
# invocations.
self.functions[function] = function_node.name

return typedtree, typedtree.body[0]
# Rewrite into typed form.
asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment,
quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node)

def _quote_function(self, function):
if function in self.functions:
return self.functions[function]

# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name

def stitch_call(self, function, args, kwargs):
self.typedtree, function_node = self._parse_embedded_function(function)
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize fake source code for the initial call so that
# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.body.append(call_node)
self.typedtree.append(call_node)

self._iterate()
5 changes: 4 additions & 1 deletion artiq/compiler/module.py
Expand Up @@ -67,7 +67,10 @@ def build_llvm_ir(self, target):

def entry_point(self):
"""Return the name of the function that is the entry point of this module."""
return self.name + ".__modinit__"
if self.name != "":
return self.name + ".__modinit__"
else:
return "__modinit__"

def __repr__(self):
printer = types.TypePrinter()
Expand Down
2 changes: 1 addition & 1 deletion artiq/compiler/transforms/artiq_ir_generator.py
Expand Up @@ -70,7 +70,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def __init__(self, module_name, engine):
self.engine = engine
self.functions = []
self.name = [module_name]
self.name = [module_name] if module_name != "" else []
self.current_loc = None
self.current_function = None
self.current_globals = set()
Expand Down
4 changes: 2 additions & 2 deletions artiq/compiler/transforms/asttyped_rewriter.py
Expand Up @@ -185,10 +185,10 @@ class ASTTypedRewriter(algorithm.Transformer):
via :class:`LocalExtractor`.
"""

def __init__(self, engine, globals):
def __init__(self, engine, prelude):
self.engine = engine
self.globals = None
self.env_stack = [globals]
self.env_stack = [prelude]

def _try_find_name(self, name):
for typing_env in reversed(self.env_stack):
Expand Down

0 comments on commit 50448ef

Please sign in to comment.