Skip to content

Commit

Permalink
support units in lists
Browse files Browse the repository at this point in the history
sbourdeauducq committed Dec 19, 2014
1 parent 0d10ae7 commit 5522378
Showing 3 changed files with 35 additions and 13 deletions.
1 change: 1 addition & 0 deletions artiq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from artiq.language.core import *
from artiq.language.context import *
from artiq.language.units import check_unit
from artiq.language.units import ps, ns, us, ms, s
from artiq.language.units import Hz, kHz, MHz, GHz
22 changes: 11 additions & 11 deletions artiq/test/full_stack.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
from fractions import Fraction

from artiq import *
from artiq.language.units import *
from artiq.language.units import DimensionError
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
from artiq.sim import devices as sim_devices

@@ -51,18 +51,19 @@ def build(self):
self.input = 84
self.inhomogeneous_units = []
self.al = [1, 2, 3, 4, 5]
self.list_copy_in = [2*Hz, 10*MHz]

@kernel
def run(self):
self.half_input = self.input//2
decimal_fraction = Fraction("1.2")
self.decimal_fraction_n = int(decimal_fraction.numerator)
self.decimal_fraction_d = int(decimal_fraction.denominator)
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
self.inhomogeneous_units.append(Quantity(10, "s"))
self.decimal_fraction = Fraction("1.2")
self.inhomogeneous_units.append(1000*Hz)
self.inhomogeneous_units.append(10*s)
self.acc = 0
for i in range(len(self.al)):
self.acc += self.al[i]
self.list_copy_out = self.list_copy_in
self.unit_comp = [1*MHz for _ in range(3)]

@kernel
def dimension_error1(self):
@@ -184,12 +185,11 @@ def test_misc(self):
uut = _Misc(core=coredev)
uut.run()
self.assertEqual(uut.half_input, 42)
self.assertEqual(Fraction(uut.decimal_fraction_n,
uut.decimal_fraction_d),
Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [
Quantity(1000, "Hz"), Quantity(10, "s")])
self.assertEqual(uut.decimal_fraction, Fraction("1.2"))
self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s])
self.assertEqual(uut.acc, sum(uut.al))
self.assertEqual(uut.list_copy_in, uut.list_copy_out)
self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)])
with self.assertRaises(DimensionError):
uut.dimension_error1()
with self.assertRaises(DimensionError):
25 changes: 23 additions & 2 deletions artiq/transforms/lower_units.py
Original file line number Diff line number Diff line change
@@ -8,8 +8,15 @@

def _add_units(f, unit_list):
def wrapper(*args):
new_args = [arg if unit is None else units.Quantity(arg, unit)
for arg, unit in zip(args, unit_list)]
new_args = []
for arg, unit in zip(args, unit_list):
if unit is None:
new_args.append(arg)
else:
if isinstance(arg, list):
new_args.append([units.Quantity(x, unit) for x in arg])
else:
new_args.append(units.Quantity(arg, unit))
return f(*new_args)
return wrapper

@@ -80,6 +87,20 @@ def visit_Attribute(self, node):
else:
return node

def visit_List(self, node):
self.generic_visit(node)
us = [getattr(elt, "unit", None) for elt in node.elts]
if not all(u == us[0] for u in us[1:]):
raise units.DimensionError
node.unit = us[0]
return node

def visit_ListComp(self, node):
self.generic_visit(node)
if hasattr(node.elt, "unit"):
node.unit = node.elt.unit
return node

def visit_Call(self, node):
self.generic_visit(node)
if node.func.id == "Quantity":

0 comments on commit 5522378

Please sign in to comment.