Skip to content

Commit

Permalink
don't memoize comptime functions if they can mutate state via parameters
Browse files Browse the repository at this point in the history
closes #639
  • Loading branch information
andrewrk committed Mar 9, 2018
1 parent aaf2230 commit 6db9be8
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 9 deletions.
9 changes: 8 additions & 1 deletion src/all_types.hpp
Expand Up @@ -1168,10 +1168,17 @@ struct TypeTableEntry {
LLVMTypeRef type_ref;
ZigLLVMDIType *di_type;

bool zero_bits;
bool zero_bits; // this is denormalized data
bool is_copyable;
bool gen_h_loop_flag;

// This is denormalized data. The simplest type that has this
// flag set to true is a mutable pointer. A const pointer has
// the same value for this flag as the child type.
// If a struct has any fields that have this flag true, then
// the flag is true for the struct.
bool can_mutate_state_through_it;

union {
TypeTableEntryPointer pointer;
TypeTableEntryInt integral;
Expand Down
28 changes: 28 additions & 0 deletions src/analyze.cpp
Expand Up @@ -398,6 +398,7 @@ TypeTableEntry *get_pointer_to_type_extra(CodeGen *g, TypeTableEntry *child_type

TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdPointer);
entry->is_copyable = true;
entry->can_mutate_state_through_it = is_const ? child_type->can_mutate_state_through_it : true;

const char *const_str = is_const ? "const " : "";
const char *volatile_str = is_volatile ? "volatile " : "";
Expand Down Expand Up @@ -482,6 +483,7 @@ TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type) {
assert(child_type->type_ref || child_type->zero_bits);
assert(child_type->di_type);
entry->is_copyable = type_is_copyable(g, child_type);
entry->can_mutate_state_through_it = child_type->can_mutate_state_through_it;

buf_resize(&entry->name, 0);
buf_appendf(&entry->name, "?%s", buf_ptr(&child_type->name));
Expand Down Expand Up @@ -572,6 +574,7 @@ TypeTableEntry *get_error_union_type(CodeGen *g, TypeTableEntry *err_set_type, T
entry->is_copyable = true;
assert(payload_type->di_type);
ensure_complete_type(g, payload_type);
entry->can_mutate_state_through_it = payload_type->can_mutate_state_through_it;

buf_resize(&entry->name, 0);
buf_appendf(&entry->name, "%s!%s", buf_ptr(&err_set_type->name), buf_ptr(&payload_type->name));
Expand Down Expand Up @@ -730,6 +733,7 @@ TypeTableEntry *get_slice_type(CodeGen *g, TypeTableEntry *ptr_type) {

TypeTableEntry *entry = new_type_table_entry(TypeTableEntryIdStruct);
entry->is_copyable = true;
entry->can_mutate_state_through_it = ptr_type->can_mutate_state_through_it;

// replace the & with [] to go from a ptr type name to a slice type name
buf_resize(&entry->name, 0);
Expand Down Expand Up @@ -1735,6 +1739,8 @@ TypeTableEntry *get_struct_type(CodeGen *g, const char *type_name, const char *f
struct_type->data.structure.gen_field_count += 1;
} else {
field->gen_index = SIZE_MAX;
struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it ||
field->type_entry->can_mutate_state_through_it;
}

auto prev_entry = struct_type->data.structure.fields_by_name.put_unique(field->name, field);
Expand Down Expand Up @@ -2475,6 +2481,9 @@ static void resolve_struct_zero_bits(CodeGen *g, TypeTableEntry *struct_type) {
if (!type_has_bits(field_type))
continue;

struct_type->can_mutate_state_through_it = struct_type->can_mutate_state_through_it ||
field_type->can_mutate_state_through_it;

if (gen_field_index == 0) {
if (struct_type->data.structure.layout == ContainerLayoutPacked) {
struct_type->data.structure.abi_alignment = 1;
Expand Down Expand Up @@ -2662,6 +2671,8 @@ static void resolve_union_zero_bits(CodeGen *g, TypeTableEntry *union_type) {
}
}
union_field->type_entry = field_type;
union_type->can_mutate_state_through_it = union_type->can_mutate_state_through_it ||
field_type->can_mutate_state_through_it;

if (field_node->data.struct_field.value != nullptr && !decl_node->data.container_decl.auto_enum) {
ErrorMsg *msg = add_node_error(g, field_node->data.struct_field.value,
Expand Down Expand Up @@ -4565,6 +4576,23 @@ bool generic_fn_type_id_eql(GenericFnTypeId *a, GenericFnTypeId *b) {
return true;
}

bool fn_eval_cacheable(Scope *scope) {
while (scope) {
if (scope->id == ScopeIdVarDecl) {
ScopeVarDecl *var_scope = (ScopeVarDecl *)scope;
if (var_scope->var->value->type->can_mutate_state_through_it)
return false;
} else if (scope->id == ScopeIdFnDef) {
return true;
} else {
zig_unreachable();
}

scope = scope->parent;
}
zig_unreachable();
}

uint32_t fn_eval_hash(Scope* scope) {
uint32_t result = 0;
while (scope) {
Expand Down
1 change: 1 addition & 0 deletions src/analyze.hpp
Expand Up @@ -195,5 +195,6 @@ TypeTableEntry *get_auto_err_set_type(CodeGen *g, FnTableEntry *fn_entry);

uint32_t get_coro_frame_align_bytes(CodeGen *g);
bool fn_type_can_fail(FnTypeId *fn_type_id);
bool fn_eval_cacheable(Scope *scope);

#endif
17 changes: 11 additions & 6 deletions src/ir.cpp
Expand Up @@ -11830,12 +11830,15 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
return_type = specified_return_type;
}

IrInstruction *result;
bool cacheable = fn_eval_cacheable(exec_scope);
IrInstruction *result = nullptr;
if (cacheable) {
auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope);
if (entry)
result = entry->value;
}

auto entry = ira->codegen->memoized_fn_eval_table.maybe_get(exec_scope);
if (entry) {
result = entry->value;
} else {
if (result == nullptr) {
// Analyze the fn body block like any other constant expression.
AstNode *body_node = fn_entry->body_node;
result = ir_eval_const_value(ira->codegen, exec_scope, body_node, return_type,
Expand All @@ -11859,7 +11862,9 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
}
}

ira->codegen->memoized_fn_eval_table.put(exec_scope, result);
if (cacheable) {
ira->codegen->memoized_fn_eval_table.put(exec_scope, result);
}

if (type_is_invalid(result->value.type))
return ira->codegen->builtin_types.entry_invalid;
Expand Down
3 changes: 1 addition & 2 deletions std/sort.zig
Expand Up @@ -964,8 +964,7 @@ fn u8desc(lhs: &const u8, rhs: &const u8) bool {

test "stable sort" {
testStableSort();
// TODO: uncomment this after https://github.com/zig-lang/zig/issues/639
//comptime testStableSort();
comptime testStableSort();
}
fn testStableSort() void {
var expected = []IdAndValue {
Expand Down
28 changes: 28 additions & 0 deletions test/cases/eval.zig
Expand Up @@ -420,3 +420,31 @@ test "binary math operator in partially inlined function" {
assert(s[2] == 0x90a0b0c);
assert(s[3] == 0xd0e0f10);
}


test "comptime function with the same args is memoized" {
comptime {
assert(MakeType(i32) == MakeType(i32));
assert(MakeType(i32) != MakeType(f64));
}
}

fn MakeType(comptime T: type) type {
return struct {
field: T,
};
}

test "comptime function with mutable pointer is not memoized" {
comptime {
var x: i32 = 1;
const ptr = &x;
increment(ptr);
increment(ptr);
assert(x == 3);
}
}

fn increment(value: &i32) void {
*value += 1;
}

0 comments on commit 6db9be8

Please sign in to comment.