Skip to content

Commit

Permalink
Showing 10 changed files with 469 additions and 105 deletions.
124 changes: 124 additions & 0 deletions spec/compiler/codegen/class_var_spec.cr
Original file line number Diff line number Diff line change
@@ -373,4 +373,128 @@ describe "Codegen: class var" do
z
)).to_i.should eq(10)
end

it "doesn't inherit class var value in subclass" do
run(%(
class Foo
@@var = 1
def self.var
@@var
end
def self.var=(@@var)
end
end
class Bar < Foo
end
Foo.var = 2
Bar.var
)).to_i.should eq(1)
end

it "doesn't inherit class var value in module" do
run(%(
module Moo
@@var = 1
def var
@@var
end
def self.var=(@@var)
end
end
class Foo
include Moo
end
Moo.var = 2
Foo.new.var
)).to_i.should eq(1)
end

it "reads class var from virtual type" do
run(%(
class Foo
@@var = 1
def self.var=(@@var)
end
def self.var
@@var
end
def var
@@var
end
end
class Bar < Foo
end
Bar.var = 2
ptr = Pointer(Foo).malloc(1_u64)
ptr.value = Bar.new
ptr.value.var
)).to_i.should eq(2)
end

it "reads class var from virtual type metaclass" do
run(%(
class Foo
@@var = 1
def self.var=(@@var)
end
def self.var
@@var
end
end
class Bar < Foo
end
Bar.var = 2
ptr = Pointer(Foo.class).malloc(1_u64)
ptr.value = Bar
ptr.value.var
)).to_i.should eq(2)
end

it "writes class var from virtual type" do
run(%(
class Foo
@@var = 1
def self.var=(@@var)
end
def self.var
@@var
end
def var=(@@var)
end
end
class Bar < Foo
end
ptr = Pointer(Foo).malloc(1_u64)
ptr.value = Bar.new
ptr.value.var = 2
Bar.var
)).to_i.should eq(2)
end
end
36 changes: 0 additions & 36 deletions spec/compiler/type_inference/class_spec.cr
Original file line number Diff line number Diff line change
@@ -983,40 +983,4 @@ describe "Type inference: class" do
initialize
)) { int32 }
end

it "error when using class var on virtual type" do
assert_error %(
class Foo
@@a = 1
def a
@@a
end
end
class Bar < Foo
end
(Bar.new as Foo).a
),
"can't access class variable from a type that is Foo or any of its subclasses"
end

it "error when using class var on virtual metaclass type" do
assert_error %(
class Foo
@@a = 1
def self.a
@@a
end
end
class Bar < Foo
end
(Bar.new as Foo).class.a
),
"can't access class variable from a type that is Foo or any of its subclasses"
end
end
75 changes: 75 additions & 0 deletions spec/compiler/type_inference/class_var_spec.cr
Original file line number Diff line number Diff line change
@@ -342,4 +342,79 @@ describe "Type inference: class var" do
),
"can't use Int as the type of a class variable yet, use a more specific type"
end

it "can find class var in subclass" do
assert_type(%(
class Foo
@@var = 1
end
class Bar < Foo
def self.var
@@var
end
end
Bar.var
)) { int32 }
end

it "can find class var through included module" do
assert_type(%(
module Moo
@@var = 1
end
class Bar
include Moo
def self.var
@@var
end
end
Bar.var
)) { int32 }
end

it "errors if redefining class var type in subclass" do
assert_error %(
class Foo
@@x : Int32
end
class Bar < Foo
@@x : Float64
end
),
"class variable '@@x' of Bar is already defined as Int32 in Foo"
end

it "errors if redefining class var type in subclass, with guess" do
assert_error %(
class Foo
@@x = 1
end
class Bar < Foo
@@x = 'a'
end
),
"class variable '@@x' of Bar is already defined as Int32 in Foo"
end

it "errors if redefining class var type in included module" do
assert_error %(
module Moo
@@x : Int32
end
class Bar
include Moo
@@x : Float64
end
),
"class variable '@@x' of Bar is already defined as Int32 in Moo"
end
end
198 changes: 155 additions & 43 deletions src/compiler/crystal/codegen/class_var.cr
Original file line number Diff line number Diff line change
@@ -6,29 +6,29 @@ require "./codegen"
# variable is read. There's an "initialized" flag too.

class Crystal::CodeGenVisitor
def declare_class_var(class_var)
global_name = class_var_global_name(class_var)
def declare_class_var(owner, name, type, thread_local)
global_name = class_var_global_name(owner, name)
global = @main_mod.globals[global_name]? ||
@main_mod.globals.add(llvm_type(class_var.type), global_name)
@main_mod.globals.add(llvm_type(type), global_name)
global.linkage = LLVM::Linkage::Internal if @single_module
global.thread_local = true if class_var.thread_local?
global.thread_local = true if thread_local
global
end

def declare_class_var_initialized_flag(class_var)
initialized_flag_name = class_var_global_initialized_name(class_var)
def declare_class_var_initialized_flag(owner, name, thread_local)
initialized_flag_name = class_var_global_initialized_name(owner, name)
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(LLVM::Int1, initialized_flag_name)
initialized_flag.initializer = int1(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
initialized_flag.thread_local = true if class_var.thread_local?
initialized_flag.thread_local = true if thread_local
end
initialized_flag
end

def declare_class_var_and_initialized_flag(class_var)
{declare_class_var(class_var), declare_class_var_initialized_flag(class_var)}
def declare_class_var_and_initialized_flag(owner, name, type, thread_local)
{declare_class_var(owner, name, type, thread_local), declare_class_var_initialized_flag(owner, name, thread_local)}
end

def initialize_class_var(class_var : ClassVar)
@@ -37,12 +37,16 @@ class Crystal::CodeGenVisitor

def initialize_class_var(class_var : MetaTypeVar)
initializer = class_var.initializer
initialize_class_var(initializer) if initializer

if initializer
initialize_class_var(initializer.owner, initializer.name, initializer.meta_vars, initializer.node)
end
end

def initialize_class_var(initializer : ClassVarInitializer)
class_var = initializer.owner.class_vars[initializer.name]
global, initialized_flag = declare_class_var_and_initialized_flag(class_var)
def initialize_class_var(owner : ClassVarContainer, name : String, meta_vars : MetaVars, node : ASTNode)
class_var = owner.lookup_class_var(name)

global, initialized_flag = declare_class_var_and_initialized_flag(owner, name, class_var.type, class_var.thread_local?)

initialized_block, not_initialized_block = new_blocks "initialized", "not_initialized"

@@ -52,8 +56,9 @@ class Crystal::CodeGenVisitor
position_at_end not_initialized_block
store int1(1), initialized_flag

init_function_name = "~#{class_var_global_initialized_name(class_var)}"
func = @main_mod.functions[init_function_name]? || create_initialize_class_var_function(init_function_name, class_var)
init_function_name = "~#{class_var_global_initialized_name(owner, name)}"
func = @main_mod.functions[init_function_name]? ||
create_initialize_class_var_function(init_function_name, owner, name, class_var.type, class_var.thread_local?, meta_vars, node)
func = check_main_fun init_function_name, func
call func

@@ -64,37 +69,36 @@ class Crystal::CodeGenVisitor
global
end

def create_initialize_class_var_function(fun_name, class_var)
global, initialized_flag = declare_class_var_and_initialized_flag(class_var)
initializer = class_var.initializer.not_nil!
def create_initialize_class_var_function(fun_name, owner, name, type, thread_local, meta_vars, node)
global, initialized_flag = declare_class_var_and_initialized_flag(owner, name, type, thread_local)

define_main_function(fun_name, ([] of LLVM::Type), LLVM::Void, needs_alloca: true) do |func|
with_cloned_context do
# "self" in a constant is the class_var owner
context.type = class_var.owner
context.type = owner

# Start with fresh variables
context.vars = LLVMVars.new

alloca_vars initializer.meta_vars
alloca_vars meta_vars

request_value do
accept initializer.node
accept node
end

node_type = initializer.node.type
node_type = node.type

if node_type.nil_type? && !class_var.type.nil_type?
global.initializer = llvm_type(class_var.type).null
elsif @last.constant? && (class_var.type.is_a?(PrimitiveType) || class_var.type.is_a?(EnumType))
if node_type.nil_type? && !type.nil_type?
global.initializer = llvm_type(type).null
elsif @last.constant? && (type.is_a?(PrimitiveType) || type.is_a?(EnumType))
global.initializer = @last
else
if class_var.type.passed_by_value?
global.initializer = llvm_type(class_var.type).undef
if type.passed_by_value?
global.initializer = llvm_type(type).undef
else
global.initializer = llvm_type(class_var.type).null
global.initializer = llvm_type(type).null
end
assign global, class_var.type, initializer.node.type, @last
assign global, type, node.type, @last
end

ret
@@ -104,22 +108,129 @@ class Crystal::CodeGenVisitor

def read_class_var(node : ClassVar)
class_var = node.var
read_class_var(node, class_var)
end

def read_class_var(node, class_var : MetaTypeVar)
last = read_class_var_ptr(node, class_var)
to_lhs last, class_var.type
end

def read_class_var_ptr(node : ClassVar)
class_var = node.var
read_class_var_ptr(node, class_var)
end

def read_class_var_ptr(node, class_var : MetaTypeVar)
owner = class_var.owner
case owner
when VirtualType
return read_virtual_class_var_ptr(node, class_var, owner)
when VirtualMetaclassType
return read_virtual_metaclass_class_var_ptr(node, class_var, owner)
end

initializer = class_var.initializer
unless initializer
return read_global class_var_global_name(node.var), node.type, node.var
return get_global class_var_global_name(class_var.owner, class_var.name), class_var.type, class_var
end

read_function_name = "~#{class_var_global_name(class_var.owner, class_var.name)}:read"
func = @main_mod.functions[read_function_name]? ||
create_read_class_var_function(read_function_name, class_var.owner, class_var.name, class_var.type, class_var.thread_local?, initializer.meta_vars, initializer.node)
func = check_main_fun read_function_name, func
call func
end

def read_virtual_class_var_ptr(node, class_var, owner)
self_type_id = type_id(llvm_self, owner)
read_function_name = "~#{class_var_global_name(owner, class_var.name)}:read"
func = @main_mod.functions[read_function_name]? ||
create_read_virtual_class_var_ptr_function(read_function_name, node, class_var, owner)
func = check_main_fun read_function_name, func
call func, self_type_id
end

def create_read_virtual_class_var_ptr_function(fun_name, node, class_var, owner)
define_main_function(fun_name, [LLVM::Int32], llvm_type(class_var.type).pointer) do |func|
self_type_id = func.params[0]

cmp = equal?(self_type_id, type_id(owner.base_type))

current_type_label, next_type_label = new_blocks "current_type", "next_type"
cond cmp, current_type_label, next_type_label

position_at_end current_type_label
last = read_class_var_ptr(node, owner.base_type.lookup_class_var(node.name))
ret last

position_at_end next_type_label

owner.base_type.all_subclasses.each do |subclass|
next unless subclass.is_a?(ClassVarContainer)

cmp = equal?(self_type_id, type_id(subclass))

current_type_label, next_type_label = new_blocks "current_type", "next_type"
cond cmp, current_type_label, next_type_label

position_at_end current_type_label
last = read_class_var_ptr(node, subclass.lookup_class_var(node.name))
ret last

position_at_end next_type_label
end

unreachable
end
end

read_function_name = "~#{class_var_global_name(class_var)}:read"
func = @main_mod.functions[read_function_name]? || create_read_class_var_function(read_function_name, class_var)
def read_virtual_metaclass_class_var_ptr(node, class_var, owner)
self_type_id = type_id(llvm_self, owner)
read_function_name = "~#{class_var_global_name(owner, class_var.name)}:read"
func = @main_mod.functions[read_function_name]? ||
create_read_virtual_metaclass_var_ptr_function(read_function_name, node, class_var, owner)
func = check_main_fun read_function_name, func
@last = call func
@last = to_lhs @last, class_var.type
call func, self_type_id
end

def create_read_virtual_metaclass_var_ptr_function(fun_name, node, class_var, owner)
define_main_function(fun_name, [LLVM::Int32], llvm_type(class_var.type).pointer) do |func|
self_type_id = func.params[0]

cmp = equal?(self_type_id, type_id(owner.base_type.metaclass))

current_type_label, next_type_label = new_blocks "current_type", "next_type"
cond cmp, current_type_label, next_type_label

position_at_end current_type_label
last = read_class_var_ptr(node, owner.base_type.lookup_class_var(node.name))
ret last

position_at_end next_type_label

owner.base_type.instance_type.all_subclasses.each do |subclass|
next unless subclass.is_a?(ClassVarContainer)

cmp = equal?(self_type_id, type_id(subclass.metaclass))

current_type_label, next_type_label = new_blocks "current_type", "next_type"
cond cmp, current_type_label, next_type_label

position_at_end current_type_label
last = read_class_var_ptr(node, subclass.lookup_class_var(node.name))
ret last

position_at_end next_type_label
end
unreachable
end
end

def create_read_class_var_function(fun_name, class_var)
global, initialized_flag = declare_class_var_and_initialized_flag(class_var)
def create_read_class_var_function(fun_name, owner, name, type, thread_local, meta_vars, node)
global, initialized_flag = declare_class_var_and_initialized_flag(owner, name, type, thread_local)

define_main_function(fun_name, ([] of LLVM::Type), llvm_type(class_var.type).pointer) do |func|
define_main_function(fun_name, ([] of LLVM::Type), llvm_type(type).pointer) do |func|
initialized_block, not_initialized_block = new_blocks "initialized", "not_initialized"

initialized = load(initialized_flag)
@@ -128,8 +239,9 @@ class Crystal::CodeGenVisitor
position_at_end not_initialized_block
store int1(1), initialized_flag

init_function_name = "~#{class_var_global_initialized_name(class_var)}"
func = @main_mod.functions[init_function_name]? || create_initialize_class_var_function(init_function_name, class_var)
init_function_name = "~#{class_var_global_initialized_name(owner, name)}"
func = @main_mod.functions[init_function_name]? ||
create_initialize_class_var_function(init_function_name, owner, name, type, thread_local, meta_vars, node)
call func

br initialized_block
@@ -140,11 +252,11 @@ class Crystal::CodeGenVisitor
end
end

def class_var_global_name(node)
"#{node.owner}#{node.name.gsub('@', ':')}"
def class_var_global_name(owner : Type, name : String)
"#{owner}#{name.gsub('@', ':')}"
end

def class_var_global_initialized_name(node)
"#{node.owner}#{node.name.gsub('@', ':')}:init"
def class_var_global_initialized_name(owner : Type, name : String)
"#{owner}#{name.gsub('@', ':')}:init"
end
end
12 changes: 8 additions & 4 deletions src/compiler/crystal/codegen/codegen.cr
Original file line number Diff line number Diff line change
@@ -197,7 +197,7 @@ module Crystal
class_var = initializer.owner.class_vars[initializer.name]
next if class_var.thread_local?

initialize_class_var(initializer)
initialize_class_var(initializer.owner, initializer.name, initializer.meta_vars, initializer.node)
end
end
end
@@ -407,7 +407,7 @@ module Crystal
if node_exp.var.initializer
initialize_class_var(node_exp)
end
get_global class_var_global_name(node_exp.var), node_exp.type, node_exp.var
get_global class_var_global_name(node_exp.var.owner, node_exp.var.name), node_exp.type, node_exp.var
when Global
get_global node_exp.name, node_exp.type, node_exp.var
when Path
@@ -842,14 +842,16 @@ module Crystal

return if value.no_returns?

last = @last

set_current_debug_location node if @debug
ptr = case target
when InstanceVar
instance_var_ptr (context.type.as(InstanceVarContainer)), target.name, llvm_self_ptr
when Global
get_global target.name, target_type, target.var
when ClassVar
get_global class_var_global_name(target.var), target_type, target.var
read_class_var_ptr(target)
when Var
# Can't assign void
return if target.type.void?
@@ -869,6 +871,8 @@ module Crystal
node.raise "Unknown assign target in codegen: #{target}"
end

@last = last

if target.is_a?(Var) && target.special_var? && !target_type.reference_like?
# For special vars that are not reference-like, the function argument will
# be a pointer to the struct value. So, we need to first cast the value to
@@ -1018,7 +1022,7 @@ module Crystal
end

def visit(node : ClassVar)
read_class_var(node)
@last = read_class_var(node)
end

def read_global(name, type, real_var)
10 changes: 1 addition & 9 deletions src/compiler/crystal/semantic/base_type_visitor.cr
Original file line number Diff line number Diff line change
@@ -919,20 +919,12 @@ module Crystal
node.raise "can't use class variables in generic types"
end

if scope.is_a?(VirtualType)
node.raise "can't access class variable from a type that is #{scope.base_type.instance_type} or any of its subclasses"
end

if scope.is_a?(VirtualMetaclassType)
node.raise "can't access class variable from a type that is #{scope.base_type.instance_type} or any of its subclasses"
end

scope.as(ClassVarContainer)
end

def lookup_class_var(node)
class_var_owner = class_var_owner(node)
var = class_var_owner.class_vars[node.name]?
var = class_var_owner.lookup_class_var?(node.name)
unless var
undefined_class_variable(node, class_var_owner)
end
30 changes: 27 additions & 3 deletions src/compiler/crystal/semantic/type_declaration_processor.cr
Original file line number Diff line number Diff line change
@@ -151,7 +151,9 @@ module Crystal
# give an error
check_nilable_instance_vars

check_errors
check_cant_use_type_errors

check_class_var_errors(type_decl_visitor.class_vars, type_guess_visitor.class_vars)

node
end
@@ -189,7 +191,7 @@ module Crystal
# If the variable is gueseed to be nilable because it is not initialized
# in all of the initialize methods, and the explicit type is not nilable,
# give an error right now
if !var.type.includes_type?(@program.nil)
if instance_var && !var.type.includes_type?(@program.nil)
if nilable_instance_var?(owner, name)
raise_not_initialized_in_all_initialize(var, name, owner)
end
@@ -547,7 +549,7 @@ module Crystal
@errors[type]?.try &.delete(name)
end

private def check_errors
private def check_cant_use_type_errors
@errors.each do |type, entries|
entries.each do |name, error|
case name
@@ -562,6 +564,28 @@ module Crystal
end
end

private def check_class_var_errors(type_decl_class_vars, guesser_class_vars)
{type_decl_class_vars, guesser_class_vars}.each do |all_vars|
all_vars.each do |owner, vars|
vars.each do |name, info|
owner_class_var = owner.lookup_class_var?(name)
next unless owner_class_var

owner.ancestors.each do |ancestor|
next unless ancestor.is_a?(ClassVarContainer)

ancestor_class_var = ancestor.class_vars?.try &.[name]?
next unless ancestor_class_var

if owner_class_var.type != ancestor_class_var.type
raise TypeException.new("class variable '#{name}' of #{owner} is already defined as #{ancestor_class_var.type} in #{ancestor}", info.location)
end
end
end
end
end
end

private def raise_not_initialized_in_all_initialize(node : ASTNode, name, owner)
node.raise "instance variable '#{name}' of #{owner} was not initialized in all of the 'initialize' methods, rendering it nilable"
end
6 changes: 3 additions & 3 deletions src/compiler/crystal/semantic/type_declaration_visitor.cr
Original file line number Diff line number Diff line change
@@ -35,7 +35,7 @@ module Crystal

# The type of class variables. The last one wins.
# This is type => variables.
@class_vars = {} of ClassVarContainer => Hash(String, Type)
@class_vars = {} of ClassVarContainer => Hash(String, TypeDeclarationWithLocation)
end

def visit(node : ClassDef)
@@ -164,8 +164,8 @@ module Crystal
owner = class_var_owner(node)
var_type = lookup_type(node.declared_type)
var_type = check_declare_var_type(node, var_type, "a class variable")
owner_vars = @class_vars[owner] ||= {} of String => Type
owner_vars[var.name] = var_type.virtual_type
owner_vars = @class_vars[owner] ||= {} of String => TypeDeclarationWithLocation
owner_vars[var.name] = TypeDeclarationWithLocation.new(var_type.virtual_type, node.location.not_nil!)
end

def declare_global_var(node, var)
15 changes: 8 additions & 7 deletions src/compiler/crystal/semantic/type_guess_visitor.cr
Original file line number Diff line number Diff line change
@@ -17,8 +17,9 @@ module Crystal
class TypeInfo
property type
property outside_def
getter location

def initialize(@type : Type)
def initialize(@type : Type, @location : Location)
@outside_def = false
end
end
@@ -257,11 +258,11 @@ module Crystal
next if owner.class_vars[target.name]?

owner_vars = @class_vars[owner] ||= {} of String => TypeInfo
add_type_info(owner_vars, target.name, tuple_type)
add_type_info(owner_vars, target.name, tuple_type, target)
when Global
next if @mod.global_vars[target.name]?

add_type_info(@globals, target.name, tuple_type)
add_type_info(@globals, target.name, tuple_type, target)
end
end
end
@@ -277,7 +278,7 @@ module Crystal

type = guess_type(value)
if type
add_type_info(@globals, target.name, type)
add_type_info(@globals, target.name, type, target)
end
type
end
@@ -293,7 +294,7 @@ module Crystal
type = guess_type(value)
if type
owner_vars = @class_vars[owner] ||= {} of String => TypeInfo
add_type_info(owner_vars, target.name, type)
add_type_info(owner_vars, target.name, type, target)
end
type
end
@@ -432,10 +433,10 @@ module Crystal
type_vars
end

def add_type_info(vars, name, type)
def add_type_info(vars, name, type, node)
info = vars[name]?
unless info
info = TypeInfo.new(type)
info = TypeInfo.new(type, node.location.not_nil!)
info.outside_def = true if @outside_def
vars[name] = info
else
68 changes: 68 additions & 0 deletions src/compiler/crystal/types.cr
Original file line number Diff line number Diff line change
@@ -801,6 +801,36 @@ module Crystal
def class_vars
@class_vars ||= {} of String => MetaTypeVar
end

def class_vars?
@class_vars
end

def lookup_class_var(name)
lookup_class_var?(name).not_nil!
end

def lookup_class_var?(name)
class_var = @class_vars.try &.[name]?
return class_var if class_var

ancestors.each do |ancestor|
next unless ancestor.is_a?(ClassVarContainer)

class_var = ancestor.class_vars?.try &.[name]?
if class_var
var = MetaTypeVar.new(name, class_var.type)
var.owner = self
var.thread_local = class_var.thread_local?
var.initializer = class_var.initializer
var.bind_to(class_var)
self.class_vars[name] = var
return var
end
end

nil
end
end

module SubclassObservable
@@ -2960,6 +2990,7 @@ module Crystal
include DefInstanceContainer
include VirtualTypeLookup
include InstanceVarContainer
include ClassVarContainer

getter program : Program
getter base_type : NonGenericClassType
@@ -3036,6 +3067,24 @@ module Crystal
end
end

def lookup_class_var?(name)
class_var = @class_vars.try &.[name]?
return class_var if class_var

class_var = base_type.lookup_class_var?(name)
if class_var
var = MetaTypeVar.new(name, class_var.type)
var.owner = self
var.thread_local = class_var.thread_local?
var.initializer = class_var.initializer
var.bind_to(class_var)
self.class_vars[name] = var
return var
end

nil
end

def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true)
base_type.to_s(io)
io << "+"
@@ -3049,6 +3098,7 @@ module Crystal
class VirtualMetaclassType < Type
include DefInstanceContainer
include VirtualTypeLookup
include ClassVarContainer

getter program : Program
getter instance_type : VirtualType
@@ -3090,6 +3140,24 @@ module Crystal
end
end

def lookup_class_var?(name)
class_var = @class_vars.try &.[name]?
return class_var if class_var

class_var = base_type.instance_type.lookup_class_var?(name)
if class_var
var = MetaTypeVar.new(name, class_var.type)
var.owner = self
var.thread_local = class_var.thread_local?
var.initializer = class_var.initializer
var.bind_to(class_var)
self.class_vars[name] = var
return var
end

nil
end

def to_s_with_options(io : IO, skip_union_parens : Bool = false, generic_args : Bool = true)
instance_type.to_s(io)
io << ":Class"

0 comments on commit 4e8aa74

Please sign in to comment.