Skip to content

Commit

Permalink
Showing 10 changed files with 106 additions and 20 deletions.
11 changes: 11 additions & 0 deletions spec/compiler/codegen/def_spec.cr
Original file line number Diff line number Diff line change
@@ -519,4 +519,15 @@ describe "Code gen: def" do
foo
)).to_i.should eq(123 * 2 + 456)
end

it "can match N type argument of static array (#1203)" do
run(%(
def fn(a : StaticArray(T, N))
N
end
n = uninitialized StaticArray(Int32, 10)
fn(n)
)).to_i.should eq(10)
end
end
10 changes: 10 additions & 0 deletions spec/compiler/type_inference/double_splat_spec.cr
Original file line number Diff line number Diff line change
@@ -171,6 +171,16 @@ describe "Type inference: double splat" do
)) { named_tuple_of({"x" => int32, "y" => char}).metaclass }
end

it "uses double splat restriction, matches empty" do
assert_type(%(
def foo(**options : **T)
T
end
foo
)) { named_tuple_of({} of String => Type).metaclass }
end

it "uses double splat restriction with concrete type" do
assert_error %(
struct NamedTuple(T)
10 changes: 10 additions & 0 deletions spec/compiler/type_inference/splat_spec.cr
Original file line number Diff line number Diff line change
@@ -395,6 +395,16 @@ describe "Type inference: splat" do
)) { tuple_of([int32, char, bool]).metaclass }
end

it "uses splat restriction, matches empty" do
assert_type(%(
def foo(*args : *T)
T
end
foo
)) { tuple_of([] of Type).metaclass }
end

it "uses splat restriction with concrete type" do
assert_error %(
struct Tuple(T)
34 changes: 34 additions & 0 deletions spec/compiler/type_inference/static_array_spec.cr
Original file line number Diff line number Diff line change
@@ -94,4 +94,38 @@ describe "Type inference: static array" do
),
"can't instantiate StaticArray(T, N) with N = Int32 (N must be an integer)"
end

it "can match N type argument of static array (#1203)" do
assert_type(%(
def fn(a : StaticArray(T, N))
N
end
n = uninitialized StaticArray(Int32, 10)
fn(n)
)) { int32 }
end

it "can match number type argument of static array (#1203)" do
assert_type(%(
def fn(a : StaticArray(T, 10))
10
end
n = uninitialized StaticArray(Int32, 10)
fn(n)
)) { int32 }
end

it "doesn't match other number type argument of static array (#1203)" do
assert_error %(
def fn(a : StaticArray(T, 11))
10
end
n = uninitialized StaticArray(Int32, 10)
fn(n)
),
"no overload matches"
end
end
2 changes: 1 addition & 1 deletion src/compiler/crystal/macros/macros.cr
Original file line number Diff line number Diff line change
@@ -103,7 +103,7 @@ module Crystal
class MacroVisitor < Visitor
getter last : ASTNode
getter yields : Hash(String, ASTNode)?
property free_vars : Hash(String, Type)?
property free_vars : Hash(String, TypeVar)?

def self.new(expander, mod, scope : Type, type_lookup : Type, a_macro : Macro, call)
vars = {} of String => ASTNode
14 changes: 9 additions & 5 deletions src/compiler/crystal/semantic/base_type_visitor.cr
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ module Crystal
getter mod : Program
property types : Array(Type)

@free_vars : Hash(String, Type)?
@free_vars : Hash(String, TypeVar)?
@type_lookup : Type?
@scope : Type?
@typed_def : Def?
@@ -395,10 +395,14 @@ module Crystal

def resolve_ident?(node : Path, create_modules_if_missing = false)
free_vars = @free_vars
if free_vars && !node.global && (type = free_vars[node.names.first]?)
target_type = type
if node.names.size > 1
target_type = lookup_type target_type, node.names[1..-1], node
if free_vars && !node.global && (type_var = free_vars[node.names.first]?)
if type_var.is_a?(Type)
target_type = type_var
if node.names.size > 1
target_type = lookup_type target_type, node.names[1..-1], node
end
else
target_type = type_var
end
else
base_lookup = node.global ? mod : (@type_lookup || @scope || @types.last)
2 changes: 1 addition & 1 deletion src/compiler/crystal/semantic/call.cr
Original file line number Diff line number Diff line change
@@ -916,7 +916,7 @@ class Crystal::Call

def visit(node : Path)
if node.names.size == 1 && @context.free_vars
if type = @context.get_free_var(node.names.first)
if (type = @context.get_free_var(node.names.first)).is_a?(Type)
@type = type
return
end
4 changes: 2 additions & 2 deletions src/compiler/crystal/semantic/match.cr
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ module Crystal
class MatchContext
property owner : Type
property type_lookup : Type
getter free_vars : Hash(String, Type)?
getter free_vars : Hash(String, TypeVar)?
getter? strict : Bool

def initialize(@owner, @type_lookup, @free_vars = nil, @strict = false)
@@ -13,7 +13,7 @@ module Crystal
end

def set_free_var(name, type)
free_vars = @free_vars ||= {} of String => Type
free_vars = @free_vars ||= {} of String => TypeVar
free_vars[name] = type
end

25 changes: 15 additions & 10 deletions src/compiler/crystal/semantic/method_lookup.cr
Original file line number Diff line number Diff line change
@@ -153,19 +153,21 @@ module Crystal

matched_arg_types = nil

# If there's a restriction on a splat, zero splatted args don't match
if splat_index &&
a_def.args[splat_index].restriction &&
# If there's a restriction on a splat (that's not a splat restriction),
# zero splatted args don't match
if splat_index && splat_restriction &&
!splat_restriction.is_a?(Splat) &&
Splat.size(a_def, arg_types) == 0
return nil
end

splat_arg_types = nil
if splat_restriction.is_a?(Splat)
splat_arg_types = [] of Type
end

a_def.match(arg_types) do |arg, arg_index, arg_type, arg_type_index|
# Don't match argument against splat restriction
if arg_index == splat_index && splat_restriction.is_a?(Splat)
splat_arg_types ||= [] of Type
if arg_index == splat_index && splat_arg_types
splat_arg_types << arg_type
next
end
@@ -190,7 +192,10 @@ module Crystal
end

found_unmatched_named_arg = false
double_splat_entries = nil

if double_splat_restriction.is_a?(DoubleSplat)
double_splat_entries = [] of NamedArgumentType
end

# Check named args
if named_args
@@ -223,8 +228,7 @@ module Crystal
if a_def.double_splat
# If there's a restrction on the double splat, check that it matches
if double_splat_restriction
if double_splat_restriction.is_a?(DoubleSplat)
double_splat_entries ||= [] of NamedArgumentType
if double_splat_entries
double_splat_entries << named_arg
else
unless match_arg(named_arg.type, double_splat_restriction, context)
@@ -262,7 +266,8 @@ module Crystal
end

# If there's a restriction on a double splat, zero matching named arguments don't matc
if double_splat && double_splat.restriction && !found_unmatched_named_arg
if double_splat && double_splat_restriction &&
!double_splat_restriction.is_a?(DoubleSplat) && !found_unmatched_named_arg
return nil
end

14 changes: 13 additions & 1 deletion src/compiler/crystal/semantic/restrictions.cr
Original file line number Diff line number Diff line change
@@ -441,7 +441,19 @@ module Crystal
end

def restrict_type_var(type_var, other_type_var, context)
unless type_var.is_a?(NumberLiteral)
if type_var.is_a?(NumberLiteral)
case other_type_var
when NumberLiteral
if type_var == other_type_var
return type_var
end
when Path
if other_type_var.names.size == 1
context.set_free_var(other_type_var.names.first, type_var)
return type_var
end
end
else
type_var = type_var.type? || type_var
end

0 comments on commit 6b926e6

Please sign in to comment.