Skip to content

Commit

Permalink
Showing 14 changed files with 246 additions and 49 deletions.
9 changes: 9 additions & 0 deletions spec/compiler/parser/parser_spec.cr
Original file line number Diff line number Diff line change
@@ -200,6 +200,8 @@ describe "Parser" do
it_parses "def foo(@@var = 1); 1; end", Def.new("foo", [Arg.new("var", 1.int32)], [Assign.new("@@var".class_var, "var".var), 1.int32] of ASTNode)
it_parses "def foo(&@block); end", Def.new("foo", body: Assign.new("@block".instance_var, "block".var), block_arg: Arg.new("block"), yields: 0)

it_parses "def foo(a, &block : *Int -> ); end", Def.new("foo", [Arg.new("a")], block_arg: Arg.new("block", restriction: ProcNotation.new(["Int".path.splat] of ASTNode)), yields: 1)

it_parses "def foo(x, *args, y = 2); 1; end", Def.new("foo", args: ["x".arg, "args".arg, Arg.new("y", default_value: 2.int32)], body: 1.int32, splat_index: 1)
it_parses "def foo(x, *args, y = 2, w, z = 3); 1; end", Def.new("foo", args: ["x".arg, "args".arg, Arg.new("y", default_value: 2.int32), "w".arg, Arg.new("z", default_value: 3.int32)], body: 1.int32, splat_index: 1)
it_parses "def foo(x, *, y); 1; end", Def.new("foo", args: ["x".arg, "".arg, "y".arg], body: 1.int32, splat_index: 1)
@@ -400,6 +402,11 @@ describe "Parser" do
it_parses "abstract class Foo; end", ClassDef.new("Foo".path, abstract: true)
it_parses "abstract struct Foo; end", ClassDef.new("Foo".path, abstract: true, struct: true)

it_parses "module Foo(*T); end", ModuleDef.new("Foo".path, type_vars: ["T"], variadic: true)
it_parses "class Foo(*T); end", ClassDef.new("Foo".path, type_vars: ["T"], variadic: true)

assert_syntax_error "class Foo(*T, U)", "only one type variable is valid for variadic generic types"

it_parses "struct Foo; end", ClassDef.new("Foo".path, struct: true)

it_parses "Foo(T)", Generic.new("Foo".path, ["T".path] of ASTNode)
@@ -427,6 +434,8 @@ describe "Parser" do
it_parses "Foo({x: X, y: Y})", Generic.new("Foo".path, [Generic.new(Path.global("NamedTuple"), [] of ASTNode, named_args: [NamedArgument.new("x", "X".path), NamedArgument.new("y", "Y".path)])] of ASTNode)
assert_syntax_error "Foo({x: X, x: Y})", "duplicated key: x"

it_parses "Foo(*T)", Generic.new("Foo".path, ["T".path.splat] of ASTNode)

it_parses "module Foo; end", ModuleDef.new("Foo".path)
it_parses "module Foo\ndef foo; end; end", ModuleDef.new("Foo".path, [Def.new("foo")] of ASTNode)
it_parses "module Foo(T); end", ModuleDef.new("Foo".path, type_vars: ["T"])
36 changes: 36 additions & 0 deletions spec/compiler/type_inference/generic_class_spec.cr
Original file line number Diff line number Diff line change
@@ -736,4 +736,40 @@ describe "Type inference: generic class" do
),
"instance variable '@value' of Bar must be PluginContainer(Plugin), not PluginContainer(Foo)"
end

it "instantiates generic variadic class, accesses T from class method" do
assert_type(%(
class Foo(*T)
def self.t
T
end
end
Foo(Int32, Char).t
)) { tuple_of([int32, char]).metaclass }
end

it "instantiates generic variadic class, accesses T from instance method" do
assert_type(%(
class Foo(*T)
def t
T
end
end
Foo(Int32, Char).new.t
)) { tuple_of([int32, char]).metaclass }
end

it "splats generic type var" do
assert_type(%(
class Foo(X, Y)
def self.vars
{X, Y}
end
end
Foo(*{Int32, Char}).vars
)) { tuple_of([int32.metaclass, char.metaclass]) }
end
end
32 changes: 32 additions & 0 deletions spec/compiler/type_inference/module_spec.cr
Original file line number Diff line number Diff line change
@@ -899,4 +899,36 @@ describe "Type inference: module" do
Baz.new.x
)) { nilable int32 }
end

it "declares and includes generic module" do
assert_type(%(
module Moo(*T)
def t
T
end
end
class Foo
include Moo(Int32, Char)
end
Foo.new.t
)) { tuple_of([int32, char]).metaclass }
end

it "includes module with Union(T*)" do
assert_type(%(
module Foo(U)
def u
U
end
end
struct Tuple
include Foo(Union(*T))
end
{1, 'a'}.u
)) { union_of(int32, char).metaclass }
end
end
13 changes: 13 additions & 0 deletions spec/compiler/type_inference/proc_spec.cr
Original file line number Diff line number Diff line change
@@ -819,4 +819,17 @@ describe "Type inference: proc" do
->(x : Int32) { 'a' }.t
)) { tuple_of([int32, char]).metaclass }
end

it "can match *T in block argument" do
assert_type(%(
struct Tuple
def foo(&block : *T -> U)
yield self[0], self[1]
U
end
end
{1, 'a'}.foo { |x, y| true }
)) { bool.metaclass }
end
end
11 changes: 9 additions & 2 deletions src/compiler/crystal/semantic/ast.cr
Original file line number Diff line number Diff line change
@@ -663,14 +663,21 @@ module Crystal

generic_type = instance_type.instantiate_named_args(entries)
else
type_vars_types = type_vars.map do |node|
type_vars_types = Array(TypeVar).new(type_vars.size + 1)
type_vars.each do |node|
if node.is_a?(Path) && (syntax_replacement = node.syntax_replacement)
node = syntax_replacement
end

case node
when NumberLiteral
type_var = node
when Splat
type = node.type?
return unless type.is_a?(TupleInstanceType)

type_vars_types.concat(type.tuple_types)
next
else
node_type = node.type?
return unless node_type
@@ -698,7 +705,7 @@ module Crystal
end
end

type_var.as(TypeVar)
type_vars_types << type_var
end

begin
20 changes: 18 additions & 2 deletions src/compiler/crystal/semantic/base_type_visitor.cr
Original file line number Diff line number Diff line change
@@ -188,8 +188,24 @@ module Crystal
node.raise "can only use named arguments with NamedTuple"
end

if instance_type.type_vars.size != node.type_vars.size
node.wrong_number_of "type vars", instance_type, node.type_vars.size, instance_type.type_vars.size
# Need to count type vars because there might be splats
type_vars_count = 0
knows_count = true
node.type_vars.each do |type_var|
if type_var.is_a?(Splat)
if type_var.type?
type_vars_count += type_var.type.as(TupleInstanceType).size
else
knows_count = false
break
end
else
type_vars_count += 1
end
end

if knows_count && instance_type.type_vars.size != type_vars_count
node.wrong_number_of "type vars", instance_type, type_vars_count, instance_type.type_vars.size
end
end

23 changes: 19 additions & 4 deletions src/compiler/crystal/semantic/call.cr
Original file line number Diff line number Diff line change
@@ -705,11 +705,26 @@ class Crystal::Call
if block_arg_restriction.is_a?(ProcNotation)
# If there are input types, solve them and creating the yield vars
if inputs = block_arg_restriction.inputs
yield_vars = inputs.map_with_index do |input, i|
arg_type = ident_lookup.lookup_node_type(input)
MainVisitor.check_type_allowed_as_proc_argument(input, arg_type)
yield_vars = Array(Var).new(inputs.size + 1)
i = 0
inputs.each do |input|
if input.is_a?(Splat)
tuple_type = ident_lookup.lookup_node_type(input.exp)
unless tuple_type.is_a?(TupleInstanceType)
input.raise "expected type to be a tuple type, not #{tuple_type}"
end
tuple_type.tuple_types.each do |arg_type|
MainVisitor.check_type_allowed_as_proc_argument(input, arg_type)
yield_vars << Var.new("var#{i}", arg_type.virtual_type)
i += 1
end
else
arg_type = ident_lookup.lookup_node_type(input)
MainVisitor.check_type_allowed_as_proc_argument(input, arg_type)

Var.new("var#{i}", arg_type.virtual_type)
yield_vars << Var.new("var#{i}", arg_type.virtual_type)
i += 1
end
end
end
output = block_arg_restriction.output
6 changes: 5 additions & 1 deletion src/compiler/crystal/semantic/new.cr
Original file line number Diff line number Diff line change
@@ -131,7 +131,11 @@ module Crystal

def fill_body_from_initialize(instance_type)
if instance_type.is_a?(GenericClassType)
generic_type_args = instance_type.type_vars.map { |type_var| Path.new(type_var).as(ASTNode) }
generic_type_args = instance_type.type_vars.map do |type_var|
arg = Path.new(type_var).as(ASTNode)
arg = Splat.new(arg) if instance_type.variadic
arg
end
new_generic = Generic.new(Path.new(instance_type.name), generic_type_args)
alloc = Call.new(new_generic, "allocate")
else
11 changes: 9 additions & 2 deletions src/compiler/crystal/semantic/top_level_visitor.cr
Original file line number Diff line number Diff line change
@@ -202,6 +202,7 @@ module Crystal
created_new_type = true
if type_vars = node.type_vars
type = GenericClassType.new @mod, scope, name, superclass, type_vars, false
type.variadic = node.variadic?
else
type = NonGenericClassType.new @mod, scope, name, superclass, false
end
@@ -253,6 +254,7 @@ module Crystal
else
if type_vars = node.type_vars
type = GenericModuleType.new @mod, scope, name, type_vars
type.variadic = true if node.variadic?
else
type = NonGenericModuleType.new @mod, scope, name
end
@@ -937,11 +939,16 @@ module Crystal
node_name.raise "#{type} is not a generic module"
end

if type.type_vars.size != node_name.type_vars.size
if !type.variadic && type.type_vars.size != node_name.type_vars.size
node_name.wrong_number_of "type vars", type, node_name.type_vars.size, type.type_vars.size
end

mapping = Hash.zip(type.type_vars, node_name.type_vars)
node_name_type_vars = node_name.type_vars
if type.variadic
node_name_type_vars = [TupleLiteral.new(node_name_type_vars)] of ASTNode
end

mapping = Hash.zip(type.type_vars, node_name_type_vars)
module_to_include = IncludedGenericModule.new(@mod, type, current_type, mapping)

type.add_inherited(current_type)
36 changes: 30 additions & 6 deletions src/compiler/crystal/semantic/type_lookup.cr
Original file line number Diff line number Diff line change
@@ -120,15 +120,31 @@ module Crystal
end
end

type_vars = node.type_vars.map do |type_var|
if type_var.is_a?(NumberLiteral)
type_var
type_vars = Array(TypeVar).new(node.type_vars.size + 1)
node.type_vars.each do |type_var|
case type_var
when NumberLiteral
type_vars << type_var
when Splat
@type = nil
type_var.exp.accept self
return false if !@raise && !@type

splat_type = type
if splat_type.is_a?(TupleInstanceType)
type_vars.concat splat_type.tuple_types
else
return false if !@raise

type_var.raise "can only splat tuple type, not #{splat_type}"
end
else
# Check the case of T resolving to a number
if type_var.is_a?(Path) && type_var.names.size == 1
the_type = @root.lookup_type(type_var)
if the_type.is_a?(ASTNode)
next the_type.as(TypeVar)
type_vars << the_type
next
end
end

@@ -138,8 +154,8 @@ module Crystal

Crystal.check_type_allowed_in_generics(type_var, type, "can't use #{type} as a generic type argument")

type.virtual_type
end.as(TypeVar)
type_vars << type.virtual_type
end
end

begin
@@ -313,6 +329,14 @@ module Crystal
class IncludedGenericModule
def lookup_type(names : Array, already_looked_up = ObjectIdSet.new, lookup_in_container = true)
if (names.size == 1) && (m = @mapping[names[0]]?)
# Case of a variadic tuple
if m.is_a?(TupleLiteral)
types = m.elements.map do |element|
TypeLookup.lookup(@including_class, element).as(Type)
end
return program.tuple_of(types)
end

case @including_class
when GenericClassType, GenericModuleType
# skip
14 changes: 8 additions & 6 deletions src/compiler/crystal/syntax/ast.cr
Original file line number Diff line number Diff line change
@@ -1324,8 +1324,9 @@ module Crystal
property name_column_number : Int32
property attributes : Array(Attribute)?
property doc : String?
property? variadic : Bool

def initialize(@name, body = nil, @superclass = nil, @type_vars = nil, @abstract = false, @struct = false, @name_column_number = 0)
def initialize(@name, body = nil, @superclass = nil, @type_vars = nil, @abstract = false, @struct = false, @name_column_number = 0, @variadic = false)
@body = Expressions.from body
end

@@ -1335,10 +1336,10 @@ module Crystal
end

def clone_without_location
ClassDef.new(@name, @body.clone, @superclass.clone, @type_vars.clone, @abstract, @struct, @name_column_number)
ClassDef.new(@name, @body.clone, @superclass.clone, @type_vars.clone, @abstract, @struct, @name_column_number, @variadic)
end

def_equals_and_hash @name, @body, @superclass, @type_vars, @abstract, @struct
def_equals_and_hash @name, @body, @superclass, @type_vars, @abstract, @struct, @variadic
end

# Module definition:
@@ -1351,10 +1352,11 @@ module Crystal
property name : Path
property body : ASTNode
property type_vars : Array(String)?
property? variadic : Bool
property name_column_number : Int32
property doc : String?

def initialize(@name, body = nil, @type_vars = nil, @name_column_number = 0)
def initialize(@name, body = nil, @type_vars = nil, @name_column_number = 0, @variadic = false)
@body = Expressions.from body
end

@@ -1363,10 +1365,10 @@ module Crystal
end

def clone_without_location
ModuleDef.new(@name, @body.clone, @type_vars.clone, @name_column_number)
ModuleDef.new(@name, @body.clone, @type_vars.clone, @name_column_number, @variadic)
end

def_equals_and_hash @name, @body, @type_vars
def_equals_and_hash @name, @body, @type_vars, @variadic
end

# While expression.
67 changes: 44 additions & 23 deletions src/compiler/crystal/syntax/parser.cr
Original file line number Diff line number Diff line change
@@ -1491,7 +1491,7 @@ module Crystal
name = parse_ident allow_type_vars: false
skip_space

type_vars = parse_type_vars
type_vars, variadic = parse_type_vars

superclass = nil

@@ -1511,27 +1511,40 @@ module Crystal

@type_nest -= 1

class_def = ClassDef.new name, body, superclass, type_vars, is_abstract, is_struct, name_column_number
class_def = ClassDef.new name, body, superclass, type_vars, is_abstract, is_struct, name_column_number, variadic: variadic
class_def.doc = doc
class_def.end_location = end_location
class_def
end

def parse_type_vars
type_vars = nil
variadic = false
if @token.type == :"("
type_vars = [] of String

next_token_skip_space_or_newline

if @token.type == :"*"
variadic = true
next_token
end

while @token.type != :")"
type_var_name = check_const

if variadic && !type_vars.empty?
raise "only one type variable is valid for variadic generic types"
end

unless Parser.free_var_name?(type_var_name)
raise "type variables can only be single letters optionally followed by a digit", @token
end

if type_vars.includes? type_var_name
raise "duplicated type var name: #{type_var_name}", @token
end

type_vars.push type_var_name

next_token_skip_space_or_newline
@@ -1546,7 +1559,7 @@ module Crystal

next_token_skip_space
end
type_vars
{type_vars, variadic}
end

def parse_module_def
@@ -1561,7 +1574,7 @@ module Crystal
name = parse_ident allow_type_vars: false
skip_space

type_vars = parse_type_vars
type_vars, variadic = parse_type_vars
skip_statement_end

body = parse_expressions
@@ -1574,7 +1587,7 @@ module Crystal

@type_nest -= 1

module_def = ModuleDef.new name, body, type_vars, name_column_number
module_def = ModuleDef.new name, body, type_vars, name_column_number, variadic: variadic
module_def.doc = doc
module_def.end_location = end_location
module_def
@@ -3244,7 +3257,7 @@ module Crystal

location = @token.location

type_spec = parse_single_type
type_spec = parse_single_type(allow_splat: true)
end

block_arg = Arg.new(arg_name, restriction: type_spec).at(name_location)
@@ -3999,7 +4012,7 @@ module Crystal
types = [] of ASTNode
named_args = parse_type_named_args(:")")
else
types = parse_types allow_primitives: true
types = parse_types allow_primitives: true, allow_splat: true
if types.empty?
raise "must specify at least one type var"
end
@@ -4058,8 +4071,8 @@ module Crystal
named_args
end

def parse_types(allow_primitives = false)
type = parse_type(allow_primitives)
def parse_types(allow_primitives = false, allow_splat = false)
type = parse_type(allow_primitives: allow_primitives, allow_splat: allow_splat)
case type
when Array
type
@@ -4070,9 +4083,9 @@ module Crystal
end
end

def parse_single_type(allow_primitives = false, allow_commas = true)
def parse_single_type(allow_primitives = false, allow_commas = true, allow_splat = false)
location = @token.location
type = parse_type(allow_primitives, allow_commas: allow_commas)
type = parse_type(allow_primitives: allow_primitives, allow_commas: allow_commas, allow_splat: allow_splat)
case type
when Array
raise "unexpected ',' in type (use parentheses to disambiguate)", location
@@ -4083,13 +4096,13 @@ module Crystal
end
end

def parse_type(allow_primitives, allow_commas = true)
def parse_type(allow_primitives, allow_commas = true, allow_splat = false)
location = @token.location

if @token.type == :"->"
input_types = nil
else
input_types = parse_type_union(allow_primitives)
input_types = parse_type_union(allow_primitives, allow_splat)
input_types = [input_types] unless input_types.is_a?(Array)
while allow_commas && @token.type == :"," && ((allow_primitives && next_comes_type_or_int) || (!allow_primitives && next_comes_type))
next_token_skip_space_or_newline
@@ -4103,7 +4116,7 @@ module Crystal
end
next
else
type_union = parse_type_union(allow_primitives)
type_union = parse_type_union(allow_primitives, allow_splat)
if type_union.is_a?(Array)
input_types.concat type_union
else
@@ -4119,7 +4132,7 @@ module Crystal
when :"=", :",", :")", :"}", :";", :NEWLINE
return_type = nil
else
type_union = parse_type_union(allow_primitives)
type_union = parse_type_union(allow_primitives, allow_splat)
if type_union.is_a?(Array)
raise "can't return more than more type", location.line_number, location.column_number
else
@@ -4137,13 +4150,13 @@ module Crystal
end
end

def parse_type_union(allow_primitives)
def parse_type_union(allow_primitives, allow_splat)
types = [] of ASTNode
parse_type_with_suffix(types, allow_primitives)
parse_type_with_suffix(types, allow_primitives, allow_splat)
if @token.type == :"|"
while @token.type == :"|"
next_token_skip_space_or_newline
parse_type_with_suffix(types, allow_primitives)
parse_type_with_suffix(types, allow_primitives, false)
end

if types.size == 1
@@ -4158,16 +4171,23 @@ module Crystal
end
end

def parse_type_with_suffix(types, allow_primitives)
def parse_type_with_suffix(types, allow_primitives, allow_splat)
splat = false
if allow_splat && @token.type == :"*"
splat = true
next_token
end

location = @token.location

if @token.type == :IDENT && @token.value == "self?"
type = Self.new.at(@token.location)
type = Union.new([type, Path.global("Nil")] of ASTNode).at(@token.location)
type = Self.new.at(location)
type = Union.new([type, Path.global("Nil")] of ASTNode).at(location)
next_token_skip_space
elsif @token.keyword?(:self)
type = Self.new.at(@token.location)
type = Self.new.at(location)
next_token_skip_space
else
location = @token.location
case @token.type
when :"{"
next_token_skip_space_or_newline
@@ -4222,6 +4242,7 @@ module Crystal
end
end

type = Splat.new(type).at(location) if splat
types << parse_type_suffix(type)
end

2 changes: 2 additions & 0 deletions src/compiler/crystal/syntax/to_s.cr
Original file line number Diff line number Diff line change
@@ -210,6 +210,7 @@ module Crystal
node.name.accept self
if type_vars = node.type_vars
@str << "("
@str << "*" if node.variadic?
type_vars.each_with_index do |type_var, i|
@str << ", " if i > 0
@str << type_var.to_s
@@ -234,6 +235,7 @@ module Crystal
node.name.accept self
if type_vars = node.type_vars
@str << "("
@str << "*" if node.variadic?
type_vars.each_with_index do |type_var, i|
@str << ", " if i > 0
@str << type_var
15 changes: 12 additions & 3 deletions src/compiler/crystal/types.cr
Original file line number Diff line number Diff line change
@@ -1352,8 +1352,9 @@ module Crystal
index.upto(type_vars.size - 1) do |second_index|
types << type_vars[second_index]
end
tuple_type = program.tuple.instantiate(types).as(TupleInstanceType)
instance_type_vars[name] = tuple_type.var
var = Var.new(name, program.tuple_of(types))
var.bind_to(var)
instance_type_vars[name] = var
else
type_var = type_vars[index]
case type_var
@@ -1734,7 +1735,15 @@ module Crystal
type_vars.each_value do |type_var|
io << ", " if i > 0
if type_var.is_a?(Var)
type_var.type.to_s_with_options(io, skip_union_parens: true)
if self.variadic
tuple = type_var.type.as(TupleInstanceType)
tuple.tuple_types.each_with_index do |tuple_type, j|
io << ", " if j > 0
tuple_type.to_s(io)
end
else
type_var.type.to_s_with_options(io, skip_union_parens: true)
end
else
type_var.to_s(io)
end

0 comments on commit 1749358

Please sign in to comment.