Skip to content

Commit

Permalink
Initial invocation of a @kernel function can now return a value (fixes
Browse files Browse the repository at this point in the history
whitequark committed Dec 18, 2015
1 parent e9afe5a commit 4fb1de3
Showing 4 changed files with 40 additions and 15 deletions.
25 changes: 21 additions & 4 deletions artiq/compiler/embedding.py
Original file line number Diff line number Diff line change
@@ -139,11 +139,15 @@ def quote(self, value):
return asttyped.QuoteT(value=value, type=instance_type,
loc=loc)

def call(self, function_node, args, kwargs):
def call(self, function_node, args, kwargs, callback=None):
"""
Construct an AST fragment calling a function specified by
an AST node `function_node`, with given arguments.
"""
if callback is not None:
callback_node = self.quote(callback)
cb_begin_loc = self._add("(")

arg_nodes = []
kwarg_nodes = []
kwarg_locs = []
@@ -165,7 +169,10 @@ def call(self, function_node, args, kwargs):
self._add(", ")
end_loc = self._add(")")

return asttyped.CallT(
if callback is not None:
cb_end_loc = self._add(")")

node = asttyped.CallT(
func=asttyped.NameT(id=function_node.name, ctx=None,
type=function_node.signature_type,
loc=name_loc),
@@ -180,6 +187,16 @@ def call(self, function_node, args, kwargs):
begin_loc=begin_loc, end_loc=end_loc, star_loc=None, dstar_loc=None,
loc=name_loc.join(end_loc))

if callback is not None:
node = asttyped.CallT(
func=callback_node,
args=[node], keywords=[], starargs=None, kwargs=None,
type=builtins.TNone(), iodelay=None,
begin_loc=cb_begin_loc, end_loc=cb_end_loc, star_loc=None, dstar_loc=None,
loc=callback_node.loc.join(cb_end_loc))

return node

def assign_local(self, var_name, value):
name_loc = self._add(var_name)
_ = self._add(" ")
@@ -426,14 +443,14 @@ def __init__(self, engine=None):
self.type_map = {}
self.value_map = defaultdict(lambda: [])

def stitch_call(self, function, args, kwargs):
def stitch_call(self, function, args, kwargs, callback=None):
function_node = self._quote_embedded_function(function)
self.typedtree.append(function_node)

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = self._synthesizer()
call_node = synthesizer.call(function_node, args, kwargs)
call_node = synthesizer.call(function_node, args, kwargs, callback)
synthesizer.finalize()
self.typedtree.append(call_node)

9 changes: 6 additions & 3 deletions artiq/compiler/transforms/llvm_ir_generator.py
Original file line number Diff line number Diff line change
@@ -907,9 +907,12 @@ def ret_error_handler(typ):

llargs = []
for arg in args:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
if builtins.is_none(arg.type):
llargslot = self.llbuilder.alloca(ll.LiteralStructType([]))
else:
llarg = self.map(arg)
llargslot = self.llbuilder.alloca(llarg.type)
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)

self.llbuilder.call(self.llbuiltin("send_rpc"),
13 changes: 10 additions & 3 deletions artiq/coredevice/core.py
Original file line number Diff line number Diff line change
@@ -56,12 +56,12 @@ def __init__(self, dmgr, ref_period=8*ns, external_clock=False, comm_device="com
self.core = self
self.comm.core = self

def compile(self, function, args, kwargs, with_attr_writeback=True):
def compile(self, function, args, kwargs, set_result, with_attr_writeback=True):
try:
engine = diagnostic.Engine(all_errors_are_fatal=True)

stitcher = Stitcher(engine=engine)
stitcher.stitch_call(function, args, kwargs)
stitcher.stitch_call(function, args, kwargs, set_result)
stitcher.finalize()

module = Module(stitcher, ref_period=self.ref_period)
@@ -76,7 +76,12 @@ def compile(self, function, args, kwargs, with_attr_writeback=True):
raise CompileError(error.diagnostic) from error

def run(self, function, args, kwargs):
object_map, kernel_library, symbolizer = self.compile(function, args, kwargs)
result = None
def set_result(new_result):
nonlocal result
result = new_result

object_map, kernel_library, symbolizer = self.compile(function, args, kwargs, set_result)

if self.first_run:
self.comm.check_ident()
@@ -87,6 +92,8 @@ def run(self, function, args, kwargs):
self.comm.run()
self.comm.serve(object_map, symbolizer)

return result

@kernel
def get_rtio_counter_mu(self):
return rtio_get_counter()
8 changes: 3 additions & 5 deletions artiq/test/coredevice/embedding.py
Original file line number Diff line number Diff line change
@@ -50,12 +50,10 @@ def test(self, foo=42) -> TInt32:
return foo

@kernel
def run(self, callback):
callback(self.test())
def run(self):
return self.test()

class DefaultArgTest(ExperimentCase):
def test_default_arg(self):
exp = self.create(DefaultArg)
def callback(value):
self.assertEqual(value, 42)
exp.run(callback)
self.assertEqual(exp.run(), 42)

0 comments on commit 4fb1de3

Please sign in to comment.