Skip to content

Commit

Permalink
Showing 26 changed files with 1,022 additions and 91 deletions.
218 changes: 218 additions & 0 deletions spec/compiler/codegen/automatic_cast.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
require "../../spec_helper"

describe "Code gen: automatic cast" do
it "casts literal integer (Int32 -> Int64)" do
run(%(
def foo(x : Int64)
x
end
foo(12345)
)).to_i.should eq(12345)
end

it "casts literal integer (Int64 -> Int32, ok)" do
run(%(
def foo(x : Int32)
x
end
foo(2147483647_i64)
)).to_i.should eq(2147483647)
end

it "casts literal integer (Int32 -> Float32)" do
run(%(
def foo(x : Float32)
x
end
foo(12345).to_i
)).to_i.should eq(12345)
end

it "casts literal integer (Int32 -> Float64)" do
run(%(
def foo(x : Float64)
x
end
foo(12345).to_i
)).to_i.should eq(12345)
end

it "casts literal float (Float32 -> Float64)" do
run(%(
def foo(x : Float64)
x
end
foo(12345.0_f32).to_i
)).to_i.should eq(12345)
end

it "casts literal float (Float64 -> Float32)" do
run(%(
def foo(x : Float32)
x
end
foo(12345.0).to_i
)).to_i.should eq(12345)
end

it "casts symbol literal to enum" do
run(%(
:four
enum Foo
One
Two
Three
end
def foo(x : Foo)
x
end
foo(:three)
)).to_i.should eq(2)
end

it "casts Int32 to Int64 in ivar assignment" do
run(%(
class Foo
@x : Int64
def initialize
@x = 10
end
def x
@x
end
end
Foo.new.x
)).to_i.should eq(10)
end

it "casts Symbol to Enum in ivar assignment" do
run(%(
enum E
One
Two
Three
end
class Foo
@x : E
def initialize
@x = :three
end
def x
@x
end
end
Foo.new.x
)).to_i.should eq(2)
end

it "casts Int32 to Int64 in cvar assignment" do
run(%(
class Foo
@@x : Int64 = 0_i64
def self.x
@@x = 10
@@x
end
end
Foo.x
)).to_i.should eq(10)
end

it "casts Int32 to Int64 in lvar assignment" do
run(%(
x : Int64
x = 123
x
)).to_i.should eq(123)
end

it "casts Int32 to Int64 in ivar type declaration" do
run(%(
class Foo
@x : Int64 = 10
def x
@x
end
end
Foo.new.x
)).to_i.should eq(10)
end

it "casts Symbol to Enum in ivar type declaration" do
run(%(
enum Color
Red
Green
Blue
end
class Foo
@x : Color = :blue
def x
@x
end
end
Foo.new.x
)).to_i.should eq(2)
end

it "casts Int32 to Int64 in cvar type declaration" do
run(%(
class Foo
@@x : Int64 = 10
def self.x
@@x
end
end
Foo.x
)).to_i.should eq(10)
end

it "casts Int32 -> Int64 in arg restriction" do
run(%(
def foo(x : Int64 = 123)
x
end
foo
)).to_i.should eq(123)
end

it "casts Int32 to Int64 in ivar type declaration in generic" do
run(%(
class Foo(T)
@x : T = 10
def x
@x
end
end
Foo(Int64).new.x
)).to_i.should eq(10)
end
end
12 changes: 9 additions & 3 deletions spec/compiler/normalize/def_spec.cr
Original file line number Diff line number Diff line change
@@ -33,16 +33,22 @@ describe "Normalize: def" do
a_def = parse("def foo(x, y : Int32 = 1, z : Int64 = 2i64); x + y + z; end").as(Def)
actual = a_def.expand_default_arguments(Program.new, 1)
expected = parse("def foo(x); y = 1; z = 2i64; x + y + z; end").as(Def)
expected.body.as(Expressions).expressions.insert 1, TypeRestriction.new Var.new("y"), Path.new(["Int32"])
expected.body.as(Expressions).expressions.insert 3, TypeRestriction.new Var.new("z"), Path.new(["Int64"])

exps = expected.body.as(Expressions).expressions
exps[0] = AssignWithRestriction.new(exps[0].as(Assign), Path.new("Int32"))
exps[1] = AssignWithRestriction.new(exps[1].as(Assign), Path.new("Int64"))

actual.should eq(expected)
end

it "expands a def on request with default arguments and type restrictions (2)" do
a_def = parse("def foo(x, y : Int32 = 1, z : Int64 = 2i64); x + y + z; end").as(Def)
actual = a_def.expand_default_arguments(Program.new, 2)
expected = parse("def foo(x, y : Int32); z = 2i64; x + y + z; end").as(Def)
expected.body.as(Expressions).expressions.insert 1, TypeRestriction.new Var.new("z"), Path.new(["Int64"])

exps = expected.body.as(Expressions).expressions
exps[0] = AssignWithRestriction.new(exps[0].as(Assign), Path.new("Int64"))

actual.should eq(expected)
end

370 changes: 370 additions & 0 deletions spec/compiler/semantic/automatic_cast.cr
Original file line number Diff line number Diff line change
@@ -0,0 +1,370 @@
require "../../spec_helper"

describe "Semantic: automatic cast" do
it "casts literal integer (Int32 -> no restriction)" do
assert_type(%(
def foo(x)
x + 1
end
foo(12345)
), inject_primitives: true) { int32 }
end

it "casts literal integer (Int32 -> Int64)" do
assert_type(%(
def foo(x : Int64)
x
end
foo(12345)
)) { int64 }
end

it "casts literal integer (Int64 -> Int32, ok)" do
assert_type(%(
def foo(x : Int32)
x
end
foo(2147483647_i64)
)) { int32 }
end

it "casts literal integer (Int64 -> Int32, too big)" do
assert_error %(
def foo(x : Int32)
x
end
foo(2147483648_i64)
),
"no overload matches"
end

it "casts literal integer (Int32 -> Float32)" do
assert_type(%(
def foo(x : Float32)
x
end
foo(12345)
)) { float32 }
end

it "casts literal integer (Int32 -> Float64)" do
assert_type(%(
def foo(x : Float64)
x
end
foo(12345)
)) { float64 }
end

it "casts literal float (Float32 -> Float64)" do
assert_type(%(
def foo(x : Float64)
x
end
foo(1.23_f32)
)) { float64 }
end

it "casts literal float (Float64 -> Float32)" do
assert_type(%(
def foo(x : Float32)
x
end
foo(1.23)
)) { float32 }
end

it "matches correct overload" do
assert_type(%(
def foo(x : Int32)
x
end
def foo(x : Int64)
x
end
foo(1_i64)
)) { int64 }
end

it "casts literal integer through alias with union" do
assert_type(%(
alias A = Int64 | String
def foo(x : A)
x
end
foo(12345)
)) { int64 }
end

it "says ambiguous call for integer" do
assert_error %(
def foo(x : Int8)
x
end
def foo(x : UInt8)
x
end
foo(1)
),
"ambiguous"
end

it "says ambiguous call for integer (2)" do
assert_error %(
def foo(x : Int8 | UInt8)
x
end
foo(1)
),
"ambiguous"
end

it "casts symbol literal to enum" do
assert_type(%(
enum Foo
One
Two
Three
end
def foo(x : Foo)
x
end
foo(:one)
)) { types["Foo"] }
end

it "casts literal integer through alias with union" do
assert_type(%(
enum Foo
One
Two
end
alias A = Foo | String
def foo(x : A)
x
end
foo(:two)
)) { types["Foo"] }
end

it "errors if symbol name doesn't match enum member" do
assert_error %(
enum Foo
One
Two
Three
end
def foo(x : Foo)
x
end
foo(:four)
),
"no overload matches"
end

it "says ambiguous call for symbol" do
assert_error %(
enum Foo
One
Two
Three
end
enum Foo2
One
Two
Three
end
def foo(x : Foo)
x
end
def foo(x : Foo2)
x
end
foo(:one)
),
"ambiguous"
end

it "casts Int32 to Int64 in ivar assignment" do
assert_type(%(
class Foo
@x : Int64
def initialize
@x = 10
end
def x
@x
end
end
Foo.new.x
)) { int64 }
end

it "casts Symbol to Enum in ivar assignment" do
assert_type(%(
enum E
One
Two
Three
end
class Foo
@x : E
def initialize
@x = :two
end
def x
@x
end
end
Foo.new.x
)) { types["E"] }
end

it "casts Int32 to Int64 in cvar assignment" do
assert_type(%(
class Foo
@@x : Int64 = 0_i64
def self.x
@@x = 10
@@x
end
end
Foo.x
)) { int64 }
end

it "casts Int32 to Int64 in lvar assignment" do
assert_type(%(
x : Int64
x = 123
x
)) { int64 }
end

it "casts Int32 to Int64 in ivar type declaration" do
assert_type(%(
class Foo
@x : Int64 = 10
def x
@x
end
end
Foo.new.x
)) { int64 }
end

it "casts Symbol to Enum in ivar type declaration" do
assert_type(%(
enum Color
Red
Green
Blue
end
class Foo
@x : Color = :red
def x
@x
end
end
Foo.new.x
)) { types["Color"] }
end

it "casts Int32 to Int64 in cvar type declaration" do
assert_type(%(
class Foo
@@x : Int64 = 10
def self.x
@@x
end
end
Foo.x
)) { int64 }
end

it "casts Symbol to Enum in cvar type declaration" do
assert_type(%(
enum Color
Red
Green
Blue
end
class Foo
@@x : Color = :red
def self.x
@@x
end
end
Foo.x
)) { types["Color"] }
end

it "casts Int32 -> Int64 in arg restriction" do
assert_type(%(
def foo(x : Int64 = 0)
x
end
foo
)) { int64 }
end

it "casts Int32 to Int64 in ivar type declaration in generic" do
assert_type(%(
class Foo(T)
@x : T = 10
def x
@x
end
end
Foo(Int64).new.x
)) { int64 }
end
end
4 changes: 2 additions & 2 deletions spec/compiler/semantic/def_spec.cr
Original file line number Diff line number Diff line change
@@ -173,12 +173,12 @@ describe "Semantic: def" do

it "errors when default value is incompatible with type restriction" do
assert_error "
def foo(x : Int64 = 1)
def foo(x : Int64 = 'a')
end
foo
",
"can't restrict Int32 to Int64"
"can't restrict Char to Int64"
end

it "types call with global scope" do
4 changes: 2 additions & 2 deletions spec/compiler/semantic/uninitialized_spec.cr
Original file line number Diff line number Diff line change
@@ -58,9 +58,9 @@ describe "Semantic: uninitialized" do
it "errors if declares var and then assigns other type" do
assert_error %(
x = uninitialized Int32
x = 1_i64
x = 'a'
),
"type must be Int32, not (Int32 | Int64)"
"type must be Int32, not (Char | Int32)"
end

it "errors if declaring variable multiple times with different types (#917)" do
33 changes: 33 additions & 0 deletions src/compiler/crystal/codegen/cast.cr
Original file line number Diff line number Diff line change
@@ -488,6 +488,39 @@ class Crystal::CodeGenVisitor
target_pointer
end

# This is the case of the automatic cast between integer types
def downcast_distinct(value, to_type : IntegerType, from_type : IntegerType)
codegen_cast(from_type, to_type, value)
end

# This is the case of the automatic cast between integer type and float type
def downcast_distinct(value, to_type : FloatType, from_type : IntegerType)
codegen_cast(from_type, to_type, value)
end

# This is the case of the automatic cast between float types
def downcast_distinct(value, to_type : FloatType, from_type : FloatType)
codegen_cast(from_type, to_type, value)
end

# This is the case of the automatic cast between symbol and enum
def downcast_distinct(value, to_type : EnumType, from_type : SymbolType)
# value has the value of the symbol inside the symbol table,
# so we first get which symbol name that is, and then match
# it to one of the enum members
index = value.const_int_get_sext_value
symbol = @symbols_by_index[index].underscore

to_type.types.each do |name, value|
if name.underscore == symbol
accept(value.as(Const).value)
return @last
end
end

raise "Bug: expected to find enum member of #{to_type} matching symbol #{symbol}"
end

def downcast_distinct(value, to_type : Type, from_type : Type)
raise "BUG: trying to downcast #{to_type} <- #{from_type}"
end
2 changes: 2 additions & 0 deletions src/compiler/crystal/codegen/codegen.cr
Original file line number Diff line number Diff line change
@@ -179,9 +179,11 @@ module Crystal
@in_lib = false
@strings = {} of StringKey => LLVM::Value
@symbols = {} of String => Int32
@symbols_by_index = [] of String
@symbol_table_values = [] of LLVM::Value
program.symbols.each_with_index do |sym, index|
@symbols[sym] = index
@symbols_by_index << sym
@symbol_table_values << build_string_constant(sym, sym)
end

33 changes: 24 additions & 9 deletions src/compiler/crystal/semantic/ast.cr
Original file line number Diff line number Diff line change
@@ -77,21 +77,20 @@ module Crystal
def_equals_and_hash type
end

# Fictitious node to represent a type restriction
#
# It is used for type restrection of method arguments.
class TypeRestriction < ASTNode
getter obj
getter to
# Fictitious node to represent an assignment with a type restriction,
# created to match the assignment of a method argument's default value.
class AssignWithRestriction < ASTNode
property assign
property restriction

def initialize(@obj : ASTNode, @to : ASTNode)
def initialize(@assign : Assign, @restriction : ASTNode)
end

def clone_without_location
TypeRestriction.new @obj.clone, @to.clone
AssignWithRestriction.new @assign.clone, @restriction.clone
end

def_equals_and_hash obj, to
def_equals_and_hash assign, restriction
end

class Arg
@@ -734,4 +733,20 @@ module Crystal
self
end
end

class NumberLiteral
def can_be_autocast_to?(other_type)
case {self.type, other_type}
when {IntegerType, IntegerType}
min, max = other_type.range
min <= integer_value <= max
when {IntegerType, FloatType}
true
when {FloatType, FloatType}
true
else
false
end
end
end
end
15 changes: 15 additions & 0 deletions src/compiler/crystal/semantic/bindings.cr
Original file line number Diff line number Diff line change
@@ -17,6 +17,21 @@ module Crystal
@type
end

def type(*, with_literals = false)
type = self.type

if with_literals
case self
when NumberLiteral
return NumberLiteralType.new(type.program, self)
when SymbolLiteral
return SymbolLiteralType.new(type.program, self)
end
end

type
end

def set_type(type : Type)
type = type.remove_alias_if_simple
if !type.no_return? && (freeze_type = @freeze_type) && !type.implements?(freeze_type)
91 changes: 53 additions & 38 deletions src/compiler/crystal/semantic/call.cr
Original file line number Diff line number Diff line change
@@ -13,6 +13,9 @@ class Crystal::Call
property? uses_with_scope = false
getter? raises = false

class RetryLookupWithLiterals < ::Exception
end

def program
scope.program
end
@@ -93,16 +96,22 @@ class Crystal::Call
end

def lookup_matches
lookup_matches(with_literals: false)
rescue ex : RetryLookupWithLiterals
lookup_matches(with_literals: true)
end

def lookup_matches(*, with_literals = false)
if args.any? { |arg| arg.is_a?(Splat) || arg.is_a?(DoubleSplat) }
lookup_matches_with_splat
lookup_matches_with_splat(with_literals)
else
arg_types = args.map(&.type)
named_args_types = NamedArgumentType.from_args(named_args)
lookup_matches_without_splat arg_types, named_args_types
arg_types = args.map(&.type(with_literals: with_literals))
named_args_types = NamedArgumentType.from_args(named_args, with_literals)
lookup_matches_without_splat arg_types, named_args_types, with_literals
end
end

def lookup_matches_with_splat
def lookup_matches_with_splat(with_literals)
# Check if all splat are of tuples
arg_types = Array(Type).new(args.size * 2)
named_args_types = nil
@@ -133,7 +142,7 @@ class Crystal::Call
arg.raise "argument to double splat must be a named tuple, not #{arg_type}"
end
else
arg_types << arg.type
arg_types << arg.type(with_literals: with_literals)
end
end

@@ -143,81 +152,84 @@ class Crystal::Call
named_args_types ||= [] of NamedArgumentType
named_args.each do |named_arg|
raise "duplicate key: #{named_arg.name}" if named_args_types.any? &.name.==(named_arg.name)
named_args_types << NamedArgumentType.new(named_arg.name, named_arg.value.type)
named_args_types << NamedArgumentType.new(
named_arg.name,
named_arg.value.type(with_literals: with_literals),
)
end
end

lookup_matches_without_splat arg_types, named_args_types
lookup_matches_without_splat arg_types, named_args_types, with_literals: with_literals
end

def lookup_matches_without_splat(arg_types, named_args_types)
def lookup_matches_without_splat(arg_types, named_args_types, with_literals)
if obj = @obj
lookup_matches_in(obj.type, arg_types, named_args_types)
lookup_matches_in(obj.type, arg_types, named_args_types, with_literals: with_literals)
elsif name == "super"
lookup_super_matches(arg_types, named_args_types)
lookup_super_matches(arg_types, named_args_types, with_literals: with_literals)
elsif name == "previous_def"
lookup_previous_def_matches(arg_types, named_args_types)
lookup_previous_def_matches(arg_types, named_args_types, with_literals: with_literals)
elsif with_scope = @with_scope
lookup_matches_with_scope_in with_scope, arg_types, named_args_types
lookup_matches_with_scope_in with_scope, arg_types, named_args_types, with_literals: with_literals
else
lookup_matches_in scope, arg_types, named_args_types
lookup_matches_in scope, arg_types, named_args_types, with_literals: with_literals
end
end

def lookup_matches_in(owner : AliasType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
lookup_matches_in(owner.remove_alias, arg_types, named_args_types, search_in_parents: search_in_parents)
def lookup_matches_in(owner : AliasType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
lookup_matches_in(owner.remove_alias, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals)
end

def lookup_matches_in(owner : UnionType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
owner.union_types.flat_map { |type| lookup_matches_in(type, arg_types, named_args_types, search_in_parents: search_in_parents) }
def lookup_matches_in(owner : UnionType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
owner.union_types.flat_map { |type| lookup_matches_in(type, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals) }
end

def lookup_matches_in(owner : Program, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents)
def lookup_matches_in(owner : Program, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents: search_in_parents, with_literals: with_literals)
end

def lookup_matches_in(owner : FileModule, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
lookup_matches_in program, arg_types, named_args_types, search_in_parents: search_in_parents
def lookup_matches_in(owner : FileModule, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
lookup_matches_in program, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals
end

def lookup_matches_in(owner : NonGenericModuleType | GenericModuleInstanceType | GenericType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
def lookup_matches_in(owner : NonGenericModuleType | GenericModuleInstanceType | GenericType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
attach_subclass_observer owner

including_types = owner.including_types
if including_types
lookup_matches_in(including_types, arg_types, named_args_types, search_in_parents: search_in_parents)
lookup_matches_in(including_types, arg_types, named_args_types, search_in_parents: search_in_parents, with_literals: with_literals)
else
[] of Def
end
end

def lookup_matches_in(owner : LibType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
def lookup_matches_in(owner : LibType, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
raise "lib fun call is not supported in dispatch"
end

def lookup_matches_in(owner : Type, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true)
lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents)
def lookup_matches_in(owner : Type, arg_types, named_args_types, self_type = nil, def_name = self.name, search_in_parents = true, with_literals = false)
lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents: search_in_parents, with_literals: with_literals)
end

def lookup_matches_with_scope_in(owner, arg_types, named_args_types)
def lookup_matches_with_scope_in(owner, arg_types, named_args_types, with_literals = false)
signature = CallSignature.new(name, arg_types, block, named_args_types)

matches = lookup_matches_checking_expansion(owner, signature)
matches = lookup_matches_checking_expansion(owner, signature, with_literals: with_literals)

if matches.empty? && owner.class? && owner.abstract?
matches = owner.virtual_type.lookup_matches(signature)
end

if matches.empty?
@uses_with_scope = false
return lookup_matches_in scope, arg_types, named_args_types
return lookup_matches_in scope, arg_types, named_args_types, with_literals: with_literals
end

@uses_with_scope = true
instantiate matches, owner, self_type: nil
end

def lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents, search_in_toplevel = true)
def lookup_matches_in_type(owner, arg_types, named_args_types, self_type, def_name, search_in_parents, search_in_toplevel = true, with_literals = false)
signature = CallSignature.new(def_name, arg_types, block, named_args_types)

matches = check_tuple_indexer(owner, def_name, args, arg_types)
@@ -254,7 +266,7 @@ class Crystal::Call
# compile errors, which will anyway appear once you add concrete
# subclasses and instances.
if def_name == "new" || !(!owner.metaclass? && owner.abstract? && (owner.leaf? || owner.is_a?(GenericClassInstanceType)))
raise_matches_not_found(matches.owner || owner, def_name, arg_types, named_args_types, matches)
raise_matches_not_found(matches.owner || owner, def_name, arg_types, named_args_types, matches, with_literals: with_literals)
end
end

@@ -271,7 +283,7 @@ class Crystal::Call
instantiate matches, owner, self_type
end

def lookup_matches_checking_expansion(owner, signature, search_in_parents = true)
def lookup_matches_checking_expansion(owner, signature, search_in_parents = true, with_literals = false)
# If this call is an expansion (because of default or named args) we must
# resolve the call in the type that defined the original method, without
# triggering a virtual lookup. But the context of lookup must be preseved.
@@ -460,6 +472,7 @@ class Crystal::Call
in_bounds = (0 <= index < instance_type.size)
if nilable || in_bounds
indexer_def = yield instance_type, (in_bounds ? index : -1)
arg_types.map!(&.remove_literal)
indexer_match = Match.new(indexer_def, arg_types, MatchContext.new(owner, owner))
return Matches.new([indexer_match] of Match, true)
elsif instance_type.size == 0
@@ -483,6 +496,7 @@ class Crystal::Call
index = instance_type.name_index(name)
if index || nilable
indexer_def = yield instance_type, (index || -1)
arg_types.map!(&.remove_literal)
indexer_match = Match.new(indexer_def, arg_types, MatchContext.new(owner, owner))
return Matches.new([indexer_match] of Match, true)
else
@@ -554,7 +568,7 @@ class Crystal::Call
end
end

def lookup_super_matches(arg_types, named_args_types)
def lookup_super_matches(arg_types, named_args_types, with_literals)
if scope.is_a?(Program)
raise "there's no superclass in this scope"
end
@@ -592,16 +606,16 @@ class Crystal::Call
if parents && parents.size > 0
parents.each_with_index do |parent, i|
if parent.lookup_first_def(enclosing_def.name, block)
return lookup_matches_in_type(parent, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false)
return lookup_matches_in_type(parent, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false, with_literals: with_literals)
end
end
lookup_matches_in_type(parents.last, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false)
lookup_matches_in_type(parents.last, arg_types, named_args_types, scope, enclosing_def.name, !in_initialize, search_in_toplevel: false, with_literals: with_literals)
else
raise "there's no superclass in this scope"
end
end

def lookup_previous_def_matches(arg_types, named_args_types)
def lookup_previous_def_matches(arg_types, named_args_types, with_literals)
enclosing_def = enclosing_def("previous_def")

previous_item = enclosing_def.previous
@@ -613,11 +627,12 @@ class Crystal::Call

signature = CallSignature.new(previous.name, arg_types, block, named_args_types)
context = MatchContext.new(scope, scope, def_free_vars: previous.free_vars)
arg_types.map!(&.remove_literal)
match = Match.new(previous, arg_types, context, named_args_types)
matches = Matches.new([match] of Match, true)

unless signature.match(previous_item, context)
raise_matches_not_found scope, previous.name, arg_types, named_args_types, matches
raise_matches_not_found scope, previous.name, arg_types, named_args_types, matches, with_literals: with_literals
end

unless scope.is_a?(Program)
9 changes: 8 additions & 1 deletion src/compiler/crystal/semantic/call_error.cr
Original file line number Diff line number Diff line change
@@ -37,7 +37,7 @@ class Crystal::Path
end

class Crystal::Call
def raise_matches_not_found(owner, def_name, arg_types, named_args_types, matches = nil)
def raise_matches_not_found(owner, def_name, arg_types, named_args_types, matches = nil, with_literals = false)
# Special case: Foo+:Class#new
if owner.is_a?(VirtualMetaclassType) && def_name == "new"
raise_matches_not_found_for_virtual_metaclass_new owner
@@ -212,6 +212,13 @@ class Crystal::Call
end
end

# If we made a lookup without the special rule for literals,
# and we have literals in the call, try again with that special rule.
if with_literals == false && (args.any? { |arg| arg.is_a?(NumberLiteral) || arg.is_a?(SymbolLiteral) } ||
named_args.try &.any? { |arg| arg.value.is_a?(NumberLiteral) || arg.value.is_a?(SymbolLiteral) })
::raise RetryLookupWithLiterals.new
end

if args.size == 1 && args.first.type.includes_type?(program.nil)
owner_trace = args.first.find_owner_trace(program, program.nil)
end
12 changes: 11 additions & 1 deletion src/compiler/crystal/semantic/class_vars_initializer_visitor.cr
Original file line number Diff line number Diff line change
@@ -68,7 +68,17 @@ module Crystal
end

main_visitor.pushing_type(owner.as(ModuleType)) do
node.accept main_visitor
# Check if we can autocast
if (node.is_a?(NumberLiteral) || node.is_a?(SymbolLiteral)) &&
(class_var_type = class_var.type?)
cloned_node = node.clone
cloned_node.accept MainVisitor.new(self)
if casted_value = MainVisitor.check_automatic_cast(cloned_node, class_var_type)
node = initializer.node = casted_value
end
end

node.accept main_visitor unless node.type?
end

unless had_class_var
4 changes: 4 additions & 0 deletions src/compiler/crystal/semantic/cleanup_transformer.cr
Original file line number Diff line number Diff line change
@@ -753,6 +753,10 @@ module Crystal
node
end

def transform(node : AssignWithRestriction)
transform(node.assign)
end

@false_literal : BoolLiteral?

def false_literal
8 changes: 8 additions & 0 deletions src/compiler/crystal/semantic/cover.cr
Original file line number Diff line number Diff line change
@@ -276,4 +276,12 @@ module Crystal
class AliasType
delegate cover, cover_size, to: aliased_type
end

class NumberLiteralType
delegate cover, cover_size, to: (@matched_type || literal.type)
end

class SymbolLiteralType
delegate cover, cover_size, to: (@matched_type || literal.type)
end
end
6 changes: 4 additions & 2 deletions src/compiler/crystal/semantic/default_arguments.cr
Original file line number Diff line number Diff line change
@@ -121,11 +121,13 @@ class Crystal::Def
if default_value.is_a?(MagicConstant)
expansion.args.push arg.clone
else
new_body << Assign.new(Var.new(arg.name).at(arg), default_value).at(arg)
assign = Assign.new(Var.new(arg.name).at(arg), default_value).at(arg)

if restriction = arg.restriction
new_body << TypeRestriction.new(Var.new(arg.name).at(arg), restriction).at(arg)
assign = AssignWithRestriction.new(assign, restriction)
end

new_body << assign
end
end
end
27 changes: 21 additions & 6 deletions src/compiler/crystal/semantic/instance_vars_initializer_visitor.cr
Original file line number Diff line number Diff line change
@@ -70,6 +70,8 @@ class Crystal::InstanceVarsInitializerVisitor < Crystal::SemanticVisitor
end

def finish
scope_initializers = [] of InstanceVarInitializerContainer::InstanceVarInitializer?

# First declare them, so when we type all of them we will have
# the info of which instance vars have initializers (so they are not nil)
initializers.each do |i|
@@ -78,18 +80,31 @@ class Crystal::InstanceVarsInitializerVisitor < Crystal::SemanticVisitor
program.undefined_instance_variable(i.target, scope, nil)
end

scope.add_instance_var_initializer(i.target.name, i.value, scope.is_a?(GenericType) ? nil : i.meta_vars)
scope_initializers <<
scope.add_instance_var_initializer(i.target.name, i.value, scope.is_a?(GenericType) ? nil : i.meta_vars)
end

# Now type them
initializers.each do |i|
initializers.each_with_index do |i, index|
scope = i.scope
value = i.value

unless scope.is_a?(GenericType)
ivar_visitor = MainVisitor.new(program, meta_vars: i.meta_vars)
ivar_visitor.scope = scope
i.value.accept ivar_visitor
next if scope.is_a?(GenericType)

# Check if we can autocast
if (value.is_a?(NumberLiteral) || value.is_a?(SymbolLiteral)) &&
(scope_initializer = scope_initializers[index])
cloned_value = value.clone
cloned_value.accept MainVisitor.new(program)
if casted_value = MainVisitor.check_automatic_cast(cloned_value, scope.lookup_instance_var(i.target.name).type)
scope_initializer.value = casted_value
next
end
end

ivar_visitor = MainVisitor.new(program, meta_vars: i.meta_vars)
ivar_visitor.scope = scope
value.accept ivar_visitor
end
end
end
89 changes: 68 additions & 21 deletions src/compiler/crystal/semantic/main_visitor.cr
Original file line number Diff line number Diff line change
@@ -744,19 +744,40 @@ module Crystal
false
end

def type_assign(target : Var, value, node)
def type_assign(target : Var, value, node, restriction = nil)
value.accept self

var_name = target.name
meta_var = (@meta_vars[var_name] ||= new_meta_var(var_name))

if freeze_type = meta_var.freeze_type
if casted_value = check_automatic_cast(value, freeze_type, node)
value = casted_value
end
end

# If this assign comes from a AssignWithRestriction node, check the restriction

if restriction && (value_type = value.type?)
if value_type.restrict(restriction, match_context.not_nil!)
# OK
else
# Check autocast too
restriction_type = scope.lookup_type(restriction, free_vars: free_vars)
if casted_value = check_automatic_cast(value, restriction_type, node)
value = casted_value
else
node.raise "can't restrict #{value.type} to #{restriction}"
end
end
end

target.bind_to value
node.bind_to value

var_name = target.name

value_type_filters = @type_filters
@type_filters = nil

meta_var = (@meta_vars[var_name] ||= new_meta_var(var_name))

# Save variable assignment location for debugging output
meta_var.location ||= target.location

@@ -820,6 +841,9 @@ module Crystal
value.accept self

var = lookup_instance_var target
if casted_value = check_automatic_cast(value, var.type, node)
value = casted_value
end

target.bind_to var
node.bind_to value
@@ -914,6 +938,10 @@ module Crystal
var = lookup_class_var(target)
check_class_var_is_thread_local(target, var, attributes)

if casted_value = check_automatic_cast(value, var.type, node)
value = casted_value
end

target.bind_to var

node.bind_to value
@@ -934,6 +962,34 @@ module Crystal
raise "BUG: unknown assign target in MainVisitor: #{target}"
end

# See if we can automatically cast the value if the types don't exactly match
def check_automatic_cast(value, var_type, assign = nil)
MainVisitor.check_automatic_cast(value, var_type, assign)
end

def self.check_automatic_cast(value, var_type, assign = nil)
if value.is_a?(NumberLiteral) && value.type != var_type && (var_type.is_a?(IntegerType) || var_type.is_a?(FloatType))
if value.can_be_autocast_to?(var_type)
value.type = var_type
value.kind = var_type.kind
assign.value = value if assign
return value
end
elsif value.is_a?(SymbolLiteral) && var_type.is_a?(EnumType)
member = var_type.find_member(value.value)
if member
path = Path.new(member.name)
path.target_const = member
path.type = var_type
value = path
assign.value = value if assign
return value
end
end

nil
end

def visit(node : Yield)
call = @call
unless call
@@ -2941,22 +2997,13 @@ module Crystal
false
end

def visit(node : TypeRestriction)
obj = node.obj
to = node.to

obj.accept self

unless context = match_context
node.raise "BUG: there is no match context"
end

if type = obj.type.restrict(to, context)
node.type = type
else
node.raise "can't restrict #{obj.type} to #{to}"
end

def visit(node : AssignWithRestriction)
type_assign(
node.assign.target.as(Var),
node.assign.value,
node.assign,
restriction: node.restriction)
node.bind_to(node.assign)
false
end

1 change: 1 addition & 0 deletions src/compiler/crystal/semantic/match.cr
Original file line number Diff line number Diff line change
@@ -60,6 +60,7 @@ module Crystal

def set_free_var(name, type)
free_vars = @free_vars ||= {} of String => TypeVar
type = type.remove_literal if type.is_a?(Type)
free_vars[name] = type
end

6 changes: 4 additions & 2 deletions src/compiler/crystal/semantic/method_lookup.cr
Original file line number Diff line number Diff line change
@@ -2,8 +2,10 @@ require "../types"

module Crystal
record NamedArgumentType, name : String, type : Type do
def self.from_args(named_args : Array(NamedArgument)?)
named_args.try &.map { |named_arg| new(named_arg.name, named_arg.value.type) }
def self.from_args(named_args : Array(NamedArgument)?, with_literals = false)
named_args.try &.map do |named_arg|
new(named_arg.name, named_arg.value.type(with_literals: with_literals))
end
end
end

48 changes: 48 additions & 0 deletions src/compiler/crystal/semantic/restrictions.cr
Original file line number Diff line number Diff line change
@@ -1117,6 +1117,54 @@ module Crystal
true
end
end

class NumberLiteralType
def restrict(other, context)
if other.is_a?(IntegerType) || other.is_a?(FloatType)
if literal.can_be_autocast_to?(other)
if @matched_type && @matched_type != other
literal.raise "ambiguous call matches both #{@matched_type} and #{other}"
end

@matched_type = other
other
else
literal.type.restrict(other, context)
end
else
type = super(other, context) ||
literal.type.restrict(other, context)
if type == self
type = @matched_type || literal.type
end
type
end
end
end

class SymbolLiteralType
def restrict(other, context)
if other.is_a?(EnumType)
if other.find_member(literal.value)
if @matched_type && @matched_type != other
literal.raise "ambiguous call matches both #{@matched_type} and #{other}"
end

@matched_type = other
other
else
literal.type.restrict(other, context)
end
else
type = super(other, context) ||
literal.type.restrict(other, context)
if type == self
type = @matched_type || literal.type
end
type
end
end
end
end

private def get_generic_type(node, context)
8 changes: 5 additions & 3 deletions src/compiler/crystal/semantic/to_s.cr
Original file line number Diff line number Diff line change
@@ -48,11 +48,13 @@ module Crystal
false
end

def visit(node : TypeRestriction)
def visit(node : AssignWithRestriction)
@str << "# type restriction: "
node.obj.accept self
node.assign.target.accept self
@str << " : "
node.to.accept self
node.restriction.accept self
@str << " = "
node.assign.value.accept self
false
end

2 changes: 1 addition & 1 deletion src/compiler/crystal/semantic/transformer.cr
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ require "../syntax/transformer"

module Crystal
class Transformer
def transform(node : MetaVar | MetaMacroVar | Primitive | TypeFilteredNode | TupleIndexer | TypeNode | TypeRestriction | YieldBlockBinder | MacroId)
def transform(node : MetaVar | MetaMacroVar | Primitive | TypeFilteredNode | TupleIndexer | TypeNode | AssignWithRestriction | YieldBlockBinder | MacroId)
node
end

15 changes: 15 additions & 0 deletions src/compiler/crystal/syntax/ast.cr
Original file line number Diff line number Diff line change
@@ -222,6 +222,21 @@ module Crystal
@value[0] == '+' || @value[0] == '-'
end

def integer_value
case kind
when :i8 then value.to_i8
when :i16 then value.to_i16
when :i32 then value.to_i32
when :i64 then value.to_i64
when :u8 then value.to_u8
when :u16 then value.to_u16
when :u32 then value.to_u32
when :u64 then value.to_u64
else
raise "Bug: called 'integer_value' for non-integer literal"
end
end

def clone_without_location
NumberLiteral.new(@value, @kind)
end
85 changes: 85 additions & 0 deletions src/compiler/crystal/types.cr
Original file line number Diff line number Diff line change
@@ -554,6 +554,10 @@ module Crystal
self
end

def remove_literal
self
end

def generic_nest
0
end
@@ -1176,6 +1180,29 @@ module Crystal
def normal_rank
(@rank - 1) / 2
end

def range
case kind
when :i8
{Int8::MIN, Int8::MAX}
when :i16
{Int16::MIN, Int16::MAX}
when :i32
{Int32::MIN, Int32::MAX}
when :i64
{Int64::MIN, Int64::MAX}
when :u8
{UInt8::MIN, UInt8::MAX}
when :u16
{UInt16::MIN, UInt16::MAX}
when :u32
{UInt32::MIN, UInt32::MAX}
when :u64
{UInt64::MIN, UInt64::MAX}
else
raise "Bug: called 'range' for non-integer literal"
end
end
end

class FloatType < PrimitiveType
@@ -1206,6 +1233,45 @@ module Crystal
class VoidType < NamedType
end

# Type for a number literal: it has the specific type of the number literal
# but can also match other types (like ints and floats) if the literal
# fits in those types.
class NumberLiteralType < Type
getter literal : NumberLiteral
@matched_type : Type?

def initialize(program, @literal)
super(program)
end

def remove_literal
literal.type
end

def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true, codegen = false)
io << @literal.type
end
end

# Type for a symbol literal: it has the specific type of the symbol literal (SymbolType)
# but can also match enums if their members match the symbol's name.
class SymbolLiteralType < Type
getter literal : SymbolLiteral
@matched_type : Type?

def initialize(program, @literal)
super(program)
end

def remove_literal
literal.type
end

def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true, codegen = false)
io << @literal.type
end
end

# Any thing that can be passed as a generic type variable.
#
# For example, in:
@@ -1360,6 +1426,15 @@ module Crystal
value = initializer.value.clone
value.accept visitor
instance_var = instance.lookup_instance_var(initializer.name)

# Check if automatic cast can be done
if instance_var.type != value.type &&
(value.is_a?(NumberLiteral) || value.is_a?(SymbolLiteral))
if casted_value = MainVisitor.check_automatic_cast(value, instance_var.type)
value = casted_value
end
end

instance_var.bind_to(value)
instance.add_instance_var_initializer(initializer.name, value, meta_vars)
end
@@ -2451,6 +2526,16 @@ module Crystal
true
end

def find_member(name)
name = name.underscore
types.each do |member_name, member|
if name == member_name.underscore
return member.as(Const)
end
end
nil
end

def type_desc
"enum"
end
3 changes: 3 additions & 0 deletions src/llvm/lib_llvm.cr
Original file line number Diff line number Diff line change
@@ -352,4 +352,7 @@ lib LibLLVM
fun create_builder_in_context = LLVMCreateBuilderInContext(c : ContextRef) : BuilderRef

fun get_type_context = LLVMGetTypeContext(TypeRef) : ContextRef

fun const_int_get_sext_value = LLVMConstIntGetSExtValue(ValueRef) : Int64
fun const_int_get_zext_value = LLVMConstIntGetZExtValue(ValueRef) : UInt64
end
8 changes: 8 additions & 0 deletions src/llvm/value_methods.cr
Original file line number Diff line number Diff line change
@@ -87,6 +87,14 @@ module LLVM::ValueMethods
LibLLVM.set_alignment(self, bytes)
end

def const_int_get_sext_value
LibLLVM.const_int_get_sext_value(self)
end

def const_int_get_zext_value
LibLLVM.const_int_get_zext_value(self)
end

def to_value
LLVM::Value.new unwrap
end

0 comments on commit 5d3e16d

Please sign in to comment.