Skip to content

Commit

Permalink
Showing 3 changed files with 205 additions and 92 deletions.
15 changes: 15 additions & 0 deletions spec/compiler/type_inference/struct_spec.cr
Original file line number Diff line number Diff line change
@@ -214,4 +214,19 @@ describe "Type inference: struct" do
Bar.new as Foo
)) { types["Foo"].virtual_type! }
end

it "detects recursive struct through module" do
assert_error %(
module Moo
end
struct Foo
include Moo
def initialize(@moo : Moo)
end
end
),
"recursive struct Foo detected: `@moo : Moo` -> `Moo` -> `Foo`"
end
end
262 changes: 172 additions & 90 deletions src/compiler/crystal/codegen/llvm_typer.cr
Original file line number Diff line number Diff line change
@@ -20,6 +20,38 @@ module Crystal
@struct_cache = TypeCache.new
@union_value_cache = TypeCache.new

# For union types we just need to know the maximum size of their types.
# It might happen that we have a recursive type, for example:
#
# ```
# struct Foo
# def initialize
# @x = uninitialized Pointer(Int32 | Foo)
# end
# end
# ```
#
# In that case, when we are computing the llvm type of Foo, we will
# need to compute the llvm type of `@x`. Its type is a pointer to
# a union. In order to compute the llvm type of a union we need
# to compute the size of each type. For this, we compute the llvm
# type of each type in the union and then get their size. The problem
# here is that we are computing `Foo`, so we can't know its size yet.
#
# To solve this, when computing the llvm type of the union types,
# we do it with a `wants_size` flag. In the case of pointers we
# can just return a word size (using size_of(LLVM::VoidPointer)) instead
# of computing the llvm type of the pointer element. This avoids the
# recursion.
#
# We still need a separate cache for this types that we use to compute
# types, because there can be cycles.
@wants_size_cache = TypeCache.new
@wants_size_struct_cache = TypeCache.new
@wants_size_union_value_cache = TypeCache.new

@types_being_computed = Set(Type).new

machine = program.target_machine
@layout = machine.data_layout
@landing_pad_type = LLVM::Type.struct([LLVM::VoidPointer, LLVM::Int32], "landing_pad")
@@ -34,139 +66,160 @@ module Crystal
]
end

def llvm_type(type)
def llvm_type(type, wants_size = false)
type = type.remove_indirection
@cache[type] ||= create_llvm_type(type)

if wants_size
@wants_size_cache[type] ||= create_llvm_type(type, wants_size: true)
else
@cache[type] ||= create_llvm_type(type, wants_size)
end
end

private def create_llvm_type(type : NoReturnType)
private def create_llvm_type(type : NoReturnType, wants_size)
LLVM::Void
end

private def create_llvm_type(type : VoidType)
private def create_llvm_type(type : VoidType, wants_size)
LLVM::Void
end

private def create_llvm_type(type : NilType)
private def create_llvm_type(type : NilType, wants_size)
NIL_TYPE
end

private def create_llvm_type(type : BoolType)
private def create_llvm_type(type : BoolType, wants_size)
LLVM::Int1
end

private def create_llvm_type(type : CharType)
private def create_llvm_type(type : CharType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : IntegerType)
private def create_llvm_type(type : IntegerType, wants_size)
LLVM::Type.int(8 * type.bytes)
end

private def create_llvm_type(type : FloatType)
private def create_llvm_type(type : FloatType, wants_size)
type.bytes == 4 ? LLVM::Float : LLVM::Double
end

private def create_llvm_type(type : SymbolType)
private def create_llvm_type(type : SymbolType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : EnumType)
private def create_llvm_type(type : EnumType, wants_size)
llvm_type(type.base_type)
end

private def create_llvm_type(type : ProcInstanceType)
private def create_llvm_type(type : ProcInstanceType, wants_size)
PROC_TYPE
end

private def create_llvm_type(type : CStructType)
llvm_struct_type(type)
private def create_llvm_type(type : CStructType, wants_size)
llvm_struct_type(type, wants_size)
end

private def create_llvm_type(type : CUnionType)
llvm_struct_type(type)
private def create_llvm_type(type : CUnionType, wants_size)
llvm_struct_type(type, wants_size)
end

private def create_llvm_type(type : InstanceVarContainer)
final_type = llvm_struct_type(type)
private def create_llvm_type(type : InstanceVarContainer, wants_size)
final_type = llvm_struct_type(type, wants_size)
unless type.struct?
final_type = final_type.pointer
end
final_type
end

private def create_llvm_type(type : MetaclassType)
private def create_llvm_type(type : MetaclassType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : LibType)
private def create_llvm_type(type : LibType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : GenericClassInstanceMetaclassType)
private def create_llvm_type(type : GenericClassInstanceMetaclassType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : VirtualMetaclassType)
private def create_llvm_type(type : VirtualMetaclassType, wants_size)
LLVM::Int32
end

private def create_llvm_type(type : PointerInstanceType)
pointed_type = llvm_embedded_type type.element_type
private def create_llvm_type(type : PointerInstanceType, wants_size)
if wants_size
return LLVM::VoidPointer
end

pointed_type = llvm_embedded_type(type.element_type, wants_size)
pointed_type = LLVM::Int8 if pointed_type == LLVM::Void
pointed_type.pointer
end

private def create_llvm_type(type : StaticArrayInstanceType)
pointed_type = llvm_embedded_type type.element_type
private def create_llvm_type(type : StaticArrayInstanceType, wants_size)
pointed_type = llvm_embedded_type(type.element_type, wants_size)
pointed_type = LLVM::Int8 if pointed_type == LLVM::Void
pointed_type.array type.size.as(NumberLiteral).value.to_i
end

private def create_llvm_type(type : TupleInstanceType)
private def create_llvm_type(type : TupleInstanceType, wants_size)
LLVM::Type.struct(type.llvm_name) do |a_struct|
@cache[type] = a_struct
if wants_size
@wants_size_cache[type] = a_struct
else
@cache[type] = a_struct
end

type.tuple_types.map { |tuple_type| llvm_embedded_type(tuple_type).as(LLVM::Type) }
type.tuple_types.map { |tuple_type| llvm_embedded_type(tuple_type, wants_size).as(LLVM::Type) }
end
end

private def create_llvm_type(type : NamedTupleInstanceType)
private def create_llvm_type(type : NamedTupleInstanceType, wants_size)
LLVM::Type.struct(type.llvm_name) do |a_struct|
@cache[type] = a_struct
if wants_size
@wants_size_cache[type] = a_struct
else
@cache[type] = a_struct
end

type.entries.map { |entry| llvm_embedded_type(entry.type).as(LLVM::Type) }
type.entries.map { |entry| llvm_embedded_type(entry.type, wants_size).as(LLVM::Type) }
end
end

private def create_llvm_type(type : NilableType)
llvm_type type.not_nil_type
private def create_llvm_type(type : NilableType, wants_size)
llvm_type(type.not_nil_type, wants_size)
end

private def create_llvm_type(type : ReferenceUnionType)
private def create_llvm_type(type : ReferenceUnionType, wants_size)
TYPE_ID_POINTER
end

private def create_llvm_type(type : NilableReferenceUnionType)
private def create_llvm_type(type : NilableReferenceUnionType, wants_size)
TYPE_ID_POINTER
end

private def create_llvm_type(type : NilableProcType)
private def create_llvm_type(type : NilableProcType, wants_size)
PROC_TYPE
end

private def create_llvm_type(type : NilablePointerType)
llvm_type(type.pointer_type)
private def create_llvm_type(type : NilablePointerType, wants_size)
llvm_type(type.pointer_type, wants_size)
end

private def create_llvm_type(type : MixedUnionType)
private def create_llvm_type(type : MixedUnionType, wants_size)
LLVM::Type.struct(type.llvm_name) do |a_struct|
@cache[type] = a_struct
if wants_size
@wants_size_cache[type] = a_struct
else
@cache[type] = a_struct
end

max_size = 0
type.expand_union_types.each do |subtype|
unless subtype.void?
size = size_of(llvm_type(subtype))
size = size_of(llvm_type(subtype, wants_size: true))
max_size = size if size > max_size
end
end
@@ -177,60 +230,82 @@ module Crystal
max_size = 1 if max_size == 0

llvm_value_type = LLVM::SizeT.array(max_size)
@union_value_cache[type] = llvm_value_type

if wants_size
@wants_size_union_value_cache[type] = llvm_value_type
else
@union_value_cache[type] = llvm_value_type
end

[LLVM::Int32, llvm_value_type]
end
end

private def create_llvm_type(type : TypeDefType)
llvm_type type.typedef
private def create_llvm_type(type : TypeDefType, wants_size)
llvm_type(type.typedef, wants_size)
end

private def create_llvm_type(type : VirtualType)
private def create_llvm_type(type : VirtualType, wants_size)
TYPE_ID_POINTER
end

private def create_llvm_type(type : AliasType)
llvm_type(type.remove_alias)
private def create_llvm_type(type : AliasType, wants_size)
llvm_type(type.remove_alias, wants_size)
end

private def create_llvm_type(type : NonGenericModuleType | GenericClassType)
private def create_llvm_type(type : NonGenericModuleType | GenericClassType, wants_size)
# This can only be reached if the module or generic class don't have implementors
LLVM::Int1
end

private def create_llvm_type(type : Type)
private def create_llvm_type(type : Type, wants_size)
raise "Bug: called create_llvm_type for #{type}"
end

def llvm_struct_type(type)
def llvm_struct_type(type, wants_size = false)
type = type.remove_indirection
@struct_cache[type] ||= create_llvm_struct_type(type)

if wants_size
@wants_size_struct_cache[type] ||= create_llvm_struct_type(type, wants_size: true)
else
@struct_cache[type] ||= create_llvm_struct_type(type, wants_size)
end
end

private def create_llvm_struct_type(type : StaticArrayInstanceType)
llvm_type type
private def create_llvm_struct_type(type : StaticArrayInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def create_llvm_struct_type(type : TupleInstanceType)
llvm_type type
private def create_llvm_struct_type(type : TupleInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def create_llvm_struct_type(type : NamedTupleInstanceType)
llvm_type type
private def create_llvm_struct_type(type : NamedTupleInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def create_llvm_struct_type(type : CStructType)
private def create_llvm_struct_type(type : CStructType, wants_size)
LLVM::Type.struct(type.llvm_name, type.packed) do |a_struct|
@struct_cache[type] = a_struct
type.vars.map { |name, var| llvm_embedded_c_type(var.type).as(LLVM::Type) }
if wants_size
@wants_size_struct_cache[type] = a_struct
else
@struct_cache[type] = a_struct
end

@types_being_computed.add(type)
types = type.vars.map { |name, var| llvm_embedded_c_type(var.type, wants_size).as(LLVM::Type) }
@types_being_computed.delete(type)
types
end
end

private def create_llvm_struct_type(type : CUnionType)
private def create_llvm_struct_type(type : CUnionType, wants_size)
LLVM::Type.struct(type.llvm_name) do |a_struct|
@struct_cache[type] = a_struct
if wants_size
@wants_size_struct_cache[type] = a_struct
else
@struct_cache[type] = a_struct
end

max_size = 0
max_align = 0
@@ -240,7 +315,7 @@ module Crystal
type.vars.each do |name, var|
var_type = var.type
unless var_type.void?
llvm_type = llvm_embedded_c_type(var_type)
llvm_type = llvm_embedded_c_type(var_type, wants_size: true)
size = size_of(llvm_type)
align = align_of(llvm_type)

@@ -266,9 +341,13 @@ module Crystal
end
end

private def create_llvm_struct_type(type : InstanceVarContainer)
private def create_llvm_struct_type(type : InstanceVarContainer, wants_size)
LLVM::Type.struct(type.llvm_name) do |a_struct|
@struct_cache[type] = a_struct
if wants_size
@wants_size_struct_cache[type] = a_struct
else
@struct_cache[type] = a_struct
end

ivars = type.all_instance_vars
ivars_size = ivars.size
@@ -283,19 +362,22 @@ module Crystal
element_types.push LLVM::Int32 # For the type id
end

@types_being_computed.add(type)
ivars.each do |name, ivar|
if ivar_type = ivar.type?
element_types.push llvm_embedded_type(ivar_type)
element_types.push llvm_embedded_type(ivar_type, wants_size)
else
# This is for untyped fields: we don't really care how to represent them in memory.
element_types.push LLVM::Int8
end
end
@types_being_computed.delete(type)

element_types
end
end

private def create_llvm_struct_type(type : Type)
private def create_llvm_struct_type(type : Type, wants_size)
raise "Bug: called llvm_struct_type for #{type}"
end

@@ -311,56 +393,56 @@ module Crystal
end
end

def llvm_embedded_type(type)
llvm_embedded_type_impl(type.remove_indirection)
def llvm_embedded_type(type, wants_size = false)
llvm_embedded_type_impl(type.remove_indirection, wants_size)
end

private def llvm_embedded_type_impl(type : CStructType)
llvm_struct_type type
private def llvm_embedded_type_impl(type : CStructType, wants_size)
llvm_struct_type(type, wants_size)
end

private def llvm_embedded_type_impl(type : CUnionType)
llvm_struct_type type
private def llvm_embedded_type_impl(type : CUnionType, wants_size)
llvm_struct_type(type, wants_size)
end

private def llvm_embedded_type_impl(type : ProcInstanceType)
llvm_type type
private def llvm_embedded_type_impl(type : ProcInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def llvm_embedded_type_impl(type : PointerInstanceType)
llvm_type type
private def llvm_embedded_type_impl(type : PointerInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def llvm_embedded_type_impl(type : InstanceVarContainer)
private def llvm_embedded_type_impl(type : InstanceVarContainer, wants_size)
if type.struct?
llvm_struct_type type
llvm_struct_type(type, wants_size)
else
llvm_type type
llvm_type(type, wants_size)
end
end

private def llvm_embedded_type_impl(type : StaticArrayInstanceType)
llvm_type type
private def llvm_embedded_type_impl(type : StaticArrayInstanceType, wants_size)
llvm_type(type, wants_size)
end

private def llvm_embedded_type_impl(type : NoReturnType)
private def llvm_embedded_type_impl(type : NoReturnType, wants_size)
LLVM::Int8
end

private def llvm_embedded_type_impl(type : VoidType)
private def llvm_embedded_type_impl(type : VoidType, wants_size)
LLVM::Int8
end

private def llvm_embedded_type_impl(type)
llvm_type type
private def llvm_embedded_type_impl(type, wants_size)
llvm_type(type, wants_size)
end

def llvm_embedded_c_type(type : ProcInstanceType)
def llvm_embedded_c_type(type : ProcInstanceType, wants_size = false)
proc_type(type)
end

def llvm_embedded_c_type(type)
llvm_embedded_type type
def llvm_embedded_c_type(type, wants_size = false)
llvm_embedded_type(type, wants_size)
end

def llvm_c_type(type : ProcInstanceType)
20 changes: 18 additions & 2 deletions src/compiler/crystal/semantic/recursive_struct_checker.cr
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ module Crystal
if struct?(type)
target = type
checked = Set(Type).new
path = [] of Var
path = [] of Var | Type
check_recursive_instance_var_container(target, type, checked, path)
end

@@ -82,6 +82,15 @@ module Crystal
type.union_types.each do |union_type|
check_recursive(target, union_type, checked, path)
end
when NonGenericModuleType
path.push type
# Check if the module is composed, recursively, of the target struct
type.raw_including_types.try &.each do |module_type|
path.push module_type
check_recursive(target, module_type, checked, path)
path.pop
end
path.pop
end
end

@@ -99,7 +108,14 @@ module Crystal
end

def path_to_s(path)
path.join(" -> ") { |var| "`#{var.name} : #{var.type}`" }
path.join(" -> ") do |var_or_type|
case var_or_type
when Var
"`#{var_or_type.name} : #{var_or_type.type}`"
else
"`#{var_or_type}`"
end
end
end

def struct?(type)

0 comments on commit a27799a

Please sign in to comment.