Skip to content

Commit

Permalink
compiler: monomorphize int64(round(x)) to not lose precision.
Browse files Browse the repository at this point in the history
This applies to any expression with an indeterminate integer type
cast to int64(), not just round().
whitequark committed Dec 2, 2016
1 parent 696db32 commit 68de724
Showing 8 changed files with 40 additions and 21 deletions.
2 changes: 2 additions & 0 deletions artiq/compiler/module.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False
self.globals = src.globals

int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
inferencer = transforms.Inferencer(engine=self.engine)
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
escape_validator = validators.EscapeValidator(engine=self.engine)
@@ -63,6 +64,7 @@ def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False
interleaver = transforms.Interleaver(engine=self.engine)
invariant_detection = analyses.InvariantDetection(engine=self.engine)

cast_monomorphizer.visit(src.typedtree)
int_monomorphizer.visit(src.typedtree)
inferencer.visit(src.typedtree)
monomorphism_validator.visit(src.typedtree)
3 changes: 2 additions & 1 deletion artiq/compiler/testbench/inferencer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys, fileinput, os
from pythonparser import source, diagnostic, algorithm, parse_buffer
from .. import prelude, types
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
from ..transforms import IODelayEstimator

class Printer(algorithm.Visitor):
@@ -84,6 +84,7 @@ def process_diagnostic(diag):
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
Inferencer(engine=engine).visit(typed)
if monomorphize:
CastMonomorphizer(engine=engine).visit(typed)
IntMonomorphizer(engine=engine).visit(typed)
Inferencer(engine=engine).visit(typed)
if iodelay:
1 change: 1 addition & 0 deletions artiq/compiler/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .asttyped_rewriter import ASTTypedRewriter
from .inferencer import Inferencer
from .int_monomorphizer import IntMonomorphizer
from .cast_monomorphizer import CastMonomorphizer
from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator
from .dead_code_eliminator import DeadCodeEliminator
24 changes: 24 additions & 0 deletions artiq/compiler/transforms/cast_monomorphizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
:class:`CastMonomorphizer` uses explicit casts to monomorphize
expressions of undetermined integer type to either 32 or 64 bits.
"""

from pythonparser import algorithm, diagnostic
from .. import types, builtins

class CastMonomorphizer(algorithm.Visitor):
def __init__(self, engine):
self.engine = engine

def visit_CallT(self, node):
self.generic_visit(node)

if (types.is_builtin(node.func.type, "int") or
types.is_builtin(node.func.type, "int32") or
types.is_builtin(node.func.type, "int64")):
typ = node.type.find()
if (not types.is_var(typ["width"]) and
builtins.is_int(node.args[0].type) and
types.is_var(node.args[0].type.find()["width"])):
node.args[0].type.unify(typ)

14 changes: 0 additions & 14 deletions artiq/compiler/transforms/inferencer.py
Original file line number Diff line number Diff line change
@@ -780,7 +780,6 @@ def makenotes(printer, typea, typeb, loca, locb):
elif types.is_builtin(typ, "round"):
valid_forms = lambda: [
valid_form("round(x:float) -> numpy.int?"),
valid_form("round(x:float, width=?) -> numpy.int?")
]

self._unify(node.type, builtins.TInt(),
@@ -791,19 +790,6 @@ def makenotes(printer, typea, typeb, loca, locb):

self._unify(arg.type, builtins.TFloat(),
arg.loc, None)
elif len(node.args) == 1 and len(node.keywords) == 1 and \
builtins.is_numeric(node.args[0].type) and \
node.keywords[0].arg == 'width':
width = node.keywords[0].value
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
diag = diagnostic.Diagnostic("error",
"the width argument of round() must be an integer literal", {},
node.keywords[0].loc)
self.engine.process(diag)
return

self._unify(node.type, builtins.TInt(types.TValue(width.n)),
node.loc, None)
else:
diagnose(valid_forms())
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):
3 changes: 0 additions & 3 deletions artiq/test/lit/inferencer/builtin_calls.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,3 @@

# CHECK-L: round:<function round>(1.0:float):numpy.int?
round(1.0)

# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
round(1.0, width=64)
3 changes: 0 additions & 3 deletions artiq/test/lit/monomorphism/integers.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,3 @@

y = int(1)
# CHECK-L: y: numpy.int32

z = round(1.0)
# CHECK-L: z: numpy.int32
11 changes: 11 additions & 0 deletions artiq/test/lit/monomorphism/round.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
# RUN: OutputCheck %s --file-to-check=%t

# CHECK-L: round:<function round>(1.0:float):numpy.int32
round(1.0)

# CHECK-L: round:<function round>(2.0:float):numpy.int32
int32(round(2.0))

# CHECK-L: round:<function round>(3.0:float):numpy.int64
int64(round(3.0))

0 comments on commit 68de724

Please sign in to comment.