Skip to content

Commit

Permalink
Add basic support for embedded functions with new compiler.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 7, 2015
1 parent b6e2613 commit 353f454
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 345 deletions.
1 change: 1 addition & 0 deletions artiq/compiler/__init__.py
@@ -1 +1,2 @@
from .module import Module, Source
from .embedding import Stitcher
2 changes: 2 additions & 0 deletions artiq/compiler/asttyped.py
Expand Up @@ -93,3 +93,5 @@ class YieldFromT(ast.YieldFrom, commontyped):
# Novel typed nodes
class CoerceT(ast.expr, commontyped):
_fields = ('value',) # other_value deliberately not in _fields
class QuoteT(ast.expr, commontyped):
_fields = ('value',)
157 changes: 157 additions & 0 deletions artiq/compiler/embedding.py
@@ -0,0 +1,157 @@
"""
The :class:`Stitcher` class allows to transparently combine compiled
Python code and Python code executed on the host system: it resolves
the references to the host objects and translates the functions
annotated as ``@kernel`` when they are referenced.
"""

import inspect
from pythonparser import ast, source, diagnostic, parse_buffer
from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer


class ASTSynthesizer:
def __init__(self):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")

def finalize(self):
self.source_buffer.source = self.source
return self.source_buffer

def _add(self, fragment):
range_from = len(self.source)
self.source += fragment
range_to = len(self.source)
return source.Range(self.source_buffer, range_from, range_to)

def quote(self, value):
"""Construct an AST fragment equal to `value`."""
if value in (None, True, False):
if node.value is True or node.value is False:
typ = builtins.TBool()
elif node.value is None:
typ = builtins.TNone()
return asttyped.NameConstantT(value=value, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, (int, float)):
if isinstance(value, int):
typ = builtins.TInt()
elif isinstance(value, float):
typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, list):
begin_loc = self._add("[")
elts = []
for index, elt in value:
elts.append(self.quote(elt))
if index < len(value) - 1:
self._add(", ")
end_loc = self._add("]")
return asttyped.ListT(elts=elts, ctx=None, type=types.TVar(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
else:
raise "no"
# return asttyped.QuoteT(value=value, type=types.TVar())

def call(self, function_node, args, kwargs):
"""
Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments.
"""
arg_nodes = []
kwarg_nodes = []
kwarg_locs = []

name_loc = self._add(function_node.name)
begin_loc = self._add("(")
for index, arg in enumerate(args):
arg_nodes.append(self.quote(arg))
if index < len(args) - 1:
self._add(", ")
if any(args) and any(kwargs):
self._add(", ")
for index, kw in enumerate(kwargs):
arg_loc = self._add(kw)
equals_loc = self._add("=")
kwarg_locs.append((arg_loc, equals_loc))
kwarg_nodes.append(self.quote(kwargs[kw]))
if index < len(kwargs) - 1:
self._add(", ")
end_loc = self._add(")")

return asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None,
type=function_node.signature_type,
loc=name_loc),
args=arg_nodes,
keywords=[ast.keyword(arg=kw, value=value,
arg_loc=arg_loc, equals_loc=equals_loc,
loc=arg_loc.join(value.loc))
for kw, value, (arg_loc, equals_loc)
in zip(kwargs, kwarg_nodes, kwarg_locs)],
starargs=None, kwargs=None,
type=types.TVar(),
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))

class StitchingASTTypedRewriter(ASTTypedRewriter):
pass

class Stitcher:
def __init__(self, engine=None):
if engine is None:
self.engine = diagnostic.Engine(all_errors_are_fatal=True)
else:
self.engine = engine

self.asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, globals=prelude.globals())
self.inferencer = Inferencer(engine=self.engine)

self.name = "stitched"
self.typedtree = None
self.globals = self.asttyped_rewriter.globals

self.rpc_map = {}

def _iterate(self):
# Iterate inference to fixed point.
self.inference_finished = False
while not self.inference_finished:
self.inference_finished = True
self.inferencer.visit(self.typedtree)

def _parse_embedded_function(self, function):
if not hasattr(function, "artiq_embedded"):
raise ValueError("{} is not an embedded function".format(repr(function)))

# Extract function source.
embedded_function = function.artiq_embedded.function
source_code = inspect.getsource(embedded_function)
filename = embedded_function.__code__.co_filename
first_line = embedded_function.__code__.co_firstlineno

# Parse.
source_buffer = source.Buffer(source_code, filename, first_line)
parsetree, comments = parse_buffer(source_buffer, engine=self.engine)

# Rewrite into typed form.
typedtree = self.asttyped_rewriter.visit(parsetree)

return typedtree, typedtree.body[0]

def stitch_call(self, function, args, kwargs):
self.typedtree, function_node = self._parse_embedded_function(function)

# We synthesize fake source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.body.append(call_node)

self._iterate()
7 changes: 6 additions & 1 deletion artiq/compiler/transforms/asttyped_rewriter.py
Expand Up @@ -190,10 +190,15 @@ def __init__(self, engine, globals):
self.globals = None
self.env_stack = [globals]

def _find_name(self, name, loc):
def _try_find_name(self, name):
for typing_env in reversed(self.env_stack):
if name in typing_env:
return typing_env[name]

def _find_name(self, name, loc):
typ = self._try_find_name(name)
if typ is not None:
return typ
diag = diagnostic.Diagnostic("fatal",
"name '{name}' is not bound to anything", {"name":name}, loc)
self.engine.process(diag)
Expand Down
25 changes: 8 additions & 17 deletions artiq/coredevice/comm_generic.py
Expand Up @@ -3,9 +3,7 @@
from enum import Enum
from fractions import Fraction

from artiq.coredevice import runtime_exceptions
from artiq.language import core as core_language
from artiq.coredevice.rpc_wrapper import RPCWrapper


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -198,35 +196,28 @@ def _receive_rpc_values(self):
else:
r.append(self._receive_rpc_value(type_tag))

def _serve_rpc(self, rpc_wrapper, rpc_map, user_exception_map):
def _serve_rpc(self, rpc_map):
rpc_num = struct.unpack(">l", self.read(4))[0]
args = self._receive_rpc_values()
logger.debug("rpc service: %d %r", rpc_num, args)
eid, r = rpc_wrapper.run_rpc(
user_exception_map, rpc_map[rpc_num], args)
eid, r = rpc_wrapper.run_rpc(rpc_map[rpc_num], args)
self._write_header(9+2*4, _H2DMsgType.RPC_REPLY)
self.write(struct.pack(">ll", eid, r))
logger.debug("rpc service: %d %r == %r (eid %d)", rpc_num, args,
r, eid)

def _serve_exception(self, rpc_wrapper, user_exception_map):
def _serve_exception(self):
eid, p0, p1, p2 = struct.unpack(">lqqq", self.read(4+3*8))
rpc_wrapper.filter_rpc_exception(eid)
if eid < core_language.first_user_eid:
exception = runtime_exceptions.exception_map[eid]
raise exception(self.core, p0, p1, p2)
else:
exception = user_exception_map[eid]
raise exception

def serve(self, rpc_map, user_exception_map):
rpc_wrapper = RPCWrapper()
raise exception(self.core, p0, p1, p2)

def serve(self, rpc_map):
while True:
_, ty = self._read_header()
if ty == _D2HMsgType.RPC_REQUEST:
self._serve_rpc(rpc_wrapper, rpc_map, user_exception_map)
self._serve_rpc(rpc_map)
elif ty == _D2HMsgType.KERNEL_EXCEPTION:
self._serve_exception(rpc_wrapper, user_exception_map)
self._serve_exception()
elif ty == _D2HMsgType.KERNEL_FINISHED:
return
else:
Expand Down
125 changes: 33 additions & 92 deletions artiq/coredevice/core.py
@@ -1,50 +1,20 @@
import os
import os, sys, tempfile

from pythonparser import diagnostic

from artiq.language.core import *
from artiq.language.units import ns

from artiq.transforms.inline import inline
from artiq.transforms.quantize_time import quantize_time
from artiq.transforms.remove_inter_assigns import remove_inter_assigns
from artiq.transforms.fold_constants import fold_constants
from artiq.transforms.remove_dead_code import remove_dead_code
from artiq.transforms.unroll_loops import unroll_loops
from artiq.transforms.interleave import interleave
from artiq.transforms.lower_time import lower_time
from artiq.transforms.unparse import unparse

from artiq.coredevice.runtime import Runtime

from artiq.py2llvm import get_runtime_binary


def _announce_unparse(label, node):
print("*** Unparsing: "+label)
print(unparse(node))

from artiq.compiler import Stitcher, Module
from artiq.compiler.targets import OR1KTarget

def _make_debug_unparse(final):
try:
env = os.environ["ARTIQ_UNPARSE"]
except KeyError:
env = ""
selected_labels = set(env.split())
if "all" in selected_labels:
return _announce_unparse
else:
if "final" in selected_labels:
selected_labels.add(final)
# Import for side effects (creating the exception classes).
from artiq.coredevice import exceptions

def _filtered_unparse(label, node):
if label in selected_labels:
_announce_unparse(label, node)
return _filtered_unparse


def _no_debug_unparse(label, node):
class CompileError(Exception):
pass


class Core:
def __init__(self, dmgr, ref_period=8*ns, external_clock=False):
self.comm = dmgr.get("comm")
Expand All @@ -54,69 +24,40 @@ def __init__(self, dmgr, ref_period=8*ns, external_clock=False):
self.first_run = True
self.core = self
self.comm.core = self
self.runtime = Runtime()

def transform_stack(self, func_def, rpc_map, exception_map,
debug_unparse=_no_debug_unparse):
remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_1", func_def)

quantize_time(func_def, self.ref_period)
debug_unparse("quantize_time", func_def)

fold_constants(func_def)
debug_unparse("fold_constants_1", func_def)

unroll_loops(func_def, 500)
debug_unparse("unroll_loops", func_def)
def compile(self, function, args, kwargs, with_attr_writeback=True):
try:
engine = diagnostic.Engine(all_errors_are_fatal=True)

interleave(func_def)
debug_unparse("interleave", func_def)
stitcher = Stitcher(engine=engine)
stitcher.stitch_call(function, args, kwargs)

lower_time(func_def)
debug_unparse("lower_time", func_def)
module = Module(stitcher)
library = OR1KTarget().compile_and_link([module])

remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_2", func_def)
return library, stitcher.rpc_map
except diagnostic.Error as error:
print("\n".join(error.diagnostic.render(colored=True)), file=sys.stderr)
raise CompileError() from error

fold_constants(func_def)
debug_unparse("fold_constants_2", func_def)

remove_dead_code(func_def)
debug_unparse("remove_dead_code_1", func_def)

remove_inter_assigns(func_def)
debug_unparse("remove_inter_assigns_3", func_def)

fold_constants(func_def)
debug_unparse("fold_constants_3", func_def)

remove_dead_code(func_def)
debug_unparse("remove_dead_code_2", func_def)

def compile(self, k_function, k_args, k_kwargs, with_attr_writeback=True):
debug_unparse = _make_debug_unparse("remove_dead_code_2")

func_def, rpc_map, exception_map = inline(
self, k_function, k_args, k_kwargs, with_attr_writeback)
debug_unparse("inline", func_def)
self.transform_stack(func_def, rpc_map, exception_map, debug_unparse)

binary = get_runtime_binary(self.runtime, func_def)

return binary, rpc_map, exception_map

def run(self, k_function, k_args, k_kwargs):
def run(self, function, args, kwargs):
if self.first_run:
self.comm.check_ident()
self.comm.switch_clock(self.external_clock)
self.first_run = False

kernel_library, rpc_map = self.compile(function, args, kwargs)

try:
self.comm.load(kernel_library)
except Exception as error:
shlib_temp = tempfile.NamedTemporaryFile(suffix=".so", delete=False)
shlib_temp.write(kernel_library)
shlib_temp.close()
raise RuntimeError("shared library dumped to {}".format(shlib_temp.name)) from error

binary, rpc_map, exception_map = self.compile(
k_function, k_args, k_kwargs)
self.comm.load(binary)
self.comm.run(k_function.__name__)
self.comm.serve(rpc_map, exception_map)
self.first_run = False
self.comm.run()
self.comm.serve(rpc_map)

@kernel
def get_rtio_counter_mu(self):
Expand Down

0 comments on commit 353f454

Please sign in to comment.