Skip to content

Commit 68de724

Browse files
author
whitequark
committedDec 2, 2016
compiler: monomorphize int64(round(x)) to not lose precision.
This applies to any expression with an indeterminate integer type cast to int64(), not just round().
1 parent 696db32 commit 68de724

File tree

8 files changed

+40
-21
lines changed

8 files changed

+40
-21
lines changed
 

‎artiq/compiler/module.py

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False
4848
self.globals = src.globals
4949

5050
int_monomorphizer = transforms.IntMonomorphizer(engine=self.engine)
51+
cast_monomorphizer = transforms.CastMonomorphizer(engine=self.engine)
5152
inferencer = transforms.Inferencer(engine=self.engine)
5253
monomorphism_validator = validators.MonomorphismValidator(engine=self.engine)
5354
escape_validator = validators.EscapeValidator(engine=self.engine)
@@ -63,6 +64,7 @@ def __init__(self, src, ref_period=1e-6, attribute_writeback=True, remarks=False
6364
interleaver = transforms.Interleaver(engine=self.engine)
6465
invariant_detection = analyses.InvariantDetection(engine=self.engine)
6566

67+
cast_monomorphizer.visit(src.typedtree)
6668
int_monomorphizer.visit(src.typedtree)
6769
inferencer.visit(src.typedtree)
6870
monomorphism_validator.visit(src.typedtree)

‎artiq/compiler/testbench/inferencer.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import sys, fileinput, os
22
from pythonparser import source, diagnostic, algorithm, parse_buffer
33
from .. import prelude, types
4-
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
4+
from ..transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer, CastMonomorphizer
55
from ..transforms import IODelayEstimator
66

77
class Printer(algorithm.Visitor):
@@ -84,6 +84,7 @@ def process_diagnostic(diag):
8484
typed = ASTTypedRewriter(engine=engine, prelude=prelude.globals()).visit(parsed)
8585
Inferencer(engine=engine).visit(typed)
8686
if monomorphize:
87+
CastMonomorphizer(engine=engine).visit(typed)
8788
IntMonomorphizer(engine=engine).visit(typed)
8889
Inferencer(engine=engine).visit(typed)
8990
if iodelay:

‎artiq/compiler/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .asttyped_rewriter import ASTTypedRewriter
22
from .inferencer import Inferencer
33
from .int_monomorphizer import IntMonomorphizer
4+
from .cast_monomorphizer import CastMonomorphizer
45
from .iodelay_estimator import IODelayEstimator
56
from .artiq_ir_generator import ARTIQIRGenerator
67
from .dead_code_eliminator import DeadCodeEliminator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
:class:`CastMonomorphizer` uses explicit casts to monomorphize
3+
expressions of undetermined integer type to either 32 or 64 bits.
4+
"""
5+
6+
from pythonparser import algorithm, diagnostic
7+
from .. import types, builtins
8+
9+
class CastMonomorphizer(algorithm.Visitor):
10+
def __init__(self, engine):
11+
self.engine = engine
12+
13+
def visit_CallT(self, node):
14+
self.generic_visit(node)
15+
16+
if (types.is_builtin(node.func.type, "int") or
17+
types.is_builtin(node.func.type, "int32") or
18+
types.is_builtin(node.func.type, "int64")):
19+
typ = node.type.find()
20+
if (not types.is_var(typ["width"]) and
21+
builtins.is_int(node.args[0].type) and
22+
types.is_var(node.args[0].type.find()["width"])):
23+
node.args[0].type.unify(typ)
24+

‎artiq/compiler/transforms/inferencer.py

-14
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,6 @@ def makenotes(printer, typea, typeb, loca, locb):
780780
elif types.is_builtin(typ, "round"):
781781
valid_forms = lambda: [
782782
valid_form("round(x:float) -> numpy.int?"),
783-
valid_form("round(x:float, width=?) -> numpy.int?")
784783
]
785784

786785
self._unify(node.type, builtins.TInt(),
@@ -791,19 +790,6 @@ def makenotes(printer, typea, typeb, loca, locb):
791790

792791
self._unify(arg.type, builtins.TFloat(),
793792
arg.loc, None)
794-
elif len(node.args) == 1 and len(node.keywords) == 1 and \
795-
builtins.is_numeric(node.args[0].type) and \
796-
node.keywords[0].arg == 'width':
797-
width = node.keywords[0].value
798-
if not (isinstance(width, asttyped.NumT) and isinstance(width.n, int)):
799-
diag = diagnostic.Diagnostic("error",
800-
"the width argument of round() must be an integer literal", {},
801-
node.keywords[0].loc)
802-
self.engine.process(diag)
803-
return
804-
805-
self._unify(node.type, builtins.TInt(types.TValue(width.n)),
806-
node.loc, None)
807793
else:
808794
diagnose(valid_forms())
809795
elif types.is_builtin(typ, "min") or types.is_builtin(typ, "max"):

‎artiq/test/lit/inferencer/builtin_calls.py

-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,3 @@
3030

3131
# CHECK-L: round:<function round>(1.0:float):numpy.int?
3232
round(1.0)
33-
34-
# CHECK-L: round:<function round>(1.0:float, width=64:numpy.int?):numpy.int64
35-
round(1.0, width=64)

‎artiq/test/lit/monomorphism/integers.py

-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,3 @@
66

77
y = int(1)
88
# CHECK-L: y: numpy.int32
9-
10-
z = round(1.0)
11-
# CHECK-L: z: numpy.int32

‎artiq/test/lit/monomorphism/round.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# RUN: %python -m artiq.compiler.testbench.inferencer +mono %s >%t
2+
# RUN: OutputCheck %s --file-to-check=%t
3+
4+
# CHECK-L: round:<function round>(1.0:float):numpy.int32
5+
round(1.0)
6+
7+
# CHECK-L: round:<function round>(2.0:float):numpy.int32
8+
int32(round(2.0))
9+
10+
# CHECK-L: round:<function round>(3.0:float):numpy.int64
11+
int64(round(3.0))

0 commit comments

Comments
 (0)
Please sign in to comment.