Skip to content

Commit 933ea53

Browse files
author
whitequark
committedJul 6, 2016
compiler: add basic numpy array support (#424).
1 parent 906db87 commit 933ea53

File tree

13 files changed

+109
-30
lines changed

13 files changed

+109
-30
lines changed
 

Diff for: ‎artiq/compiler/builtins.py

+25-3
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def TInt32():
4444
def TInt64():
4545
return TInt(types.TValue(64))
4646

47-
def _int_printer(typ, depth, max_depth):
47+
def _int_printer(typ, printer, depth, max_depth):
4848
if types.is_var(typ["width"]):
4949
return "numpy.int?"
5050
else:
5151
return "numpy.int{}".format(types.get_value(typ.find()["width"]))
52-
types.TypePrinter.custom_printers['int'] = _int_printer
52+
types.TypePrinter.custom_printers["int"] = _int_printer
5353

5454
class TFloat(types.TMono):
5555
def __init__(self):
@@ -73,6 +73,16 @@ def __init__(self, elt=None):
7373
elt = types.TVar()
7474
super().__init__("list", {"elt": elt})
7575

76+
class TArray(types.TMono):
77+
def __init__(self, elt=None):
78+
if elt is None:
79+
elt = types.TVar()
80+
super().__init__("array", {"elt": elt})
81+
82+
def _array_printer(typ, printer, depth, max_depth):
83+
return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth))
84+
types.TypePrinter.custom_printers["array"] = _array_printer
85+
7686
class TRange(types.TMono):
7787
def __init__(self, elt=None):
7888
if elt is None:
@@ -124,6 +134,9 @@ def fn_str():
124134
def fn_list():
125135
return types.TConstructor(TList())
126136

137+
def fn_array():
138+
return types.TConstructor(TArray())
139+
127140
def fn_Exception():
128141
return types.TExceptionConstructor(TException("Exception"))
129142

@@ -231,6 +244,15 @@ def is_list(typ, elt=None):
231244
else:
232245
return types.is_mono(typ, "list")
233246

247+
def is_array(typ, elt=None):
248+
if elt is not None:
249+
return types.is_mono(typ, "array", elt=elt)
250+
else:
251+
return types.is_mono(typ, "array")
252+
253+
def is_listish(typ, elt=None):
254+
return is_list(typ, elt) or is_array(typ, elt)
255+
234256
def is_range(typ, elt=None):
235257
if elt is not None:
236258
return types.is_mono(typ, "range", {"elt": elt})
@@ -247,7 +269,7 @@ def is_exception(typ, name=None):
247269
def is_iterable(typ):
248270
typ = typ.find()
249271
return isinstance(typ, types.TMono) and \
250-
typ.name in ('list', 'range')
272+
typ.name in ('list', 'array', 'range')
251273

252274
def get_iterable_elt(typ):
253275
if is_iterable(typ):

Diff for: ‎artiq/compiler/embedding.py

+12
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,18 @@ def quote(self, value):
187187
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
188188
begin_loc=begin_loc, end_loc=end_loc,
189189
loc=begin_loc.join(end_loc))
190+
elif isinstance(value, numpy.ndarray):
191+
begin_loc = self._add("numpy.array([")
192+
elts = []
193+
for index, elt in enumerate(value):
194+
elts.append(self.quote(elt))
195+
if index < len(value) - 1:
196+
self._add(", ")
197+
end_loc = self._add("])")
198+
199+
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TArray(),
200+
begin_loc=begin_loc, end_loc=end_loc,
201+
loc=begin_loc.join(end_loc))
190202
elif inspect.isfunction(value) or inspect.ismethod(value) or \
191203
isinstance(value, pytypes.BuiltinFunctionType) or \
192204
isinstance(value, SpecializedFunction):

Diff for: ‎artiq/compiler/prelude.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def globals():
1212
"int": builtins.fn_int(),
1313
"float": builtins.fn_float(),
1414
"list": builtins.fn_list(),
15+
"array": builtins.fn_array(),
1516
"range": builtins.fn_range(),
1617

1718
# Exception constructors

Diff for: ‎artiq/compiler/transforms/artiq_ir_generator.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def visit_While(self, node):
477477
self.continue_target = old_continue
478478

479479
def iterable_len(self, value, typ=_size_type):
480-
if builtins.is_list(value.type):
480+
if builtins.is_listish(value.type):
481481
return self.append(ir.Builtin("len", [value], typ,
482482
name="{}.len".format(value.name)))
483483
elif builtins.is_range(value.type):
@@ -492,7 +492,7 @@ def iterable_len(self, value, typ=_size_type):
492492

493493
def iterable_get(self, value, index):
494494
# Assuming the value is within bounds.
495-
if builtins.is_list(value.type):
495+
if builtins.is_listish(value.type):
496496
return self.append(ir.GetElem(value, index))
497497
elif builtins.is_range(value.type):
498498
start = self.append(ir.GetAttr(value, "start"))
@@ -1322,7 +1322,7 @@ def visit_BinOpT(self, node):
13221322
for index, elt in enumerate(node.right.type.elts):
13231323
elts.append(self.append(ir.GetAttr(rhs, index)))
13241324
return self.append(ir.Alloc(elts, node.type))
1325-
elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type):
1325+
elif builtins.is_listish(node.left.type) and builtins.is_listish(node.right.type):
13261326
lhs_length = self.iterable_len(lhs)
13271327
rhs_length = self.iterable_len(rhs)
13281328

@@ -1355,9 +1355,9 @@ def body_gen(index):
13551355
assert False
13561356
elif isinstance(node.op, ast.Mult): # list * int, int * list
13571357
lhs, rhs = self.visit(node.left), self.visit(node.right)
1358-
if builtins.is_list(lhs.type) and builtins.is_int(rhs.type):
1358+
if builtins.is_listish(lhs.type) and builtins.is_int(rhs.type):
13591359
lst, num = lhs, rhs
1360-
elif builtins.is_int(lhs.type) and builtins.is_list(rhs.type):
1360+
elif builtins.is_int(lhs.type) and builtins.is_listish(rhs.type):
13611361
lst, num = rhs, lhs
13621362
else:
13631363
assert False
@@ -1412,7 +1412,7 @@ def polymorphic_compare_pair_order(self, op, lhs, rhs):
14121412
result = self.append(ir.Select(result, elt_result,
14131413
ir.Constant(False, builtins.TBool())))
14141414
return result
1415-
elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type):
1415+
elif builtins.is_listish(lhs.type) and builtins.is_listish(rhs.type):
14161416
head = self.current_block
14171417
lhs_length = self.iterable_len(lhs)
14181418
rhs_length = self.iterable_len(rhs)
@@ -1606,7 +1606,7 @@ def visit_builtin_call(self, node):
16061606
return self.append(ir.Coerce(arg, node.type))
16071607
else:
16081608
assert False
1609-
elif types.is_builtin(typ, "list"):
1609+
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
16101610
if len(node.args) == 0 and len(node.keywords) == 0:
16111611
length = ir.Constant(0, builtins.TInt32())
16121612
return self.append(ir.Alloc([length], node.type))
@@ -1968,8 +1968,13 @@ def flush():
19681968
else:
19691969
format_string += "%s"
19701970
args.append(value)
1971-
elif builtins.is_list(value.type):
1972-
format_string += "["; flush()
1971+
elif builtins.is_listish(value.type):
1972+
if builtins.is_list(value.type):
1973+
format_string += "["; flush()
1974+
elif builtins.is_array(value.type):
1975+
format_string += "array(["; flush()
1976+
else:
1977+
assert False
19731978

19741979
length = self.iterable_len(value)
19751980
last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type)))
@@ -1992,7 +1997,10 @@ def body_gen(index):
19921997
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)),
19931998
body_gen)
19941999

1995-
format_string += "]"
2000+
if builtins.is_list(value.type):
2001+
format_string += "]"
2002+
elif builtins.is_array(value.type):
2003+
format_string += "])"
19962004
elif builtins.is_range(value.type):
19972005
format_string += "range("; flush()
19982006

Diff for: ‎artiq/compiler/transforms/inferencer.py

+21-9
Original file line numberDiff line numberDiff line change
@@ -671,14 +671,25 @@ def simple_form(info, arg_types=[], return_type=builtins.TNone()):
671671
pass
672672
else:
673673
diagnose(valid_forms())
674-
elif types.is_builtin(typ, "list"):
675-
valid_forms = lambda: [
676-
valid_form("list() -> list(elt='a)"),
677-
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
678-
]
674+
elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
675+
if types.is_builtin(typ, "list"):
676+
valid_forms = lambda: [
677+
valid_form("list() -> list(elt='a)"),
678+
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
679+
]
679680

680-
self._unify(node.type, builtins.TList(),
681-
node.loc, None)
681+
self._unify(node.type, builtins.TList(),
682+
node.loc, None)
683+
elif types.is_builtin(typ, "array"):
684+
valid_forms = lambda: [
685+
valid_form("array() -> array(elt='a)"),
686+
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
687+
]
688+
689+
self._unify(node.type, builtins.TArray(),
690+
node.loc, None)
691+
else:
692+
assert False
682693

683694
if len(node.args) == 0 and len(node.keywords) == 0:
684695
pass # []
@@ -708,7 +719,8 @@ def makenotes(printer, typea, typeb, loca, locb):
708719
{"type": types.TypePrinter().name(arg.type)},
709720
arg.loc)
710721
diag = diagnostic.Diagnostic("error",
711-
"the argument of list() must be of an iterable type", {},
722+
"the argument of {builtin}() must be of an iterable type",
723+
{"builtin": typ.find().name},
712724
node.func.loc, notes=[note])
713725
self.engine.process(diag)
714726
else:
@@ -743,7 +755,7 @@ def makenotes(printer, typea, typeb, loca, locb):
743755
if builtins.is_range(arg.type):
744756
self._unify(node.type, builtins.get_iterable_elt(arg.type),
745757
node.loc, None)
746-
elif builtins.is_list(arg.type):
758+
elif builtins.is_listish(arg.type):
747759
# TODO: should be ssize_t-sized
748760
self._unify(node.type, builtins.TInt32(),
749761
node.loc, None)

Diff for: ‎artiq/compiler/transforms/llvm_ir_generator.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def llty_of_type(self, typ, bare=False, for_return=False):
218218
return lldouble
219219
elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ):
220220
return llptr
221-
elif builtins.is_list(typ):
221+
elif builtins.is_listish(typ):
222222
lleltty = self.llty_of_type(builtins.get_iterable_elt(typ))
223223
return ll.LiteralStructType([lli32, lleltty.as_pointer()])
224224
elif builtins.is_range(typ):
@@ -610,7 +610,7 @@ def process_Alloc(self, insn):
610610
name=insn.name)
611611
else:
612612
assert False
613-
elif builtins.is_list(insn.type):
613+
elif builtins.is_listish(insn.type):
614614
llsize = self.map(insn.operands[0])
615615
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
616616
llvalue = self.llbuilder.insert_value(llvalue, llsize, 0)
@@ -1162,6 +1162,9 @@ def _rpc_tag(self, typ, error_handler):
11621162
elif builtins.is_list(typ):
11631163
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
11641164
error_handler)
1165+
elif builtins.is_array(typ):
1166+
return b"a" + self._rpc_tag(builtins.get_iterable_elt(typ),
1167+
error_handler)
11651168
elif builtins.is_range(typ):
11661169
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
11671170
error_handler)
@@ -1405,13 +1408,13 @@ def _quote_attributes():
14051408
elif builtins.is_str(typ):
14061409
assert isinstance(value, (str, bytes))
14071410
return self.llstr_of_str(value)
1408-
elif builtins.is_list(typ):
1409-
assert isinstance(value, list)
1411+
elif builtins.is_listish(typ):
1412+
assert isinstance(value, (list, numpy.ndarray))
14101413
elt_type = builtins.get_iterable_elt(typ)
14111414
llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)])
14121415
for i in range(len(value))]
14131416
lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)),
1414-
llelts)
1417+
list(llelts))
14151418

14161419
llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type,
14171420
self.llmodule.scope.deduplicate("quoted.list"))

Diff for: ‎artiq/compiler/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def name(self, typ, depth=0, max_depth=1):
697697
return "<instance {} {{}}>".format(typ.name)
698698
elif isinstance(typ, TMono):
699699
if typ.name in self.custom_printers:
700-
return self.custom_printers[typ.name](typ, depth + 1, max_depth)
700+
return self.custom_printers[typ.name](typ, self, depth + 1, max_depth)
701701
elif typ.params == {}:
702702
return typ.name
703703
else:

Diff for: ‎artiq/coredevice/comm_generic.py

+3
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,9 @@ def _receive_rpc_value(self, embedding_map):
331331
elif tag == "l":
332332
length = self._read_int32()
333333
return [self._receive_rpc_value(embedding_map) for _ in range(length)]
334+
elif tag == "a":
335+
length = self._read_int32()
336+
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
334337
elif tag == "r":
335338
start = self._receive_rpc_value(embedding_map)
336339
stop = self._receive_rpc_value(embedding_map)

Diff for: ‎artiq/runtime/session.c

+6-2
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ static void skip_rpc_value(const char **tag) {
607607
}
608608

609609
case 'l':
610+
case 'a':
610611
skip_rpc_value(tag);
611612
break;
612613

@@ -650,6 +651,7 @@ static int sizeof_rpc_value(const char **tag)
650651
return sizeof(char *);
651652

652653
case 'l': // list(elt='a)
654+
case 'a': // array(elt='a)
653655
skip_rpc_value(tag);
654656
return sizeof(struct { int32_t length; struct {} *elements; });
655657

@@ -733,7 +735,8 @@ static int receive_rpc_value(const char **tag, void **slot)
733735
break;
734736
}
735737

736-
case 'l': { // list(elt='a)
738+
case 'l': // list(elt='a)
739+
case 'a': { // array(elt='a)
737740
struct { int32_t length; struct {} *elements; } *list = *slot;
738741
list->length = in_packet_int32();
739742

@@ -824,7 +827,8 @@ static int send_rpc_value(const char **tag, void **value)
824827
return out_packet_string(*((*(const char***)value)++));
825828
}
826829

827-
case 'l': { // list(elt='a)
830+
case 'l': // list(elt='a)
831+
case 'a': { // array(elt='a)
828832
struct { uint32_t length; struct {} *elements; } *list = *value;
829833
void *element = list->elements;
830834

Diff for: ‎artiq/test/coredevice/test_embedding.py

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def test_str(self):
4141
def test_list(self):
4242
self.assertRoundtrip([10])
4343

44+
def test_array(self):
45+
self.assertRoundtrip(numpy.array([10]))
46+
4447
def test_object(self):
4548
obj = object()
4649
self.assertRoundtrip(obj)

Diff for: ‎artiq/test/lit/inferencer/unify.py

+3
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757
k = "x"
5858
# CHECK-L: k:str
5959

60+
l = array([1])
61+
# CHECK-L: l:numpy.array(elt=numpy.int?)
62+
6063
IndexError()
6164
# CHECK-L: IndexError:<constructor IndexError {}>():IndexError
6265

Diff for: ‎artiq/test/lit/integration/array.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# RUN: %python -m artiq.compiler.testbench.jit %s
2+
# REQUIRES: exceptions
3+
4+
ary = array([1, 2, 3])
5+
assert [x*x for x in ary] == [1, 4, 9]

Diff for: ‎artiq/test/lit/integration/print.py

+3
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,6 @@
3030

3131
# CHECK-L: range(0, 10, 1)
3232
print(range(10))
33+
34+
# CHECK-L: array([1, 2])
35+
print(array([1, 2]))

0 commit comments

Comments
 (0)
Please sign in to comment.