Skip to content

Commit

Permalink
py2llvm: replace array with list
Browse files Browse the repository at this point in the history
sbourdeauducq committed Dec 17, 2014
1 parent 6ca39f7 commit f3b727b
Showing 10 changed files with 174 additions and 129 deletions.
16 changes: 0 additions & 16 deletions artiq/language/core.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
"""

from collections import namedtuple as _namedtuple
from copy import copy as _copy
from functools import wraps as _wraps

from artiq.language import units as _units
@@ -71,21 +70,6 @@ def round64(x):
return int64(round(x))


def array(element, count):
"""Creates an array.
The array is initialized with the value of ``element`` repeated ``count``
times. Elements can be read and written using the regular Python index
syntax.
For static compilation, ``count`` must be a fixed integer.
Arrays of arrays are supported.
"""
return [_copy(element) for i in range(count)]


_KernelFunctionInfo = _namedtuple("_KernelFunctionInfo", "core_name k_function")


70 changes: 0 additions & 70 deletions artiq/py2llvm/arrays.py

This file was deleted.

104 changes: 92 additions & 12 deletions artiq/py2llvm/ast_body.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import llvmlite.ir as ll

from artiq.py2llvm import values, base_types, fractions, arrays, iterators
from artiq.py2llvm import values, base_types, fractions, lists, iterators
from artiq.py2llvm.tools import is_terminated


@@ -177,14 +177,6 @@ def _visit_expr_Call(self, node):
denominator = self.visit_expression(node.args[1])
r.set_value_nd(self.builder, numerator, denominator)
return r
elif fn == "array":
element = self.visit_expression(node.args[0])
if (isinstance(node.args[1], ast.Num)
and isinstance(node.args[1].n, int)):
count = node.args[1].n
else:
raise ValueError("Array size must be integer and constant")
return arrays.VArray(element, count)
elif fn == "range":
return iterators.IRange(
self.builder,
@@ -201,6 +193,56 @@ def _visit_expr_Attribute(self, node):
value = self.visit_expression(node.value)
return value.o_getattr(node.attr, self.builder)

def _visit_expr_List(self, node):
elts = [self.visit_expression(elt) for elt in node.elts]
if elts:
el_type = elts[0].new()
for elt in elts[1:]:
el_type.merge(elt)
else:
el_type = VNone()
count = len(elts)
r = lists.VList(el_type, count)
r.elts = elts
return r

def _visit_expr_ListComp(self, node):
if len(node.generators) != 1:
raise NotImplementedError
generator = node.generators[0]
if not isinstance(generator, ast.comprehension):
raise NotImplementedError
if not isinstance(generator.target, ast.Name):
raise NotImplementedError
target = generator.target.id
if not isinstance(generator.iter, ast.Call):
raise NotImplementedError
if not isinstance(generator.iter.func, ast.Name):
raise NotImplementedError
if generator.iter.func.id != "range":
raise NotImplementedError
if len(generator.iter.args) != 1:
raise NotImplementedError
if not isinstance(generator.iter.args[0], ast.Num):
raise NotImplementedError
count = generator.iter.args[0].n

# Prevent incorrect use of the generator target, if it is defined in
# the local function namespace.
if target in self.ns:
old_target_val = self.ns[target]
del self.ns[target]
else:
old_target_val = None
elt = self.visit_expression(node.elt)
if old_target_val is not None:
self.ns[target] = old_target_val

el_type = elt.new()
r = lists.VList(el_type, count)
r.elt = elt
return r

def _visit_expr_Subscript(self, node):
value = self.visit_expression(node.value)
if isinstance(node.slice, ast.Index):
@@ -227,9 +269,47 @@ def _bb_terminated(self):

def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
for target in node.targets:
target = self.visit_expression(target)
target.set_value(self.builder, val)
if isinstance(node.value, ast.List):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)
for i, elt in enumerate(val.elts):
idx = base_types.VInt()
idx.set_const_value(self.builder, i)
target.o_subscript(idx, self.builder).set_value(self.builder,
elt)
elif isinstance(node.value, ast.ListComp):
if len(node.targets) > 1:
raise NotImplementedError
target = self.visit_expression(node.targets[0])
target.set_count(self.builder, val.alloc_count)

i = base_types.VInt()
i.alloca(self.builder)
i.auto_store(self.builder, ll.Constant(ll.IntType(32), 0))

function = self.builder.basic_block.function
copy_block = function.append_basic_block("ai_copy")
end_block = function.append_basic_block("ai_end")
self.builder.branch(copy_block)

self.builder.position_at_end(copy_block)
target.o_subscript(i, self.builder).set_value(self.builder,
val.elt)
i.auto_store(self.builder, self.builder.add(
i.auto_load(self.builder),
ll.Constant(ll.IntType(32), 1)))
cont = self.builder.icmp_signed(
"<", i.auto_load(self.builder),
ll.Constant(ll.IntType(32), val.alloc_count))
self.builder.cbranch(cont, copy_block, end_block)

self.builder.position_at_end(end_block)
else:
for target in node.targets:
target = self.visit_expression(target)
target.set_value(self.builder, val)

def _visit_stmt_AugAssign(self, node):
target = self.visit_expression(node.target)
52 changes: 52 additions & 0 deletions artiq/py2llvm/lists.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import llvmlite.ir as ll

from artiq.py2llvm.values import VGeneric


class VList(VGeneric):
def __init__(self, el_type, alloc_count):
VGeneric.__init__(self)
self.el_type = el_type
self.alloc_count = alloc_count

def get_llvm_type(self):
count = 0 if self.alloc_count is None else self.alloc_count
return ll.LiteralStructType([ll.IntType(32),
ll.ArrayType(self.el_type.get_llvm_type(),
count)])

def __repr__(self):
return "<VList:{} x{}>".format(
repr(self.el_type),
"?" if self.alloc_count is None else self.alloc_count)

def same_type(self, other):
return (isinstance(other, VList)
and self.el_type.same_type(other.el_type))

def merge(self, other):
if isinstance(other, VList):
self.el_type.merge(other.el_type)
else:
raise TypeError("Incompatible types: {} and {}"
.format(repr(self), repr(other)))

def merge_subscript(self, other):
self.el_type.merge(other)

def set_count(self, builder, count):
count_ptr = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 0)])
builder.store(ll.Constant(ll.IntType(32), count), count_ptr)

def o_subscript(self, index, builder):
r = self.el_type.new()
if builder is not None:
index = index.o_int(builder).auto_load(builder)
ssa_r = builder.gep(self.llvm_value, [
ll.Constant(ll.IntType(32), 0),
ll.Constant(ll.IntType(32), 1),
index])
r.auto_store(builder, ssa_r)
return r
2 changes: 1 addition & 1 deletion artiq/py2llvm/values.py
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ def auto_store(self, builder, llvm_value):
raise RuntimeError(
"Attempted to set LLVM SSA value multiple times")

def alloca(self, builder, name):
def alloca(self, builder, name=""):
if self.llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self.llvm_value = builder.alloca(self.get_llvm_type(), name=name)
51 changes: 25 additions & 26 deletions artiq/test/py2llvm.py
Original file line number Diff line number Diff line change
@@ -7,9 +7,9 @@

import llvmlite.binding as llvm

from artiq.language.core import int64, array
from artiq.language.core import int64
from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import base_types, arrays
from artiq.py2llvm import base_types, lists
from artiq.py2llvm.module import Module


@@ -71,22 +71,22 @@ def test_return(self):
self.assertEqual(self.ns["return"].nbits, 64)


def test_array_types():
a = array(0, 5)
def test_list_types():
a = [0, 0, 0, 0, 0]
for i in range(2):
a[i] = int64(8)
return a


class FunctionArrayTypesCase(unittest.TestCase):
class FunctionListTypesCase(unittest.TestCase):
def setUp(self):
self.ns = _build_function_types(test_array_types)
self.ns = _build_function_types(test_list_types)

def test_array_types(self):
self.assertIsInstance(self.ns["a"], arrays.VArray)
self.assertIsInstance(self.ns["a"].el_init, base_types.VInt)
self.assertEqual(self.ns["a"].el_init.nbits, 64)
self.assertEqual(self.ns["a"].count, 5)
def test_list_types(self):
self.assertIsInstance(self.ns["a"], lists.VList)
self.assertIsInstance(self.ns["a"].el_type, base_types.VInt)
self.assertEqual(self.ns["a"].el_type.nbits, 64)
self.assertEqual(self.ns["a"].alloc_count, 5)
self.assertIsInstance(self.ns["i"], base_types.VInt)
self.assertEqual(self.ns["i"].nbits, 32)

@@ -212,20 +212,19 @@ def frac_arith_float_rev(op, a, b, x):
return x / Fraction(a, b)


def array_test():
a = array(array(2, 5), 5)
a[3][2] = 11
a[4][1] = 42
a[0][0] += 6
def list_test():
x = 80
a = [3 for x in range(7)]
b = [1, 2, 4, 5, 4, 0, 5]
a[3] = x
a[0] += 6
a[1] = b[1] + b[2]

acc = 0
for i in range(5):
for j in range(5):
if i + j == 2 or i + j == 1:
continue
if i and j and a[i][j]:
acc += 1
acc += a[i][j]
for i in range(7):
if i and a[i]:
acc += 1
acc += a[i]
return acc


@@ -364,9 +363,9 @@ def test_frac_div_float(self):
self._test_frac_arith_float(3, False)
self._test_frac_arith_float(3, True)

def test_array(self):
array_test_c = CompiledFunction(array_test, dict())
self.assertEqual(array_test_c(), array_test())
def test_list(self):
list_test_c = CompiledFunction(list_test, dict())
self.assertEqual(list_test_c(), list_test())

def test_corner_cases(self):
corner_cases_c = CompiledFunction(corner_cases, dict())
2 changes: 1 addition & 1 deletion artiq/transforms/tools.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,7 @@
core_language.time_to_cycles, core_language.cycles_to_time,
core_language.syscall,
range, bool, int, float, round,
core_language.int64, core_language.round64, core_language.array,
core_language.int64, core_language.round64,
Fraction, units.Quantity, units.check_unit, core_language.EncodedException
)
embeddable_func_names = {func.__name__ for func in embeddable_funcs}
Loading

0 comments on commit f3b727b

Please sign in to comment.