Skip to content

Commit

Permalink
units: error checking
Browse files Browse the repository at this point in the history
sbourdeauducq committed Nov 23, 2014
1 parent d59d110 commit a3f9817
Showing 6 changed files with 172 additions and 32 deletions.
5 changes: 4 additions & 1 deletion artiq/language/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Core ARTIQ extensions to the Python language."""
"""
Core ARTIQ extensions to the Python language.
"""

from collections import namedtuple as _namedtuple
from copy import copy as _copy
70 changes: 64 additions & 6 deletions artiq/language/units.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,48 @@
"""
Definition and management of physical units.
"""

from fractions import Fraction as _Fraction


_prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)
class DimensionError(Exception):
"""Raised when attempting an operation with incompatible units.
When targeting the core device, all units are statically managed at
compilation time. Thus, when raised by functions in this module, this
exception cannot be caught in the kernel as it is raised by the compiler
instead.
"""
pass


def mul_dimension(l, r):
"""Returns the unit obtained by multiplying unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l is None:
return r
if r is None:
return l
if {l, r} == {"Hz", "s"}:
return None
raise DimensionError


def _rmul_dimension(l, r):
return mul_dimension(r, l)


def div_dimension(l, r):
"""Returns the unit obtained by dividing unit ``l`` with unit ``r``.
Raises ``DimensionError`` if the resulting unit is not implemented.
"""
if l == r:
return None
if r is None:
@@ -28,17 +52,28 @@ def div_dimension(l, r):
return "Hz"
if r == "Hz":
return "s"
raise DimensionError


def _rdiv_dimension(l, r):
return div_dimension(r, l)


def addsub_dimension(x, y):
"""Returns the unit obtained by adding or subtracting unit ``l`` with
unit ``r``.
Raises ``DimensionError`` if ``l`` and ``r`` are different.
"""
if x == y:
return x
else:
return None
raise DimensionError


_prefixes_str = "pnum_kMG"
_smallest_prefix = _Fraction(1, 10**12)


def _format(amount, unit):
@@ -139,9 +174,9 @@ def __rmod__(self, other):

# comparisons
def _cmp(self, other, opf_name):
if isinstance(other, Quantity):
other = other.amount
return getattr(self.amount, opf_name)(other)
if not isinstance(other, Quantity) or other.unit != self.unit:
raise DimensionError
return getattr(self.amount, opf_name)(other.amount)

def __lt__(self, other):
return self._cmp(other, "__lt__")
@@ -173,3 +208,26 @@ def _register_unit(unit, prefixes):

_register_unit("s", "pnum_")
_register_unit("Hz", "_kMG")


def check_unit(value, unit):
"""Checks that the value has the specified unit. Unit specification is
a string representing the unit without any prefix (e.g. ``s``, ``Hz``).
Checking for a dimensionless value (not a ``Quantity`` instance) is done
by setting ``unit`` to ``None``.
If the units do not match, ``DimensionError`` is raised.
This function can be used in kernels and is executed at compilation time.
There is already unit checking built into the arithmetic, so you typically
need to use this function only when using the ``amount`` property of
``Quantity``.
"""
if unit is None:
if isinstance(value, Quantity):
raise DimensionError
else:
if not isinstance(value, Quantity) or value.unit != unit:
raise DimensionError
22 changes: 3 additions & 19 deletions artiq/transforms/inline.py
Original file line number Diff line number Diff line change
@@ -3,14 +3,15 @@
import ast
import types
import builtins
from copy import copy
from fractions import Fraction
from collections import OrderedDict
from functools import partial
from itertools import zip_longest, chain

from artiq.language import core as core_language
from artiq.language import units
from artiq.transforms.tools import value_to_ast, NotASTRepresentable
from artiq.transforms.tools import *


def new_mangled_name(in_use_names, name):
@@ -35,23 +36,6 @@ def __init__(self, obj, mangled_name, read_write):
self.read_write = read_write


embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, core_language.EncodedException
)


def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False


def is_inlinable(core, func):
if hasattr(func, "k_function_info"):
if func.k_function_info.core_name == "":
@@ -493,7 +477,7 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
def inline(core, k_function, k_args, k_kwargs):
# OrderedDict prevents non-determinism in attribute init
attribute_namespace = OrderedDict()
in_use_names = {func.__name__ for func in embeddable_funcs}
in_use_names = copy(embeddable_func_names)
mappers = types.SimpleNamespace(
rpc=HostObjectMapper(),
exception=HostObjectMapper(core_language.first_user_eid)
63 changes: 58 additions & 5 deletions artiq/transforms/lower_units.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from copy import copy

from artiq.language import units
from artiq.transforms.tools import embeddable_func_names


def _add_units(f, unit_list):
@@ -16,6 +17,7 @@ def wrapper(*args):
class _UnitsLowerer(ast.NodeTransformer):
def __init__(self, rpc_map):
self.rpc_map = rpc_map
# (original rpc number, (unit list)) -> new rpc number
self.rpc_remap = defaultdict(lambda: len(self.rpc_remap))
self.variable_units = dict()

@@ -29,6 +31,22 @@ def visit_Name(self, node):
node.unit = unit
return node

def visit_BoolOp(self, node):
self.generic_visit(node)
us = [getattr(value, "unit", None) for value in node.values]
if not all(u == us[0] for u in us[1:]):
raise units.DimensionError
return node

def visit_Compare(self, node):
self.generic_visit(node)
u0 = getattr(node.left, "unit", None)
us = [getattr(comparator, "unit", None)
for comparator in node.comparators]
if not all(u == u0 for u in us):
raise units.DimensionError
return node

def visit_UnaryOp(self, node):
self.generic_visit(node)
if hasattr(node.operand, "unit"):
@@ -47,6 +65,8 @@ def visit_BinOp(self, node):
elif op in (ast.Div, ast.FloorDiv):
unit = units.div_dimension(left_unit, right_unit)
else:
if left_unit is not None or right_unit is not None:
raise units.DimensionError
unit = None
if unit is not None:
node.unit = unit
@@ -66,14 +86,47 @@ def visit_Call(self, node):
amount, unit = node.args
amount.unit = unit.s
return amount
elif node.func.id == "now":
elif node.func.id in ("now", "cycles_to_time"):
node.unit = "s"
elif node.func.id == "syscall" and node.args[0].s == "rpc":
unit_list = tuple(getattr(arg, "unit", None) for arg in node.args[2:])
rpc_n = node.args[1].n
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
elif node.func.id == "syscall":
# only RPCs can have units
if node.args[0].s == "rpc":
unit_list = tuple(getattr(arg, "unit", None)
for arg in node.args[2:])
rpc_n = node.args[1].n
node.args[1].n = self.rpc_remap[(rpc_n, (unit_list))]
else:
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
elif node.func.id in ("delay", "at", "time_to_cycles"):
if getattr(node.args[0], "unit", None) != "s":
raise units.DimensionError
elif node.func.id == "check_unit":
self.generic_visit(node)
elif node.func.id in embeddable_func_names:
# must be last (some embeddable funcs may have units)
if any(hasattr(arg, "unit") for arg in node.args):
raise units.DimensionError
return node

def visit_Expr(self, node):
self.generic_visit(node)
if (isinstance(node.value, ast.Call)
and node.value.func.id == "check_unit"):
call = node.value
if (isinstance(call.args[1], ast.NameConstant)
and call.args[1].value is None):
if hasattr(call.value.args[0], "unit"):
raise units.DimensionError
elif isinstance(call.args[1], ast.Str):
if getattr(call.args[0], "unit", None) != call.args[1].s:
raise units.DimensionError
else:
raise NotImplementedError
return None
else:
return node

def _update_target(self, target, unit):
if isinstance(target, ast.Name):
if target.id in self.variable_units:
18 changes: 18 additions & 0 deletions artiq/transforms/tools.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,24 @@
from artiq.language import units


embeddable_funcs = (
core_language.delay, core_language.at, core_language.now,
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
)
embeddable_func_names = {func.__name__ for func in embeddable_funcs}


def is_embeddable(func):
for ef in embeddable_funcs:
if func is ef:
return True
return False


def eval_ast(expr, symdict=dict()):
if not isinstance(expr, ast.Expression):
expr = ast.copy_location(ast.Expression(expr), expr)
26 changes: 25 additions & 1 deletion test/full_stack.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from fractions import Fraction

from artiq import *
from artiq.language.units import Quantity
from artiq.language.units import *
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices

@@ -56,6 +56,22 @@ def run(self):
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s"))

@kernel
def dimension_error1(self):
print(1*Hz + 1*s)

@kernel
def dimension_error2(self):
print(1*Hz < 1*s)

@kernel
def dimension_error3(self):
check_unit(1*Hz, "s")

@kernel
def dimension_error4(self):
delay(10*Hz)


class _PulseLogger(AutoContext):
parameters = "output_list name"
@@ -163,6 +179,14 @@ def test_misc(self):
Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")])
with self.assertRaises(DimensionError):
uut.dimension_error1()
with self.assertRaises(DimensionError):
uut.dimension_error2()
with self.assertRaises(DimensionError):
uut.dimension_error3()
with self.assertRaises(DimensionError):
uut.dimension_error4()

def test_pulses(self):
l_device, l_host = [], []

0 comments on commit a3f9817

Please sign in to comment.