Skip to content

Commit 5522378

Browse files
committedDec 19, 2014
support units in lists
1 parent 0d10ae7 commit 5522378

File tree

3 files changed

+35
-13
lines changed

3 files changed

+35
-13
lines changed
 

‎artiq/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from artiq.language.core import *
22
from artiq.language.context import *
3+
from artiq.language.units import check_unit
34
from artiq.language.units import ps, ns, us, ms, s
45
from artiq.language.units import Hz, kHz, MHz, GHz

‎artiq/test/full_stack.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fractions import Fraction
55

66
from artiq import *
7-
from artiq.language.units import *
7+
from artiq.language.units import DimensionError
88
from artiq.coredevice import comm_serial, core, runtime_exceptions, rtio
99
from artiq.sim import devices as sim_devices
1010

@@ -51,18 +51,19 @@ def build(self):
5151
self.input = 84
5252
self.inhomogeneous_units = []
5353
self.al = [1, 2, 3, 4, 5]
54+
self.list_copy_in = [2*Hz, 10*MHz]
5455

5556
@kernel
5657
def run(self):
5758
self.half_input = self.input//2
58-
decimal_fraction = Fraction("1.2")
59-
self.decimal_fraction_n = int(decimal_fraction.numerator)
60-
self.decimal_fraction_d = int(decimal_fraction.denominator)
61-
self.inhomogeneous_units.append(Quantity(1000, "Hz"))
62-
self.inhomogeneous_units.append(Quantity(10, "s"))
59+
self.decimal_fraction = Fraction("1.2")
60+
self.inhomogeneous_units.append(1000*Hz)
61+
self.inhomogeneous_units.append(10*s)
6362
self.acc = 0
6463
for i in range(len(self.al)):
6564
self.acc += self.al[i]
65+
self.list_copy_out = self.list_copy_in
66+
self.unit_comp = [1*MHz for _ in range(3)]
6667

6768
@kernel
6869
def dimension_error1(self):
@@ -184,12 +185,11 @@ def test_misc(self):
184185
uut = _Misc(core=coredev)
185186
uut.run()
186187
self.assertEqual(uut.half_input, 42)
187-
self.assertEqual(Fraction(uut.decimal_fraction_n,
188-
uut.decimal_fraction_d),
189-
Fraction("1.2"))
190-
self.assertEqual(uut.inhomogeneous_units, [
191-
Quantity(1000, "Hz"), Quantity(10, "s")])
188+
self.assertEqual(uut.decimal_fraction, Fraction("1.2"))
189+
self.assertEqual(uut.inhomogeneous_units, [1000*Hz, 10*s])
192190
self.assertEqual(uut.acc, sum(uut.al))
191+
self.assertEqual(uut.list_copy_in, uut.list_copy_out)
192+
self.assertEqual(uut.unit_comp, [1*MHz for _ in range(3)])
193193
with self.assertRaises(DimensionError):
194194
uut.dimension_error1()
195195
with self.assertRaises(DimensionError):

‎artiq/transforms/lower_units.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88

99
def _add_units(f, unit_list):
1010
def wrapper(*args):
11-
new_args = [arg if unit is None else units.Quantity(arg, unit)
12-
for arg, unit in zip(args, unit_list)]
11+
new_args = []
12+
for arg, unit in zip(args, unit_list):
13+
if unit is None:
14+
new_args.append(arg)
15+
else:
16+
if isinstance(arg, list):
17+
new_args.append([units.Quantity(x, unit) for x in arg])
18+
else:
19+
new_args.append(units.Quantity(arg, unit))
1320
return f(*new_args)
1421
return wrapper
1522

@@ -80,6 +87,20 @@ def visit_Attribute(self, node):
8087
else:
8188
return node
8289

90+
def visit_List(self, node):
91+
self.generic_visit(node)
92+
us = [getattr(elt, "unit", None) for elt in node.elts]
93+
if not all(u == us[0] for u in us[1:]):
94+
raise units.DimensionError
95+
node.unit = us[0]
96+
return node
97+
98+
def visit_ListComp(self, node):
99+
self.generic_visit(node)
100+
if hasattr(node.elt, "unit"):
101+
node.unit = node.elt.unit
102+
return node
103+
83104
def visit_Call(self, node):
84105
self.generic_visit(node)
85106
if node.func.id == "Quantity":

0 commit comments

Comments
 (0)
Please sign in to comment.