Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: m-labs/artiq
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: a647e1104dfc
Choose a base ref
...
head repository: m-labs/artiq
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: 6deaf7b81a02
Choose a head ref
  • 3 commits
  • 7 files changed
  • 1 contributor

Commits on Sep 6, 2014

  1. Copy the full SHA
    2ef187b View commit details
  2. Copy the full SHA
    64c29bc View commit details
  3. Copy the full SHA
    6deaf7b View commit details
Showing with 205 additions and 70 deletions.
  1. +5 −21 artiq/py2llvm/__init__.py
  2. +28 −10 artiq/py2llvm/ast_body.py
  3. +31 −0 artiq/py2llvm/functions.py
  4. +29 −35 artiq/py2llvm/infer_types.py
  5. +11 −0 artiq/py2llvm/tools.py
  6. +7 −4 artiq/py2llvm/values.py
  7. +94 −0 test/py2llvm.py
26 changes: 5 additions & 21 deletions artiq/py2llvm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,20 @@
from llvm import core as lc
from llvm import passes as lp

from artiq.py2llvm import infer_types, ast_body, values


def _compile_function(module, env, funcdef):
function_type = lc.Type.function(lc.Type.void(), [])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)

ns = infer_types.infer_types(env, funcdef)
for k, v in ns.items():
v.alloca(builder, k)
visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)
builder.ret_void()
from artiq.py2llvm import values
from artiq.py2llvm.functions import compile_function
from artiq.py2llvm.tools import add_common_passes


def get_runtime_binary(env, funcdef):
module = lc.Module.new("main")
env.init_module(module)
values.init_module(module)

_compile_function(module, env, funcdef)
compile_function(module, env, funcdef, dict())

pass_manager = lp.PassManager.new()
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)
add_common_passes(pass_manager)
pass_manager.run(module)

return env.emit_object()
38 changes: 28 additions & 10 deletions artiq/py2llvm/ast_body.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ast

from artiq.py2llvm import values
from artiq.py2llvm.tools import is_terminated


class Visitor:
@@ -131,13 +132,16 @@ def _visit_expr_Call(self, node):

def visit_statements(self, stmts):
for node in stmts:
method = "_visit_stmt_" + node.__class__.__name__
node_type = node.__class__.__name__
method = "_visit_stmt_" + node_type
try:
visitor = getattr(self, method)
except AttributeError:
raise NotImplementedError("Unsupported node '{}' in statement"
.format(node.__class__.__name__))
.format(node_type))
visitor(node)
if node_type == "Return":
break

def _visit_stmt_Assign(self, node):
val = self.visit_expression(node.value)
@@ -165,17 +169,19 @@ def _visit_stmt_If(self, node):
merge_block = function.append_basic_block("i_merge")

condition = values.operators.bool(self.visit_expression(node.test),
self.builder)
self.builder)
self.builder.cbranch(condition.get_ssa_value(self.builder),
then_block, else_block)

self.builder.position_at_end(then_block)
self.visit_statements(node.body)
self.builder.branch(merge_block)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block)

self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
self.builder.branch(merge_block)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block)

self.builder.position_at_end(merge_block)

@@ -192,13 +198,25 @@ def _visit_stmt_While(self, node):

self.builder.position_at_end(body_block)
self.visit_statements(node.body)
condition = values.operators.bool(
self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block)
if not is_terminated(self.builder.basic_block):
condition = values.operators.bool(
self.visit_expression(node.test), self.builder)
self.builder.cbranch(
condition.get_ssa_value(self.builder), body_block, merge_block)

self.builder.position_at_end(else_block)
self.visit_statements(node.orelse)
self.builder.branch(merge_block)
if not is_terminated(self.builder.basic_block):
self.builder.branch(merge_block)

self.builder.position_at_end(merge_block)

def _visit_stmt_Return(self, node):
if node.value is None:
val = values.VNone()
else:
val = self.visit_expression(node.value)
if isinstance(val, values.VNone):
self.builder.ret_void()
else:
self.builder.ret(val.get_ssa_value(self.builder))
31 changes: 31 additions & 0 deletions artiq/py2llvm/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from llvm import core as lc

from artiq.py2llvm import infer_types, ast_body, values, tools

def compile_function(module, env, funcdef, param_types):
ns = infer_types.infer_function_types(env, funcdef, param_types)
retval = ns["return"]

function_type = lc.Type.function(retval.get_llvm_type(),
[ns[arg.arg].get_llvm_type() for arg in funcdef.args.args])
function = module.add_function(function_type, funcdef.name)
bb = function.append_basic_block("entry")
builder = lc.Builder.new(bb)

for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
arg_llvm.name = arg_ast.arg
for k, v in ns.items():
v.alloca(builder, k)
for arg_ast, arg_llvm in zip(funcdef.args.args, function.args):
ns[arg_ast.arg].set_ssa_value(builder, arg_llvm)

visitor = ast_body.Visitor(env, ns, builder)
visitor.visit_statements(funcdef.body)

if not tools.is_terminated(builder.basic_block):
if isinstance(retval, values.VNone):
builder.ret_void()
else:
builder.ret(retval.get_ssa_value(builder))

return function, retval
64 changes: 29 additions & 35 deletions artiq/py2llvm/infer_types.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,55 @@
import ast
from operator import itemgetter
from copy import deepcopy

from artiq.py2llvm.ast_body import Visitor
from artiq.py2llvm import values


class _TypeScanner(ast.NodeVisitor):
def __init__(self, env, ns):
self.exprv = Visitor(env, ns)

def _update_target(self, target, val):
ns = self.exprv.ns
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = deepcopy(val)
else:
raise NotImplementedError

def visit_Assign(self, node):
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns
for target in node.targets:
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = val
else:
raise NotImplementedError
self._update_target(target, val)

def visit_AugAssign(self, node):
val = self.exprv.visit_expression(ast.BinOp(
op=node.op, left=node.target, right=node.value))
self._update_target(node.target, val)

def visit_Return(self, node):
if node.value is None:
val = values.VNone()
else:
val = self.exprv.visit_expression(node.value)
ns = self.exprv.ns
target = node.target
if isinstance(target, ast.Name):
if target.id in ns:
ns[target.id].merge(val)
else:
ns[target.id] = val
if "return" in ns:
ns["return"].merge(val)
else:
raise NotImplementedError
ns["return"] = deepcopy(val)


def infer_types(env, node):
ns = dict()
def infer_function_types(env, node, param_types):
ns = deepcopy(param_types)
ts = _TypeScanner(env, ns)
ts.visit(node)
while True:
prev_ns = deepcopy(ns)
ts = _TypeScanner(env, ns)
ts.visit(node)
if prev_ns and all(v.same_type(prev_ns[k]) for k, v in ns.items()):
if all(v.same_type(prev_ns[k]) for k, v in ns.items()):
# no more promotions - completed
if "return" not in ns:
ns["return"] = values.VNone()
return ns

if __name__ == "__main__":
testcode = """
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True
bar = None
"""
ns = infer_types(None, ast.parse(testcode))
for k, v in sorted(ns.items(), key=itemgetter(0)):
print("{:10}--> {}".format(k, str(v)))
11 changes: 11 additions & 0 deletions artiq/py2llvm/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from llvm import passes as lp

def is_terminated(basic_block):
return basic_block.instructions and basic_block.instructions[-1].is_terminator

def add_common_passes(pass_manager):
pass_manager.add(lp.PASS_MEM2REG)
pass_manager.add(lp.PASS_INSTCOMBINE)
pass_manager.add(lp.PASS_REASSOCIATE)
pass_manager.add(lp.PASS_GVN)
pass_manager.add(lp.PASS_SIMPLIFYCFG)
11 changes: 7 additions & 4 deletions artiq/py2llvm/values.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ def set_ssa_value(self, builder, value):

def alloca(self, builder, name):
if self._llvm_value is not None:
raise RuntimeError("Attempted to alloca existing LLVM value")
raise RuntimeError("Attempted to alloca existing LLVM value "+name)
self._llvm_value = builder.alloca(self.get_llvm_type(), name=name)

def o_int(self, builder):
@@ -96,16 +96,19 @@ def set_value(self, builder, n):
def set_const_value(self, builder, n):
self.set_ssa_value(builder, lc.Constant.int(self.get_llvm_type(), n))

def o_bool(self, builder):
def o_bool(self, builder, inv=False):
r = VBool()
if builder is not None:
r.set_ssa_value(
builder, builder.icmp(
lc.ICMP_NE,
lc.ICMP_EQ if inv else lc.ICMP_NE,
self.get_ssa_value(builder),
lc.Constant.int(self.get_llvm_type(), 0)))
return r

def o_not(self, builder):
return self.o_bool(builder, True)

def o_intx(self, target_bits, builder):
r = VInt(target_bits)
if builder is not None:
@@ -432,7 +435,7 @@ def _make_operators():
for op_name in ("bool", "int", "int64", "round", "round64",
"inv", "pos", "neg"):
d[op_name] = _make_unary_operator(op_name)
d["not_"] = _make_binary_operator("not")
d["not_"] = _make_unary_operator("not")
for op_name in ("add", "sub", "mul",
"truediv", "floordiv", "mod",
"pow", "lshift", "rshift", "xor",
94 changes: 94 additions & 0 deletions test/py2llvm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest
import ast
import inspect

from llvm import core as lc
from llvm import passes as lp
from llvm import ee as le

from artiq.py2llvm.infer_types import infer_function_types
from artiq.py2llvm import values
from artiq.py2llvm import compile_function
from artiq.py2llvm.tools import add_common_passes


def test_types(choice):
a = 2 # promoted later to int64
b = a + 1 # initially int32, becomes int64 after a is promoted
c = b//2 # initially int32, becomes int64 after b is promoted
d = 4 # stays int32
x = int64(7)
a += x # promotes a to int64
foo = True
bar = None

if choice:
return 3
else:
return x

class FunctionTypesCase(unittest.TestCase):
def setUp(self):
self.ns = infer_function_types(
None, ast.parse(inspect.getsource(test_types)),
dict())

def test_base_types(self):
self.assertIsInstance(self.ns["foo"], values.VBool)
self.assertIsInstance(self.ns["bar"], values.VNone)
self.assertIsInstance(self.ns["d"], values.VInt)
self.assertEqual(self.ns["d"].nbits, 32)
self.assertIsInstance(self.ns["x"], values.VInt)
self.assertEqual(self.ns["x"].nbits, 64)

def test_promotion(self):
for v in "abc":
self.assertIsInstance(self.ns[v], values.VInt)
self.assertEqual(self.ns[v].nbits, 64)

def test_return(self):
self.assertIsInstance(self.ns["return"], values.VInt)
self.assertEqual(self.ns["return"].nbits, 64)


class CompiledFunction:
def __init__(self, function, param_types):
module = lc.Module.new("main")
values.init_module(module)

funcdef = ast.parse(inspect.getsource(function)).body[0]
self.function, self.retval = compile_function(
module, None, funcdef, param_types)
self.argval = [param_types[arg.arg] for arg in funcdef.args.args]

self.executor = le.ExecutionEngine.new(module)
pass_manager = lp.PassManager.new()
add_common_passes(pass_manager)
pass_manager.run(module)

def __call__(self, *args):
args_llvm = [
le.GenericValue.int(av.get_llvm_type(), a)
for av, a in zip(self.argval, args)]
result = self.executor.run_function(self.function, args_llvm)
if isinstance(self.retval, values.VBool):
return bool(result.as_int())
elif isinstance(self.retval, values.VInt):
return result.as_int_signed()
else:
raise NotImplementedError


def is_prime(x):
d = 2
while d*d <= x:
if not x % d:
return False
d += 1
return True

class CodeGenCase(unittest.TestCase):
def test_is_prime(self):
is_prime_c = CompiledFunction(is_prime, {"x": values.VInt(32)})
for i in range(200):
self.assertEqual(is_prime_c(i), is_prime(i))