Skip to content

Commit

Permalink
compiler: explicitly represent loops in IR.
Browse files Browse the repository at this point in the history
whitequark committed Dec 16, 2015
1 parent 3386082 commit f8eaeaa
Showing 8 changed files with 109 additions and 9 deletions.
2 changes: 1 addition & 1 deletion artiq/compiler/analyses/devirtualization.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@ def visit_Assign(self, node):
self.visit(node.value)
self.visit_in_assign(node.targets)

def visit_For(self, node):
def visit_ForT(self, node):
self.visit(node.iter)
self.visit_in_assign(node.target)
self.visit(node.body)
6 changes: 6 additions & 0 deletions artiq/compiler/asttyped.py
Original file line number Diff line number Diff line change
@@ -36,6 +36,12 @@ class ExceptHandlerT(ast.ExceptHandler):
_fields = ("filter", "name", "body") # rename ast.ExceptHandler.type to filter
_types = ("name_type",)

class ForT(ast.For):
"""
:ivar trip_count: (:class:`iodelay.Expr`)
:ivar trip_interval: (:class:`iodelay.Expr`)
"""

class SliceT(ast.Slice, commontyped):
pass

64 changes: 63 additions & 1 deletion artiq/compiler/ir.py
Original file line number Diff line number Diff line change
@@ -1296,7 +1296,7 @@ class Delay(Terminator):
:ivar interval: (:class:`iodelay.Expr`) expression
:ivar var_names: (list of string)
iodelay variable names corresponding to operands
iodelay variable names corresponding to SSA values
"""

"""
@@ -1354,6 +1354,68 @@ def _operands_as_string(self, type_printer):
def opcode(self):
return "delay({})".format(self.interval)

class Loop(Terminator):
"""
A terminator for loop headers that carries metadata useful
for unrolling. It includes an :class:`iodelay.Expr` specifying
the trip count, tied to SSA values so that inlining could lead
to the expression folding to a constant.
:ivar trip_count: (:class:`iodelay.Expr`)
expression for trip count
:ivar var_names: (list of string)
iodelay variable names corresponding to ``trip_count`` operands
"""

"""
:param trip_count: (:class:`iodelay.Expr`) expression
:param substs: (dict of str to :class:`Value`)
SSA values corresponding to iodelay variable names
:param cond: (:class:`Value`) branch condition
:param if_true: (:class:`BasicBlock`) branch target if condition is truthful
:param if_false: (:class:`BasicBlock`) branch target if condition is falseful
"""
def __init__(self, trip_count, substs, cond, if_true, if_false, name=""):
for var_name in substs: assert isinstance(var_name, str)
assert isinstance(cond, Value)
assert builtins.is_bool(cond.type)
assert isinstance(if_true, BasicBlock)
assert isinstance(if_false, BasicBlock)
super().__init__([cond, if_true, if_false, *substs.values()], builtins.TNone(), name)
self.trip_count = trip_count
self.var_names = list(substs.keys())

def copy(self, mapper):
self_copy = super().copy(mapper)
self_copy.trip_count = self.trip_count
self_copy.var_names = list(self.var_names)
return self_copy

def condition(self):
return self.operands[0]

def if_true(self):
return self.operands[1]

def if_false(self):
return self.operands[2]

def substs(self):
return {key: value for key, value in zip(self.var_names, self.operands[3:])}

def _operands_as_string(self, type_printer):
substs = self.substs()
substs_as_strings = []
for var_name in substs:
substs_as_strings.append("{} = {}".format(var_name, substs[var_name]))
result = "[{}]".format(", ".join(substs_as_strings))
result += ", {}, {}, {}".format(*list(map(lambda value: value.as_operand(type_printer),
self.operands[0:3])))
return result

def opcode(self):
return "loop({} times)".format(self.trip_count)

class Parallel(Terminator):
"""
An instruction that schedules several threads of execution
19 changes: 18 additions & 1 deletion artiq/compiler/testbench/inferencer.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from ..transforms import IODelayEstimator

class Printer(algorithm.Visitor):
"""
@@ -32,7 +33,15 @@ def visit_ExceptHandlerT(self, node):

if node.name_loc:
self.rewriter.insert_after(node.name_loc,
":{}".format(self.type_printer.name(node.name_type)))
":{}".format(self.type_printer.name(node.name_type)))

def visit_ForT(self, node):
super().generic_visit(node)

if node.trip_count is not None and node.trip_interval is not None:
self.rewriter.insert_after(node.keyword_loc,
"[{} x {} mu]".format(node.trip_count.fold(),
node.trip_interval.fold()))

def generic_visit(self, node):
super().generic_visit(node)
@@ -48,6 +57,12 @@ def main():
else:
monomorphize = False

if len(sys.argv) > 1 and sys.argv[1] == "+iodelay":
del sys.argv[1]
iodelay = True
else:
iodelay = False

if len(sys.argv) > 1 and sys.argv[1] == "+diag":
del sys.argv[1]
def process_diagnostic(diag):
@@ -71,6 +86,8 @@ def process_diagnostic(diag):
if monomorphize:
IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed)
if iodelay:
IODelayEstimator(engine=engine, ref_period=1e6).visit_fixpoint(typed)

printer = Printer(buf)
printer.visit(typed)
9 changes: 7 additions & 2 deletions artiq/compiler/transforms/artiq_ir_generator.py
Original file line number Diff line number Diff line change
@@ -472,7 +472,7 @@ def iterable_get(self, value, index):
else:
assert False

def visit_For(self, node):
def visit_ForT(self, node):
try:
iterable = self.visit(node.iter)
length = self.iterable_len(iterable)
@@ -522,7 +522,12 @@ def visit_For(self, node):
else:
else_tail = tail

head.append(ir.BranchIf(cond, body, else_tail))
if node.trip_count is not None:
substs = {var_name: self.current_args[var_name]
for var_name in node.trip_count.free_vars()}
head.append(ir.Loop(node.trip_count, substs, cond, body, else_tail))
else:
head.append(ir.BranchIf(cond, body, else_tail))
if not post_body.is_terminated():
post_body.append(ir.Branch(continue_block))
break_block.append(ir.Branch(tail))
9 changes: 9 additions & 0 deletions artiq/compiler/transforms/asttyped_rewriter.py
Original file line number Diff line number Diff line change
@@ -471,6 +471,15 @@ def visit_Raise(self, node):
self.engine.process(diag)
return node

def visit_For(self, node):
node = self.generic_visit(node)
node = asttyped.ForT(
target=node.target, iter=node.iter, body=node.body, orelse=node.orelse,
trip_count=None, trip_interval=None,
keyword_loc=node.keyword_loc, in_loc=node.in_loc, for_colon_loc=node.for_colon_loc,
else_loc=node.else_loc, else_colon_loc=node.else_colon_loc)
return node

# Unsupported visitors
#
def visit_unsupported(self, node):
2 changes: 1 addition & 1 deletion artiq/compiler/transforms/inferencer.py
Original file line number Diff line number Diff line change
@@ -916,7 +916,7 @@ def visit_AugAssign(self, node):

node.value = self._coerce_one(value_type, node.value, other_node=node.target)

def visit_For(self, node):
def visit_ForT(self, node):
old_in_loop, self.in_loop = self.in_loop, True
self.generic_visit(node)
self.in_loop = old_in_loop
7 changes: 4 additions & 3 deletions artiq/compiler/transforms/iodelay_estimator.py
Original file line number Diff line number Diff line change
@@ -167,7 +167,7 @@ def evaluate(node):
node.loc)
abort([note])

def visit_For(self, node):
def visit_ForT(self, node):
self.visit(node.iter)

old_goto, self.current_goto = self.current_goto, None
@@ -180,8 +180,9 @@ def visit_For(self, node):
self.abort("loop trip count is indeterminate because of control flow",
self.current_goto.loc)

trip_count = self.get_iterable_length(node.iter)
self.current_delay = old_delay + self.current_delay * trip_count
node.trip_count = self.get_iterable_length(node.iter).fold()
node.trip_interval = self.current_delay.fold()
self.current_delay = old_delay + node.trip_interval * node.trip_count
self.current_goto = old_goto

self.visit(node.orelse)

0 comments on commit f8eaeaa

Please sign in to comment.