Skip to content

Commit

Permalink
transforms/inline: object attribute writeback
Browse files Browse the repository at this point in the history
sbourdeauducq committed Nov 3, 2014
1 parent f54a2f9 commit e9e12ad
Showing 2 changed files with 67 additions and 12 deletions.
61 changes: 50 additions & 11 deletions artiq/transforms/inline.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
import builtins
from fractions import Fraction
from collections import OrderedDict
from functools import partial

from artiq.language import core as core_language
from artiq.language import units
@@ -371,6 +372,52 @@ def get_map(self):
return {encoding: obj for i, (encoding, obj) in self._d.items()}


def get_attr_init(attribute_namespace, loc_node):
attr_init = []
for (_, attr), attr_info in attribute_namespace.items():
if hasattr(attr_info.obj, attr):
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), loc_node)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
loc_node)
assign = ast.copy_location(ast.Assign([target], value),
loc_node)
attr_init.append(assign)
return attr_init


def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
attr_writeback = []
for (_, attr), attr_info in attribute_namespace.items():
if attr_info.read_write:
# HACK/FIXME: since RPC of non-int is not supported yet, skip
# writeback of other types for now.
# This code breaks if an int is promoted to int64
if hasattr(attr_info.obj, attr):
val = getattr(attr_info.obj, attr)
if (not isinstance(val, int)
or isinstance(val, core_language.int64)):
continue
#

setter = partial(setattr, attr_info.obj, attr)
func = ast.copy_location(
ast.Name("syscall", ast.Load()), loc_node)
arg1 = ast.copy_location(ast.Str("rpc"), loc_node)
arg2 = ast.copy_location(
value_to_ast(rpc_mapper.encode(setter)), loc_node)
arg3 = ast.copy_location(
ast.Name(attr_info.mangled_name, ast.Load()), loc_node)
call = ast.copy_location(
ast.Call(func=func, args=[arg1, arg2, arg3],
keywords=[], starargs=None, kwargs=None),
loc_node)
expr = ast.copy_location(ast.Expr(call), loc_node)
attr_writeback.append(expr)
return attr_writeback


def inline(core, k_function, k_args, k_kwargs):
if k_kwargs:
raise NotImplementedError(
@@ -392,16 +439,8 @@ def inline(core, k_function, k_args, k_kwargs):
func=k_function,
args=k_args)

param_init = []
for (_, attr), attr_info in attribute_namespace.items():
value = getattr(attr_info.obj, attr)
value = ast.copy_location(value_to_ast(value), func_def)
target = ast.copy_location(ast.Name(attr_info.mangled_name,
ast.Store()),
func_def)
assign = ast.copy_location(ast.Assign([target], value),
func_def)
param_init.append(assign)
func_def.body[0:0] = param_init
func_def.body[0:0] = get_attr_init(attribute_namespace, func_def)
func_def.body += get_attr_writeback(attribute_namespace, mappers.rpc,
func_def)

return func_def, mappers.rpc.get_map(), mappers.exception.get_map()
18 changes: 17 additions & 1 deletion test/full_stack.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,15 @@ def run(self):
self.output_list.append(x)


class _Attributes(AutoContext):
def build(self):
self.input = 84

@kernel
def run(self):
self.result = self.input//2


class _PulseLogger(AutoContext):
parameters = "output_list name"

@@ -123,13 +132,20 @@ def run(self):
self.trace.append(104)


class SimCompareCase(unittest.TestCase):
class ExecutionCase(unittest.TestCase):
def test_primes(self):
l_device, l_host = [], []
_run_on_device(_Primes, max=100, output_list=l_device)
_run_on_host(_Primes, max=100, output_list=l_host)
self.assertEqual(l_device, l_host)

def test_attributes(self):
with comm_serial.Comm() as comm:
coredev = core.Core(comm)
uut = _Attributes(core=coredev)
uut.run()
self.assertEqual(uut.result, 42)

def test_pulses(self):
l_device, l_host = [], []
_run_on_device(_Pulses, output_list=l_device)

0 comments on commit e9e12ad

Please sign in to comment.