Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: m-labs/artiq
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: e22301ea0589
Choose a base ref
...
head repository: m-labs/artiq
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 46165f3b5074
Choose a head ref
  • 2 commits
  • 2 files changed
  • 1 contributor

Commits on Oct 8, 2014

  1. Unverified

    This user has not yet uploaded their public signing key.
    Copy the full SHA
    2920ac8 View commit details
  2. Copy the full SHA
    46165f3 View commit details
Showing with 120 additions and 63 deletions.
  1. +97 −53 artiq/transforms/inline.py
  2. +23 −10 test/full_stack.py
150 changes: 97 additions & 53 deletions artiq/transforms/inline.py
Original file line number Diff line number Diff line change
@@ -33,71 +33,91 @@ def get_map(self):
_UserVariable = namedtuple("_UserVariable", "name")


def _is_kernel_attr(value, attr):
return hasattr(value, "kernel_attr") and attr in value.kernel_attr.split()


class _ReferenceManager:
def __init__(self):
# (id(obj), func_name, local_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / ast / constant_object
self.to_inlined = dict()
# inlined_name -> use_count
self.use_count = dict()
self.rpc_mapper = _HostObjectMapper()
self.exception_mapper = _HostObjectMapper(core_language.first_user_eid)
self.kernel_attr_init = []

# (id(obj), func_name, ref_name) or (id(obj), kernel_attr_name)
# -> _UserVariable(name) / ast / constant_object
self._to_inlined = dict()
# inlined_name -> use_count
self._use_count = dict()
# reserved names
for kg in core_language.kernel_globals:
self.use_count[kg] = 1
self._use_count[kg] = 1
for name in ("int", "round", "int64", "round64", "float", "array",
"range", "Fraction", "Quantity", "EncodedException"):
self.use_count[name] = 1
self._use_count[name] = 1

# node_or_value can be a AST node, used to inline function parameter values
# that can be simplified later through constant folding.
def register_replace(self, obj, func_name, ref_name, node_or_value):
self._to_inlined[(id(obj), func_name, ref_name)] = node_or_value

def new_name(self, base_name):
if base_name[-1].isdigit():
base_name += "_"
if base_name in self.use_count:
r = base_name + str(self.use_count[base_name])
self.use_count[base_name] += 1
if base_name in self._use_count:
r = base_name + str(self._use_count[base_name])
self._use_count[base_name] += 1
return r
else:
self.use_count[base_name] = 1
self._use_count[base_name] = 1
return base_name

def get(self, obj, func_name, ref):
if isinstance(ref, ast.Name):
key = (id(obj), func_name, ref.id)
try:
return self.to_inlined[key]
except KeyError:
if isinstance(ref.ctx, ast.Store):
ival = _UserVariable(self.new_name(ref.id))
self.to_inlined[key] = ival
return ival
else:
try:
return inspect.getmodule(obj).__dict__[ref.id]
except KeyError:
return getattr(builtins, ref.id)
elif isinstance(ref, ast.Attribute):
target = self.get(obj, func_name, ref.value)
if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split():
key = (id(target), ref.attr)
try:
ival = self.to_inlined[key]
assert(isinstance(ival, _UserVariable))
except KeyError:
iname = self.new_name(ref.attr)
ival = _UserVariable(iname)
self.to_inlined[key] = ival
a = value_to_ast(getattr(target, ref.attr))
if a is None:
raise NotImplementedError(
"Cannot represent initial value"
" of kernel attribute")
self.kernel_attr_init.append(ast.Assign(
[ast.Name(iname, ast.Store())], a))
def resolve_name(self, obj, func_name, ref_name, store):
key = (id(obj), func_name, ref_name)
try:
return self._to_inlined[key]
except KeyError:
if store:
ival = _UserVariable(self.new_name(ref_name))
self._to_inlined[key] = ival
return ival
else:
return getattr(target, ref.attr)
try:
return inspect.getmodule(obj).__dict__[ref_name]
except KeyError:
return getattr(builtins, ref_name)

def resolve_attr(self, value, attr):
if _is_kernel_attr(value, attr):
key = (id(value), attr)
try:
ival = self._to_inlined[key]
assert(isinstance(ival, _UserVariable))
except KeyError:
iname = self.new_name(attr)
ival = _UserVariable(iname)
self._to_inlined[key] = ival
a = value_to_ast(getattr(value, attr))
if a is None:
raise NotImplementedError(
"Cannot represent initial value"
" of kernel attribute")
self.kernel_attr_init.append(ast.Assign(
[ast.Name(iname, ast.Store())], a))
return ival
else:
return getattr(value, attr)

def resolve_constant(self, obj, func_name, node):
if isinstance(node, ast.Name):
c = self.resolve_name(obj, func_name, node.id, False)
if isinstance(c, (_UserVariable, ast.AST)):
raise ValueError("Not a constant")
return c
elif isinstance(node, ast.Attribute):
value = self.resolve_constant(obj, func_name, node.value)
if _is_kernel_attr(value, node.attr):
raise ValueError("Not a constant")
return getattr(value, node.attr)
else:
raise NotImplementedError

@@ -156,9 +176,9 @@ def generic_visit(self, node):
setattr(node, field, new_node)
return node

def visit_ref(self, node):
def visit_Name(self, node):
store = isinstance(node.ctx, ast.Store)
ival = self.rm.get(self.obj, self.func_name, node)
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, store)
if isinstance(ival, _UserVariable):
newnode = ast.Name(ival.name, node.ctx)
elif isinstance(ival, ast.AST):
@@ -175,11 +195,34 @@ def visit_ref(self, node):
"Cannot represent inlined value")
return ast.copy_location(newnode, node)

visit_Name = visit_ref
visit_Attribute = visit_ref
def _resolve_attribute(self, node):
if isinstance(node, ast.Name):
ival = self.rm.resolve_name(self.obj, self.func_name, node.id, False)
if isinstance(ival, _UserVariable):
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
else:
return ival
elif isinstance(node, ast.Attribute):
value = self._resolve_attribute(node.value)
if isinstance(value, ast.AST):
node.value = value
return node
else:
return self.rm.resolve_attr(value, node.attr)
else:
return self.visit(node)

def visit_Attribute(self, node):
ival = self._resolve_attribute(node)
if isinstance(ival, ast.AST):
return ival
elif isinstance(ival, _UserVariable):
return ast.copy_location(ast.Name(ival.name, ast.Load()), node)
else:
return value_to_ast(ival)

def visit_Call(self, node):
func = self.rm.get(self.obj, self.func_name, node.func)
func = self.rm.resolve_constant(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args]

if _is_embeddable(func):
@@ -240,7 +283,7 @@ def visit_FunctionDef(self, node):
return node

def _encode_exception(self, e):
exception_class = self.rm.get(self.obj, self.func_name, e)
exception_class = self.rm.resolve_constant(self.obj, self.func_name, e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
if issubclass(exception_class, core_language.RuntimeException):
@@ -301,9 +344,10 @@ def _initialize_function_params(func_def, k_args, k_kwargs, rm):
for arg_ast, arg_value in zip(func_def.args.args, k_args):
arg_name = arg_ast.arg
if arg_name in rop:
rm.to_inlined[(id(obj), func_name, arg_name)] = arg_value
rm.register_replace(obj, func_name, arg_name, arg_value)
else:
target = rm.get(obj, func_name, ast.Name(arg_name, ast.Store()))
uservar = rm.resolve_name(obj, func_name, arg_name, True)
target = ast.Name(uservar.name, ast.Store())
value = value_to_ast(arg_value)
param_init.append(ast.Assign(targets=[target], value=value))
return param_init
33 changes: 23 additions & 10 deletions test/full_stack.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from operator import itemgetter

from artiq import *
from artiq.devices import corecom_serial, core, runtime_exceptions, rtio_core
@@ -36,22 +37,29 @@ def run(self):


class _PulseLogger(AutoContext):
parameters = "name"
parameters = "output_list name"

def print_on(self, t, f):
print("{} ON:{:4} @{}".format(self.name, f, t))
def _append(self, t, l, f):
if not hasattr(self, "first_timestamp"):
self.first_timestamp = t
self.output_list.append((self.name, t-self.first_timestamp, l, f))

def print_off(self, t):
print("{} OFF @{}".format(self.name, t))
def on(self, t, f):
self._append(t, True, f)

def off(self, t):
self._append(t, False, 0)

@kernel
def pulse(self, f, duration):
self.print_on(int(now()), f)
self.on(int(now().amount*1000000000), f)
delay(duration)
self.print_off(int(now()))
self.off(int(now().amount*1000000000))


class _Pulses(AutoContext):
parameters = "output_list"

def build(self):
for name in "a", "b", "c", "d":
pl = _PulseLogger(self, name=name)
@@ -123,9 +131,14 @@ def test_primes(self):
self.assertEqual(l_device, l_host)

def test_pulses(self):
# TODO: compare results on host and device
# (this requires better unit management in the compiler)
_run_on_device(_Pulses)
l_device, l_host = [], []
_run_on_device(_Pulses, output_list=l_device)
_run_on_host(_Pulses, output_list=l_host)
l_host = sorted(l_host, key=itemgetter(1))
for channel in "a", "b", "c", "d":
c_device = [x for x in l_device if x[0] == channel]
c_host = [x for x in l_host if x[0] == channel]
self.assertEqual(c_device, c_host)

def test_exceptions(self):
t_device, t_host = [], []