Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: m-labs/artiq
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 1b81fc8a8fff
Choose a base ref
...
head repository: m-labs/artiq
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 8aebab580f9b
Choose a head ref
  • 2 commits
  • 1 file changed
  • 1 contributor

Commits on Sep 24, 2014

  1. Copy the full SHA
    82da734 View commit details
  2. Copy the full SHA
    8aebab5 View commit details
Showing with 37 additions and 21 deletions.
  1. +37 −21 artiq/transforms/inline.py
58 changes: 37 additions & 21 deletions artiq/transforms/inline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections import namedtuple, defaultdict
from collections import namedtuple
from fractions import Fraction
import inspect
import textwrap
@@ -10,13 +10,27 @@
from artiq.language import units


_UserVariable = namedtuple("_UserVariable", "name")
class _HostObjectMapper:
def __init__(self, first_encoding=0):
self._next_encoding = first_encoding
# id(object) -> (encoding, object)
# this format is required to support non-hashable host objects.
self._d = dict()

def encode(self, obj):
try:
return self._d[id(obj)][0]
except KeyError:
encoding = self._next_encoding
self._d[id(obj)] = (encoding, obj)
self._next_encoding += 1
return encoding

def get_map(self):
return {encoding: obj for i, (encoding, obj) in self._d.items()}

def _is_in_attr_list(obj, attr, al):
if not hasattr(obj, al):
return False
return attr in getattr(obj, al).split()

_UserVariable = namedtuple("_UserVariable", "name")


class _ReferenceManager:
@@ -26,8 +40,9 @@ def __init__(self):
self.to_inlined = dict()
# inlined_name -> use_count
self.use_count = dict()
self.rpc_map = defaultdict(lambda: len(self.rpc_map))
self.exception_map = defaultdict(lambda: len(self.exception_map))
self.rpc_mapper = _HostObjectMapper()
# exceptions 0-1023 are for runtime
self.exception_mapper = _HostObjectMapper(1024)
self.kernel_attr_init = []

# reserved names
@@ -66,7 +81,7 @@ def get(self, obj, func_name, ref):
return getattr(builtins, ref.id)
elif isinstance(ref, ast.Attribute):
target = self.get(obj, func_name, ref.value)
if _is_in_attr_list(target, ref.attr, "kernel_attr"):
if hasattr(target, "kernel_attr") and ref.attr in target.kernel_attr.split():
key = (id(target), ref.attr)
try:
ival = self.to_inlined[key]
@@ -89,13 +104,19 @@ def get(self, obj, func_name, ref):
raise NotImplementedError


_embeddable_calls = {
_embeddable_calls = (
core_language.delay, core_language.at, core_language.now,
core_language.syscall,
range, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
}
)

def _is_embeddable(call):
for ec in _embeddable_calls:
if call is ec:
return True
return False


class _ReferenceReplacer(ast.NodeVisitor):
@@ -162,7 +183,7 @@ def visit_Call(self, node):
func = self.rm.get(self.obj, self.func_name, node.func)
new_args = [self.visit(arg) for arg in node.args]

if func in _embeddable_calls:
if _is_embeddable(func):
new_func = ast.Name(func.__name__, ast.Load())
return ast.copy_location(
ast.Call(func=new_func, args=new_args,
@@ -183,7 +204,7 @@ def visit_Call(self, node):
body=inlined.body))
return ast.copy_location(ast.Name(retval_name, ast.Load()), node)
else:
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_map[func])]
args = [ast.Str("rpc"), value_to_ast(self.rm.rpc_mapper.encode(func))]
args += new_args
return ast.copy_location(
ast.Call(func=ast.Name("syscall", ast.Load()),
@@ -222,7 +243,7 @@ def visit_Raise(self, node):
exception_class = self.rm.get(self.obj, self.func_name, node.exc)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception must be a class")
exception_id = self.rm.exception_map[exception_class]
exception_id = self.rm.exception_mapper.encode(exception_class)
node.exc = ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)],
@@ -234,7 +255,7 @@ def _encode_exception(self, e):
exception_class = self.rm.get(self.obj, self.func_name, e)
if not inspect.isclass(exception_class):
raise NotImplementedError("Exception type must be a class")
exception_id = self.rm.exception_map[exception_class]
exception_id = self.rm.exception_mapper.encode(exception_class)
return ast.copy_location(
ast.Call(func=ast.Name("EncodedException", ast.Load()),
args=[value_to_ast(exception_id)],
@@ -308,9 +329,4 @@ def inline(core, k_function, k_args, k_kwargs, rm=None, retval_name=None):
if init_kernel_attr:
func_def.body[0:0] = rm.kernel_attr_init

r_rpc_map = dict((rpc_num, rpc_fun)
for rpc_fun, rpc_num in rm.rpc_map.items())
r_exception_map = dict((exception_num, exception_class)
for exception_class, exception_num
in rm.exception_map.items())
return func_def, r_rpc_map, r_exception_map
return func_def, rm.rpc_mapper.get_map(), rm.exception_mapper.get_map()