Skip to content

Commit

Permalink
Implement sending RPCs.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 8, 2015
1 parent 22457bc commit b26af5d
Show file tree
Hide file tree
Showing 11 changed files with 432 additions and 130 deletions.
4 changes: 2 additions & 2 deletions artiq/compiler/builtins.py
Expand Up @@ -163,7 +163,7 @@ def is_bool(typ):

def is_int(typ, width=None):
if width is not None:
return types.is_mono(typ, "int", {"width": width})
return types.is_mono(typ, "int", width=width)
else:
return types.is_mono(typ, "int")

Expand All @@ -184,7 +184,7 @@ def is_numeric(typ):

def is_list(typ, elt=None):
if elt is not None:
return types.is_mono(typ, "list", {"elt": elt})
return types.is_mono(typ, "list", elt=elt)
else:
return types.is_mono(typ, "list")

Expand Down
125 changes: 114 additions & 11 deletions artiq/compiler/embedding.py
Expand Up @@ -5,10 +5,13 @@
annotated as ``@kernel`` when they are referenced.
"""

import inspect, os
import os, re, linecache, inspect
from collections import OrderedDict

from pythonparser import ast, source, diagnostic, parse_buffer

from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer


class ASTSynthesizer:
Expand Down Expand Up @@ -45,6 +48,9 @@ def quote(self, value):
typ = builtins.TFloat()
return asttyped.NumT(n=value, ctx=None, type=typ,
loc=self._add(repr(value)))
elif isinstance(value, str):
return asttyped.StrT(s=value, ctx=None, type=builtins.TStr(),
loc=self._add(repr(value)))
elif isinstance(value, list):
begin_loc = self._add("[")
elts = []
Expand Down Expand Up @@ -123,7 +129,7 @@ def visit_Name(self, node):
if inspect.isfunction(value):
# It's a function. We need to translate the function and insert
# a reference to it.
function_name = self.quote_function(value)
function_name = self.quote_function(value, node.loc)
return asttyped.NameT(id=function_name, ctx=None,
type=self.globals[function_name],
loc=node.loc)
Expand Down Expand Up @@ -154,7 +160,19 @@ def __init__(self, engine=None):

self.functions = {}

self.next_rpc = 0
self.rpc_map = {}
self.inverse_rpc_map = {}

def _map(self, obj):
obj_id = id(obj)
if obj_id in self.inverse_rpc_map:
return self.inverse_rpc_map[obj_id]

self.next_rpc += 1
self.rpc_map[self.next_rpc] = obj
self.inverse_rpc_map[obj_id] = self.next_rpc
return self.next_rpc

def _iterate(self):
inferencer = Inferencer(engine=self.engine)
Expand Down Expand Up @@ -213,17 +231,102 @@ def _quote_embedded_function(self, function):
quote_function=self._quote_function)
return asttyped_rewriter.visit(function_node)

def _quote_function(self, function):
def _function_def_note(self, function):
filename = function.__code__.co_filename
line = function.__code__.co_firstlineno
name = function.__code__.co_name

source_line = linecache.getline(filename, line)
column = re.search("def", source_line).start(0)
source_buffer = source.Buffer(source_line, filename, line)
loc = source.Range(source_buffer, column, column)
return diagnostic.Diagnostic("note",
"definition of function '{function}'",
{"function": name},
loc)

def _type_of_param(self, function, loc, param):
if param.default is not inspect.Parameter.empty:
# Try and infer the type from the default value.
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer()
ast = synthesizer.quote(param.default)
synthesizer.finalize()

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" unannotated optional argument '{param_name}' from its default value",
{"param_name": param.name},
loc)
diag.notes.append(note)

diag.notes.append(self._function_def_note(function))

self.engine.process(diag)

proxy_engine = diagnostic.Engine()
proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)

return ast.type
else:
# Let the rest of the program decide.
return types.TVar()

def _quote_rpc_function(self, function, loc):
signature = inspect.signature(function)

arg_types = OrderedDict()
optarg_types = OrderedDict()
for param in signature.parameters.values():
if param.kind not in (inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD):
# We pretend we don't see *args, kwpostargs=..., **kwargs.
# Since every method can be still invoked without any arguments
# going into *args and the slots after it, this is always safe,
# if sometimes constraining.
#
# Accepting POSITIONAL_ONLY is OK, because the compiler
# desugars the keyword arguments into positional ones internally.
continue

if param.default is inspect.Parameter.empty:
arg_types[param.name] = self._type_of_param(function, loc, param)
else:
optarg_types[param.name] = self._type_of_param(function, loc, param)

# Fixed for now.
ret_type = builtins.TInt(types.TValue(32))

rpc_type = types.TRPCFunction(arg_types, optarg_types, ret_type,
service=self._map(function))

rpc_name = "__rpc_{}__".format(rpc_type.service)
self.globals[rpc_name] = rpc_type
self.functions[function] = rpc_name

return rpc_name

def _quote_function(self, function, loc):
if function in self.functions:
return self.functions[function]

# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name
if hasattr(function, "artiq_embedded"):
# Insert the typed AST for the new function and restart inference.
# It doesn't really matter where we insert as long as it is before
# the final call.
function_node = self._quote_embedded_function(function)
self.typedtree.insert(0, function_node)
self.inference_finished = False
return function_node.name
else:
# Insert a storage-less global whose type instructs the compiler
# to perform an RPC instead of a regular call.
return self._quote_rpc_function(function, loc)

def stitch_call(self, function, args, kwargs):
function_node = self._quote_embedded_function(function)
Expand Down
19 changes: 11 additions & 8 deletions artiq/compiler/module.py
Expand Up @@ -41,14 +41,16 @@ def from_filename(cls, filename, engine=None):

class Module:
def __init__(self, src):
int_monomorphizer = transforms.IntMonomorphizer(engine=src.engine)
inferencer = transforms.Inferencer(engine=src.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=src.engine)
escape_validator = validators.EscapeValidator(engine=src.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=src.engine,
self.engine = src.engine

int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine)
artiq_ir_generator = transforms.ARTIQIRGenerator(engine=self.engine,
module_name=src.name)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=src.engine)
local_access_validator = validators.LocalAccessValidator(engine=src.engine)
dead_code_eliminator = transforms.DeadCodeEliminator(engine=self.engine)
local_access_validator = validators.LocalAccessValidator(engine=self.engine)

self.name = src.name
self.globals = src.globals
Expand All @@ -62,7 +64,8 @@ def __init__(self, src):

def build_llvm_ir(self, target):
"""Compile the module to LLVM IR for the specified target."""
llvm_ir_generator = transforms.LLVMIRGenerator(module_name=self.name, target=target)
llvm_ir_generator = transforms.LLVMIRGenerator(engine=self.engine,
module_name=self.name, target=target)
return llvm_ir_generator.process(self.artiq_ir)

def entry_point(self):
Expand Down
98 changes: 89 additions & 9 deletions artiq/compiler/transforms/llvm_ir_generator.py
Expand Up @@ -3,12 +3,13 @@
into LLVM intermediate representation.
"""

from pythonparser import ast
from pythonparser import ast, diagnostic
from llvmlite_artiq import ir as ll
from .. import types, builtins, ir

class LLVMIRGenerator:
def __init__(self, module_name, target):
def __init__(self, engine, module_name, target):
self.engine = engine
self.target = target
self.llcontext = target.llcontext
self.llmodule = ll.Module(context=self.llcontext, name=module_name)
Expand All @@ -21,6 +22,11 @@ def llty_of_type(self, typ, bare=False, for_return=False):
typ = typ.find()
if types.is_tuple(typ):
return ll.LiteralStructType([self.llty_of_type(eltty) for eltty in typ.elts])
elif types.is_rpc_function(typ):
if for_return:
return ll.VoidType()
else:
return ll.LiteralStructType([])
elif types.is_function(typ):
envarg = ll.IntType(8).as_pointer()
llty = ll.FunctionType(args=[envarg] +
Expand Down Expand Up @@ -89,10 +95,13 @@ def llconst_of_const(self, const):
return ll.Constant(llty, False)
elif isinstance(const.value, (int, float)):
return ll.Constant(llty, const.value)
elif isinstance(const.value, str):
assert "\0" not in const.value
elif isinstance(const.value, (str, bytes)):
if isinstance(const.value, str):
assert "\0" not in const.value
as_bytes = (const.value + "\0").encode("utf-8")
else:
as_bytes = const.value

as_bytes = (const.value + "\0").encode("utf-8")
if ir.is_exn_typeinfo(const.type):
# Exception typeinfo; should be merged with identical others
name = "__artiq_exn_" + const.value
Expand Down Expand Up @@ -144,6 +153,9 @@ def llbuiltin(self, name):
llty = ll.FunctionType(ll.VoidType(), [self.llty_of_type(builtins.TException())])
elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True)
else:
assert False

Expand Down Expand Up @@ -546,11 +558,79 @@ def process_Closure(self, insn):
name=insn.name)
return llvalue

# See session.c:send_rpc_value.
def _rpc_tag(self, typ, root_type, root_loc):
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
elif builtins.is_bool(typ):
return b"b"
elif builtins.is_int(typ, types.TValue(32)):
return b"i"
elif builtins.is_int(typ, types.TValue(64)):
return b"I"
elif builtins.is_float(typ):
return b"f"
elif builtins.is_str(typ):
return b"s"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc)
else:
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(root_type)},
root_loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)},
root_loc)
self.engine.process(diag)

def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)

tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
tag += b":\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))

llargs = []
for arg in args:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)

return self.llbuiltin("rpc"), [llservice, lltag] + llargs

def prepare_call(self, insn):
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().type.service,
insn.arguments(),
insn.target_function().type.ret)
else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)

def process_Call(self, insn):
llfun, llargs = self.prepare_call(insn)
Expand Down

0 comments on commit b26af5d

Please sign in to comment.