Skip to content

Commit 20b7a73

Browse files
author
whitequark
committedJun 14, 2015
Add support for Compare.
1 parent fe69c5b commit 20b7a73

File tree

4 files changed

+87
-46
lines changed

4 files changed

+87
-46
lines changed
 

Diff for: ‎artiq/py2llvm/types.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,14 @@ def __ne__(self, other):
163163
def is_var(typ):
164164
return isinstance(typ.find(), TVar)
165165

166-
def is_mono(typ, name, **params):
166+
def is_mono(typ, name=None, **params):
167167
typ = typ.find()
168168
params_match = True
169169
for param in params:
170-
params_match = params_match and typ.params[param] == params[param]
170+
params_match = params_match and \
171+
typ.params[param].find() == params[param].find()
171172
return isinstance(typ, TMono) and \
172-
typ.name == name and params_match
173+
(name is None or (typ.name == name and params_match))
173174

174175
def is_tuple(typ, elts=None):
175176
typ = typ.find()

Diff for: ‎artiq/py2llvm/typing.py

+70-36
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def visit_unsupported(self, node):
283283

284284
# expr
285285
visit_Call = visit_unsupported
286-
visit_Compare = visit_unsupported
287286
visit_Dict = visit_unsupported
288287
visit_DictComp = visit_unsupported
289288
visit_Ellipsis = visit_unsupported
@@ -393,11 +392,14 @@ def visit_AttributeT(self, node):
393392
node.attr_loc, [node.value.loc])
394393
self.engine.process(diag)
395394

395+
def _unify_collection(self, element, collection):
396+
# TODO: support more than just lists
397+
self._unify(builtins.TList(element.type), collection.type,
398+
element.loc, collection.loc)
399+
396400
def visit_SubscriptT(self, node):
397401
self.generic_visit(node)
398-
# TODO: support more than just lists
399-
self._unify(builtins.TList(node.type), node.value.type,
400-
node.loc, node.value.loc)
402+
self._unify_collection(element=node, collection=node.value)
401403

402404
def visit_IfExpT(self, node):
403405
self.generic_visit(node)
@@ -455,43 +457,39 @@ def visit_CoerceT(self, node):
455457
def _coerce_one(self, typ, coerced_node, other_node):
456458
if coerced_node.type.find() == typ.find():
457459
return coerced_node
460+
elif isinstance(coerced_node, asttyped.CoerceT):
461+
node.type, node.other_expr = typ, other_node
458462
else:
459463
node = asttyped.CoerceT(type=typ, expr=coerced_node, other_expr=other_node,
460464
loc=coerced_node.loc)
461-
self.visit(node)
462-
return node
465+
self.visit(node)
466+
return node
463467

464-
def _coerce_numeric(self, left, right):
465-
# Implements the coercion protocol.
468+
def _coerce_numeric(self, nodes, map_return=lambda typ: typ):
466469
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
467-
if builtins.is_float(left.type) or builtins.is_float(right.type):
468-
typ = builtins.TFloat()
469-
elif builtins.is_int(left.type) or builtins.is_int(right.type):
470-
left_width, right_width = \
471-
builtins.get_int_width(left.type), builtins.get_int_width(left.type)
472-
if left_width and right_width:
473-
typ = builtins.TInt(types.TValue(max(left_width, right_width)))
474-
else:
475-
typ = builtins.TInt()
476-
elif types.is_var(left.type) or types.is_var(right.type): # not enough info yet
470+
node_types = [node.type for node in nodes]
471+
if any(map(types.is_var, node_types)): # not enough info yet
477472
return
478-
else: # conflicting types
479-
printer = types.TypePrinter()
480-
note1 = diagnostic.Diagnostic("note",
481-
"expression of type {typea}", {"typea": printer.name(left.type)},
482-
left.loc)
483-
note2 = diagnostic.Diagnostic("note",
484-
"expression of type {typeb}", {"typeb": printer.name(right.type)},
485-
right.loc)
473+
elif not all(map(builtins.is_numeric, node_types)):
474+
err_node = next(filter(lambda node: not builtins.is_numeric(node.type), nodes))
486475
diag = diagnostic.Diagnostic("error",
487-
"cannot coerce {typea} and {typeb} to a common numeric type",
488-
{"typea": printer.name(left.type), "typeb": printer.name(right.type)},
489-
left.loc, [right.loc],
490-
[note1, note2])
476+
"cannot coerce {type} to a numeric type",
477+
{"type": types.TypePrinter().name(err_node.type)},
478+
err_node.loc, [])
491479
self.engine.process(diag)
492480
return
481+
elif any(map(builtins.is_float, node_types)):
482+
typ = builtins.TFloat()
483+
elif any(map(builtins.is_int, node_types)):
484+
widths = map(builtins.get_int_width, node_types)
485+
if all(widths):
486+
typ = builtins.TInt(types.TValue(max(widths)))
487+
else:
488+
typ = builtins.TInt()
489+
else:
490+
assert False
493491

494-
return typ, typ, typ
492+
return map_return(typ)
495493

496494
def _order_by_pred(self, pred, left, right):
497495
if pred(left.type):
@@ -503,7 +501,7 @@ def _order_by_pred(self, pred, left, right):
503501

504502
def _coerce_binop(self, op, left, right):
505503
if isinstance(op, (ast.BitAnd, ast.BitOr, ast.BitXor,
506-
ast.LShift, ast.RShift)):
504+
ast.LShift, ast.RShift)):
507505
# bitwise operators require integers
508506
for operand in (left, right):
509507
if not types.is_var(operand.type) and not builtins.is_int(operand.type):
@@ -515,7 +513,7 @@ def _coerce_binop(self, op, left, right):
515513
self.engine.process(diag)
516514
return
517515

518-
return self._coerce_numeric(left, right)
516+
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
519517
elif isinstance(op, ast.Add):
520518
# add works on numbers and also collections
521519
if builtins.is_collection(left.type) or builtins.is_collection(right.type):
@@ -554,7 +552,7 @@ def _coerce_binop(self, op, left, right):
554552
left.loc, right.loc)
555553
return left.type, left.type, right.type
556554
else:
557-
return self._coerce_numeric(left, right)
555+
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
558556
elif isinstance(op, ast.Mult):
559557
# mult works on numbers and also number & collection
560558
if types.is_tuple(left.type) or types.is_tuple(right.type):
@@ -585,10 +583,10 @@ def _coerce_binop(self, op, left, right):
585583

586584
return list_.type, left.type, right.type
587585
else:
588-
return self._coerce_numeric(left, right)
586+
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
589587
elif isinstance(op, (ast.Div, ast.FloorDiv, ast.Mod, ast.Pow, ast.Sub)):
590588
# numeric operators work on any kind of number
591-
return self._coerce_numeric(left, right)
589+
return self._coerce_numeric((left, right), lambda typ: (typ, typ, typ))
592590
else: # MatMult
593591
diag = diagnostic.Diagnostic("error",
594592
"operator '{op}' is not supported", {"op": op.loc.source()},
@@ -605,6 +603,42 @@ def visit_BinOpT(self, node):
605603
node.right = self._coerce_one(right_type, node.right, other_node=node.left)
606604
node.type.unify(return_type) # should never fail
607605

606+
def visit_CompareT(self, node):
607+
self.generic_visit(node)
608+
pairs = zip([node.left] + node.comparators, node.comparators)
609+
if all(map(lambda op: isinstance(op, (ast.Is, ast.IsNot)), node.ops)):
610+
for left, right in pairs:
611+
self._unify(left.type, right.type,
612+
left.loc, right.loc)
613+
elif all(map(lambda op: isinstance(op, (ast.In, ast.NotIn)), node.ops)):
614+
for left, right in pairs:
615+
self._unify_collection(element=left, collection=right)
616+
else: # Eq, NotEq, Lt, LtE, Gt, GtE
617+
operands = [node.left] + node.comparators
618+
operand_types = [operand.type for operand in operands]
619+
if any(map(builtins.is_collection, operand_types)):
620+
for left, right in pairs:
621+
self._unify(left.type, right.type,
622+
left.loc, right.loc)
623+
else:
624+
typ = self._coerce_numeric(operands)
625+
if typ:
626+
try:
627+
other_node = next(filter(lambda operand: operand.type.find() == typ.find(),
628+
operands))
629+
except StopIteration:
630+
# can't find an argument with an exact type, meaning
631+
# the return value is more generic than any of the inputs, meaning
632+
# the type is known (typ is not None), but its width is not
633+
def wide_enough(opreand):
634+
return types.is_mono(opreand.type) and \
635+
opreand.type.find().name == typ.find().name
636+
other_node = next(filter(wide_enough, operands))
637+
print(typ, other_node)
638+
node.left, *node.comparators = \
639+
[self._coerce_one(typ, operand, other_node) for operand in operands]
640+
node.type.unify(builtins.TBool())
641+
608642
def visit_Assign(self, node):
609643
self.generic_visit(node)
610644
if len(node.targets) > 1:

Diff for: ‎lit-test/py2llvm/typing/coerce.py

+12
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,15 @@
2727

2828
a = []; a += [1]
2929
# CHECK-L: a:list(elt=int(width='r)) = []:list(elt=int(width='r)); a:list(elt=int(width='r)) += [1:int(width='r)]:list(elt=int(width='r))
30+
31+
[] is [1]
32+
# CHECK-L: []:list(elt=int(width='s)) is [1:int(width='s)]:list(elt=int(width='s)):bool
33+
34+
1 in [1]
35+
# CHECK-L: 1:int(width='t) in [1:int(width='t)]:list(elt=int(width='t)):bool
36+
37+
[] < [1]
38+
# CHECK-L: []:list(elt=int(width='u)) < [1:int(width='u)]:list(elt=int(width='u)):bool
39+
40+
1.0 < 1
41+
# CHECK-L: 1.0:float < 1:int(width='v):float:bool

Diff for: ‎lit-test/py2llvm/typing/error_coerce.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,7 @@
2525
# CHECK-L: ${LINE:+1}: note: operand of type list(elt='b), which is not a valid repetition amount
2626
[1] * []
2727

28-
# CHECK-L: ${LINE:+3}: error: cannot coerce list(elt='a) and NoneType to a common numeric type
29-
# CHECK-L: ${LINE:+2}: note: expression of type list(elt='a)
30-
# CHECK-L: ${LINE:+1}: note: expression of type NoneType
31-
[] - None
32-
33-
# CHECK-L: ${LINE:+2}: error: cannot coerce list(elt='a) to float
34-
# CHECK-L: ${LINE:+1}: note: expression that required coercion to float
28+
# CHECK-L: ${LINE:+1}: error: cannot coerce list(elt='a) to a numeric type
3529
[] - 1.0
3630

3731
# CHECK-L: ${LINE:+2}: error: expression of type int(width='a) has to be coerced to float, which makes assignment invalid

0 commit comments

Comments
 (0)
Please sign in to comment.