@@ -283,7 +283,6 @@ def visit_unsupported(self, node):
283
283
284
284
# expr
285
285
visit_Call = visit_unsupported
286
- visit_Compare = visit_unsupported
287
286
visit_Dict = visit_unsupported
288
287
visit_DictComp = visit_unsupported
289
288
visit_Ellipsis = visit_unsupported
@@ -393,11 +392,14 @@ def visit_AttributeT(self, node):
393
392
node .attr_loc , [node .value .loc ])
394
393
self .engine .process (diag )
395
394
395
+ def _unify_collection (self , element , collection ):
396
+ # TODO: support more than just lists
397
+ self ._unify (builtins .TList (element .type ), collection .type ,
398
+ element .loc , collection .loc )
399
+
396
400
def visit_SubscriptT (self , node ):
397
401
self .generic_visit (node )
398
- # TODO: support more than just lists
399
- self ._unify (builtins .TList (node .type ), node .value .type ,
400
- node .loc , node .value .loc )
402
+ self ._unify_collection (element = node , collection = node .value )
401
403
402
404
def visit_IfExpT (self , node ):
403
405
self .generic_visit (node )
@@ -455,43 +457,39 @@ def visit_CoerceT(self, node):
455
457
def _coerce_one (self , typ , coerced_node , other_node ):
456
458
if coerced_node .type .find () == typ .find ():
457
459
return coerced_node
460
+ elif isinstance (coerced_node , asttyped .CoerceT ):
461
+ node .type , node .other_expr = typ , other_node
458
462
else :
459
463
node = asttyped .CoerceT (type = typ , expr = coerced_node , other_expr = other_node ,
460
464
loc = coerced_node .loc )
461
- self .visit (node )
462
- return node
465
+ self .visit (node )
466
+ return node
463
467
464
- def _coerce_numeric (self , left , right ):
465
- # Implements the coercion protocol.
468
+ def _coerce_numeric (self , nodes , map_return = lambda typ : typ ):
466
469
# See https://docs.python.org/3/library/stdtypes.html#numeric-types-int-float-complex.
467
- if builtins .is_float (left .type ) or builtins .is_float (right .type ):
468
- typ = builtins .TFloat ()
469
- elif builtins .is_int (left .type ) or builtins .is_int (right .type ):
470
- left_width , right_width = \
471
- builtins .get_int_width (left .type ), builtins .get_int_width (left .type )
472
- if left_width and right_width :
473
- typ = builtins .TInt (types .TValue (max (left_width , right_width )))
474
- else :
475
- typ = builtins .TInt ()
476
- elif types .is_var (left .type ) or types .is_var (right .type ): # not enough info yet
470
+ node_types = [node .type for node in nodes ]
471
+ if any (map (types .is_var , node_types )): # not enough info yet
477
472
return
478
- else : # conflicting types
479
- printer = types .TypePrinter ()
480
- note1 = diagnostic .Diagnostic ("note" ,
481
- "expression of type {typea}" , {"typea" : printer .name (left .type )},
482
- left .loc )
483
- note2 = diagnostic .Diagnostic ("note" ,
484
- "expression of type {typeb}" , {"typeb" : printer .name (right .type )},
485
- right .loc )
473
+ elif not all (map (builtins .is_numeric , node_types )):
474
+ err_node = next (filter (lambda node : not builtins .is_numeric (node .type ), nodes ))
486
475
diag = diagnostic .Diagnostic ("error" ,
487
- "cannot coerce {typea} and {typeb} to a common numeric type" ,
488
- {"typea" : printer .name (left .type ), "typeb" : printer .name (right .type )},
489
- left .loc , [right .loc ],
490
- [note1 , note2 ])
476
+ "cannot coerce {type} to a numeric type" ,
477
+ {"type" : types .TypePrinter ().name (err_node .type )},
478
+ err_node .loc , [])
491
479
self .engine .process (diag )
492
480
return
481
+ elif any (map (builtins .is_float , node_types )):
482
+ typ = builtins .TFloat ()
483
+ elif any (map (builtins .is_int , node_types )):
484
+ widths = map (builtins .get_int_width , node_types )
485
+ if all (widths ):
486
+ typ = builtins .TInt (types .TValue (max (widths )))
487
+ else :
488
+ typ = builtins .TInt ()
489
+ else :
490
+ assert False
493
491
494
- return typ , typ , typ
492
+ return map_return ( typ )
495
493
496
494
def _order_by_pred (self , pred , left , right ):
497
495
if pred (left .type ):
@@ -503,7 +501,7 @@ def _order_by_pred(self, pred, left, right):
503
501
504
502
def _coerce_binop (self , op , left , right ):
505
503
if isinstance (op , (ast .BitAnd , ast .BitOr , ast .BitXor ,
506
- ast .LShift , ast .RShift )):
504
+ ast .LShift , ast .RShift )):
507
505
# bitwise operators require integers
508
506
for operand in (left , right ):
509
507
if not types .is_var (operand .type ) and not builtins .is_int (operand .type ):
@@ -515,7 +513,7 @@ def _coerce_binop(self, op, left, right):
515
513
self .engine .process (diag )
516
514
return
517
515
518
- return self ._coerce_numeric (left , right )
516
+ return self ._coerce_numeric (( left , right ), lambda typ : ( typ , typ , typ ) )
519
517
elif isinstance (op , ast .Add ):
520
518
# add works on numbers and also collections
521
519
if builtins .is_collection (left .type ) or builtins .is_collection (right .type ):
@@ -554,7 +552,7 @@ def _coerce_binop(self, op, left, right):
554
552
left .loc , right .loc )
555
553
return left .type , left .type , right .type
556
554
else :
557
- return self ._coerce_numeric (left , right )
555
+ return self ._coerce_numeric (( left , right ), lambda typ : ( typ , typ , typ ) )
558
556
elif isinstance (op , ast .Mult ):
559
557
# mult works on numbers and also number & collection
560
558
if types .is_tuple (left .type ) or types .is_tuple (right .type ):
@@ -585,10 +583,10 @@ def _coerce_binop(self, op, left, right):
585
583
586
584
return list_ .type , left .type , right .type
587
585
else :
588
- return self ._coerce_numeric (left , right )
586
+ return self ._coerce_numeric (( left , right ), lambda typ : ( typ , typ , typ ) )
589
587
elif isinstance (op , (ast .Div , ast .FloorDiv , ast .Mod , ast .Pow , ast .Sub )):
590
588
# numeric operators work on any kind of number
591
- return self ._coerce_numeric (left , right )
589
+ return self ._coerce_numeric (( left , right ), lambda typ : ( typ , typ , typ ) )
592
590
else : # MatMult
593
591
diag = diagnostic .Diagnostic ("error" ,
594
592
"operator '{op}' is not supported" , {"op" : op .loc .source ()},
@@ -605,6 +603,42 @@ def visit_BinOpT(self, node):
605
603
node .right = self ._coerce_one (right_type , node .right , other_node = node .left )
606
604
node .type .unify (return_type ) # should never fail
607
605
606
+ def visit_CompareT (self , node ):
607
+ self .generic_visit (node )
608
+ pairs = zip ([node .left ] + node .comparators , node .comparators )
609
+ if all (map (lambda op : isinstance (op , (ast .Is , ast .IsNot )), node .ops )):
610
+ for left , right in pairs :
611
+ self ._unify (left .type , right .type ,
612
+ left .loc , right .loc )
613
+ elif all (map (lambda op : isinstance (op , (ast .In , ast .NotIn )), node .ops )):
614
+ for left , right in pairs :
615
+ self ._unify_collection (element = left , collection = right )
616
+ else : # Eq, NotEq, Lt, LtE, Gt, GtE
617
+ operands = [node .left ] + node .comparators
618
+ operand_types = [operand .type for operand in operands ]
619
+ if any (map (builtins .is_collection , operand_types )):
620
+ for left , right in pairs :
621
+ self ._unify (left .type , right .type ,
622
+ left .loc , right .loc )
623
+ else :
624
+ typ = self ._coerce_numeric (operands )
625
+ if typ :
626
+ try :
627
+ other_node = next (filter (lambda operand : operand .type .find () == typ .find (),
628
+ operands ))
629
+ except StopIteration :
630
+ # can't find an argument with an exact type, meaning
631
+ # the return value is more generic than any of the inputs, meaning
632
+ # the type is known (typ is not None), but its width is not
633
+ def wide_enough (opreand ):
634
+ return types .is_mono (opreand .type ) and \
635
+ opreand .type .find ().name == typ .find ().name
636
+ other_node = next (filter (wide_enough , operands ))
637
+ print (typ , other_node )
638
+ node .left , * node .comparators = \
639
+ [self ._coerce_one (typ , operand , other_node ) for operand in operands ]
640
+ node .type .unify (builtins .TBool ())
641
+
608
642
def visit_Assign (self , node ):
609
643
self .generic_visit (node )
610
644
if len (node .targets ) > 1 :
0 commit comments