Skip to content

Commit

Permalink
add @newStackCall builtin function
Browse files Browse the repository at this point in the history
See #1006
  • Loading branch information
andrewrk committed May 12, 2018
1 parent 4277762 commit a6ae451
Show file tree
Hide file tree
Showing 8 changed files with 298 additions and 16 deletions.
47 changes: 44 additions & 3 deletions doc/langref.html.in
Expand Up @@ -4485,17 +4485,58 @@ mem.set(u8, dest, c);</code></pre>
If no overflow or underflow occurs, returns <code>false</code>.
</p>
{#header_close#}
{#header_open|@newStackCall#}
<pre><code class="zig">@newStackCall(new_stack: []u8, function: var, args: ...) -&gt; var</code></pre>
<p>
This calls a function, in the same way that invoking an expression with parentheses does. However,
instead of using the same stack as the caller, the function uses the stack provided in the <code>new_stack</code>
parameter.
</p>
{#code_begin|test#}
const std = @import("std");
const assert = std.debug.assert;

var new_stack_bytes: [1024]u8 = undefined;

test "calling a function with a new stack" {
const arg = 1234;

const a = @newStackCall(new_stack_bytes[0..512], targetFunction, arg);
const b = @newStackCall(new_stack_bytes[512..], targetFunction, arg);
_ = targetFunction(arg);

assert(arg == 1234);
assert(a < b);
}

fn targetFunction(x: i32) usize {
assert(x == 1234);

var local_variable: i32 = 42;
const ptr = &local_variable;
*ptr += 1;

assert(local_variable == 43);
return @ptrToInt(ptr);
}
{#code_end#}
{#header_close#}
{#header_open|@noInlineCall#}
<pre><code class="zig">@noInlineCall(function: var, args: ...) -&gt; var</code></pre>
<p>
This calls a function, in the same way that invoking an expression with parentheses does:
</p>
<pre><code class="zig">const assert = @import("std").debug.assert;
{#code_begin|test#}
const assert = @import("std").debug.assert;

test "noinline function call" {
assert(@noInlineCall(add, 3, 9) == 12);
}

fn add(a: i32, b: i32) -&gt; i32 { a + b }</code></pre>
fn add(a: i32, b: i32) i32 {
return a + b;
}
{#code_end#}
<p>
Unlike a normal function call, however, <code>@noInlineCall</code> guarantees that the call
will not be inlined. If the call must be inlined, a compile error is emitted.
Expand Down Expand Up @@ -6451,7 +6492,7 @@ hljs.registerLanguage("zig", function(t) {
a = t.IR + "\\s*\\(",
c = {
keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo newStackCall",
literal: "true false null undefined"
},
n = [e, t.CLCM, t.CBCM, s, r];
Expand Down
7 changes: 7 additions & 0 deletions src/all_types.hpp
Expand Up @@ -1340,6 +1340,7 @@ enum BuiltinFnId {
BuiltinFnIdOffsetOf,
BuiltinFnIdInlineCall,
BuiltinFnIdNoInlineCall,
BuiltinFnIdNewStackCall,
BuiltinFnIdTypeId,
BuiltinFnIdShlExact,
BuiltinFnIdShrExact,
Expand Down Expand Up @@ -1656,8 +1657,13 @@ struct CodeGen {
LLVMValueRef coro_alloc_helper_fn_val;
LLVMValueRef merge_err_ret_traces_fn_val;
LLVMValueRef add_error_return_trace_addr_fn_val;
LLVMValueRef stacksave_fn_val;
LLVMValueRef stackrestore_fn_val;
LLVMValueRef write_register_fn_val;
bool error_during_imports;

LLVMValueRef sp_md_node;

const char **clang_argv;
size_t clang_argv_len;
ZigList<const char *> lib_dirs;
Expand Down Expand Up @@ -2280,6 +2286,7 @@ struct IrInstructionCall {
bool is_async;

IrInstruction *async_allocator;
IrInstruction *new_stack;
};

struct IrInstructionConst {
Expand Down
99 changes: 97 additions & 2 deletions src/codegen.cpp
Expand Up @@ -938,6 +938,53 @@ static LLVMValueRef get_memcpy_fn_val(CodeGen *g) {
return g->memcpy_fn_val;
}

static LLVMValueRef get_stacksave_fn_val(CodeGen *g) {
if (g->stacksave_fn_val)
return g->stacksave_fn_val;

// declare i8* @llvm.stacksave()

LLVMTypeRef fn_type = LLVMFunctionType(LLVMPointerType(LLVMInt8Type(), 0), nullptr, 0, false);
g->stacksave_fn_val = LLVMAddFunction(g->module, "llvm.stacksave", fn_type);
assert(LLVMGetIntrinsicID(g->stacksave_fn_val));

return g->stacksave_fn_val;
}

static LLVMValueRef get_stackrestore_fn_val(CodeGen *g) {
if (g->stackrestore_fn_val)
return g->stackrestore_fn_val;

// declare void @llvm.stackrestore(i8* %ptr)

LLVMTypeRef param_type = LLVMPointerType(LLVMInt8Type(), 0);
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), &param_type, 1, false);
g->stackrestore_fn_val = LLVMAddFunction(g->module, "llvm.stackrestore", fn_type);
assert(LLVMGetIntrinsicID(g->stackrestore_fn_val));

return g->stackrestore_fn_val;
}

static LLVMValueRef get_write_register_fn_val(CodeGen *g) {
if (g->write_register_fn_val)
return g->write_register_fn_val;

// declare void @llvm.write_register.i64(metadata, i64 @value)
// !0 = !{!"sp\00"}

LLVMTypeRef param_types[] = {
LLVMMetadataTypeInContext(LLVMGetGlobalContext()),
LLVMIntType(g->pointer_size_bytes * 8),
};

LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 2, false);
Buf *name = buf_sprintf("llvm.write_register.i%d", g->pointer_size_bytes * 8);
g->write_register_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
assert(LLVMGetIntrinsicID(g->write_register_fn_val));

return g->write_register_fn_val;
}

static LLVMValueRef get_coro_destroy_fn_val(CodeGen *g) {
if (g->coro_destroy_fn_val)
return g->coro_destroy_fn_val;
Expand Down Expand Up @@ -2901,6 +2948,38 @@ static size_t get_async_err_code_arg_index(CodeGen *g, FnTypeId *fn_type_id) {
return 1 + get_async_allocator_arg_index(g, fn_type_id);
}


static LLVMValueRef get_new_stack_addr(CodeGen *g, LLVMValueRef new_stack) {
LLVMValueRef ptr_field_ptr = LLVMBuildStructGEP(g->builder, new_stack, (unsigned)slice_ptr_index, "");
LLVMValueRef len_field_ptr = LLVMBuildStructGEP(g->builder, new_stack, (unsigned)slice_len_index, "");

LLVMValueRef ptr_value = gen_load_untyped(g, ptr_field_ptr, 0, false, "");
LLVMValueRef len_value = gen_load_untyped(g, len_field_ptr, 0, false, "");

LLVMValueRef ptr_addr = LLVMBuildPtrToInt(g->builder, ptr_value, LLVMTypeOf(len_value), "");
LLVMValueRef end_addr = LLVMBuildNUWAdd(g->builder, ptr_addr, len_value, "");
LLVMValueRef align_amt = LLVMConstInt(LLVMTypeOf(end_addr), get_abi_alignment(g, g->builtin_types.entry_usize), false);
LLVMValueRef align_adj = LLVMBuildURem(g->builder, end_addr, align_amt, "");
return LLVMBuildNUWSub(g->builder, end_addr, align_adj, "");
}

static void gen_set_stack_pointer(CodeGen *g, LLVMValueRef aligned_end_addr) {
LLVMValueRef write_register_fn_val = get_write_register_fn_val(g);

if (g->sp_md_node == nullptr) {
Buf *sp_reg_name = buf_create_from_str(arch_stack_pointer_register_name(&g->zig_target.arch));
LLVMValueRef str_node = LLVMMDString(buf_ptr(sp_reg_name), buf_len(sp_reg_name) + 1);
g->sp_md_node = LLVMMDNode(&str_node, 1);
}

LLVMValueRef params[] = {
g->sp_md_node,
aligned_end_addr,
};

LLVMBuildCall(g->builder, write_register_fn_val, params, 2, "");
}

static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstructionCall *instruction) {
LLVMValueRef fn_val;
TypeTableEntry *fn_type;
Expand Down Expand Up @@ -2967,8 +3046,23 @@ static LLVMValueRef ir_render_call(CodeGen *g, IrExecutable *executable, IrInstr
}

LLVMCallConv llvm_cc = get_llvm_cc(g, fn_type->data.fn.fn_type_id.cc);
LLVMValueRef result = ZigLLVMBuildCall(g->builder, fn_val,
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
LLVMValueRef result;

if (instruction->new_stack == nullptr) {
result = ZigLLVMBuildCall(g->builder, fn_val,
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
} else {
LLVMValueRef stacksave_fn_val = get_stacksave_fn_val(g);
LLVMValueRef stackrestore_fn_val = get_stackrestore_fn_val(g);

LLVMValueRef new_stack_addr = get_new_stack_addr(g, ir_llvm_value(g, instruction->new_stack));
LLVMValueRef old_stack_ref = LLVMBuildCall(g->builder, stacksave_fn_val, nullptr, 0, "");
gen_set_stack_pointer(g, new_stack_addr);
result = ZigLLVMBuildCall(g->builder, fn_val,
gen_param_values, (unsigned)gen_param_index, llvm_cc, fn_inline, "");
LLVMBuildCall(g->builder, stackrestore_fn_val, &old_stack_ref, 1, "");
}


for (size_t param_i = 0; param_i < fn_type_id->param_count; param_i += 1) {
FnGenParamInfo *gen_info = &fn_type->data.fn.gen_param_info[param_i];
Expand Down Expand Up @@ -6171,6 +6265,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdSqrt, "sqrt", 2);
create_builtin_fn(g, BuiltinFnIdInlineCall, "inlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNoInlineCall, "noInlineCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdNewStackCall, "newStackCall", SIZE_MAX);
create_builtin_fn(g, BuiltinFnIdTypeId, "typeId", 1);
create_builtin_fn(g, BuiltinFnIdShlExact, "shlExact", 2);
create_builtin_fn(g, BuiltinFnIdShrExact, "shrExact", 2);
Expand Down
66 changes: 57 additions & 9 deletions src/ir.cpp
Expand Up @@ -1102,7 +1102,8 @@ static IrInstruction *ir_build_union_field_ptr_from(IrBuilder *irb, IrInstructio

static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *source_node,
FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args,
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator)
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator,
IrInstruction *new_stack)
{
IrInstructionCall *call_instruction = ir_build_instruction<IrInstructionCall>(irb, scope, source_node);
call_instruction->fn_entry = fn_entry;
Expand All @@ -1113,23 +1114,27 @@ static IrInstruction *ir_build_call(IrBuilder *irb, Scope *scope, AstNode *sourc
call_instruction->arg_count = arg_count;
call_instruction->is_async = is_async;
call_instruction->async_allocator = async_allocator;
call_instruction->new_stack = new_stack;

if (fn_ref)
ir_ref_instruction(fn_ref, irb->current_basic_block);
for (size_t i = 0; i < arg_count; i += 1)
ir_ref_instruction(args[i], irb->current_basic_block);
if (async_allocator)
ir_ref_instruction(async_allocator, irb->current_basic_block);
if (new_stack != nullptr)
ir_ref_instruction(new_stack, irb->current_basic_block);

return &call_instruction->base;
}

static IrInstruction *ir_build_call_from(IrBuilder *irb, IrInstruction *old_instruction,
FnTableEntry *fn_entry, IrInstruction *fn_ref, size_t arg_count, IrInstruction **args,
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator)
bool is_comptime, FnInline fn_inline, bool is_async, IrInstruction *async_allocator,
IrInstruction *new_stack)
{
IrInstruction *new_instruction = ir_build_call(irb, old_instruction->scope,
old_instruction->source_node, fn_entry, fn_ref, arg_count, args, is_comptime, fn_inline, is_async, async_allocator);
old_instruction->source_node, fn_entry, fn_ref, arg_count, args, is_comptime, fn_inline, is_async, async_allocator, new_stack);
ir_link_new_instruction(new_instruction, old_instruction);
return new_instruction;
}
Expand Down Expand Up @@ -4303,7 +4308,37 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
}
FnInline fn_inline = (builtin_fn->id == BuiltinFnIdInlineCall) ? FnInlineAlways : FnInlineNever;

IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, fn_inline, false, nullptr);
IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, fn_inline, false, nullptr, nullptr);
return ir_lval_wrap(irb, scope, call, lval);
}
case BuiltinFnIdNewStackCall:
{
if (node->data.fn_call_expr.params.length == 0) {
add_node_error(irb->codegen, node, buf_sprintf("expected at least 1 argument, found 0"));
return irb->codegen->invalid_instruction;
}

AstNode *new_stack_node = node->data.fn_call_expr.params.at(0);
IrInstruction *new_stack = ir_gen_node(irb, new_stack_node, scope);
if (new_stack == irb->codegen->invalid_instruction)
return new_stack;

AstNode *fn_ref_node = node->data.fn_call_expr.params.at(1);
IrInstruction *fn_ref = ir_gen_node(irb, fn_ref_node, scope);
if (fn_ref == irb->codegen->invalid_instruction)
return fn_ref;

size_t arg_count = node->data.fn_call_expr.params.length - 2;

IrInstruction **args = allocate<IrInstruction*>(arg_count);
for (size_t i = 0; i < arg_count; i += 1) {
AstNode *arg_node = node->data.fn_call_expr.params.at(i + 2);
args[i] = ir_gen_node(irb, arg_node, scope);
if (args[i] == irb->codegen->invalid_instruction)
return args[i];
}

IrInstruction *call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, false, nullptr, new_stack);
return ir_lval_wrap(irb, scope, call, lval);
}
case BuiltinFnIdTypeId:
Expand Down Expand Up @@ -4513,7 +4548,7 @@ static IrInstruction *ir_gen_fn_call(IrBuilder *irb, Scope *scope, AstNode *node
}
}

IrInstruction *fn_call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, is_async, async_allocator);
IrInstruction *fn_call = ir_build_call(irb, scope, node, nullptr, fn_ref, arg_count, args, false, FnInlineAuto, is_async, async_allocator, nullptr);
return ir_lval_wrap(irb, scope, fn_call, lval);
}

Expand Down Expand Up @@ -6825,7 +6860,7 @@ bool ir_gen(CodeGen *codegen, AstNode *node, Scope *scope, IrExecutable *ir_exec
IrInstruction **args = allocate<IrInstruction *>(arg_count);
args[0] = implicit_allocator_ptr; // self
args[1] = mem_slice; // old_mem
ir_build_call(irb, scope, node, nullptr, free_fn, arg_count, args, false, FnInlineAuto, false, nullptr);
ir_build_call(irb, scope, node, nullptr, free_fn, arg_count, args, false, FnInlineAuto, false, nullptr, nullptr);

IrBasicBlock *resume_block = ir_create_basic_block(irb, scope, "Resume");
ir_build_cond_br(irb, scope, node, resume_awaiter, resume_block, irb->exec->coro_suspend_block, const_bool_false);
Expand Down Expand Up @@ -11992,7 +12027,7 @@ static IrInstruction *ir_analyze_async_call(IrAnalyze *ira, IrInstructionCall *c
TypeTableEntry *async_return_type = get_error_union_type(ira->codegen, alloc_fn_error_set_type, promise_type);

IrInstruction *result = ir_build_call(&ira->new_irb, call_instruction->base.scope, call_instruction->base.source_node,
fn_entry, fn_ref, arg_count, casted_args, false, FnInlineAuto, true, async_allocator_inst);
fn_entry, fn_ref, arg_count, casted_args, false, FnInlineAuto, true, async_allocator_inst, nullptr);
result->value.type = async_return_type;
return result;
}
Expand Down Expand Up @@ -12362,6 +12397,19 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
return ir_finish_anal(ira, return_type);
}

IrInstruction *casted_new_stack = nullptr;
if (call_instruction->new_stack != nullptr) {
TypeTableEntry *u8_ptr = get_pointer_to_type(ira->codegen, ira->codegen->builtin_types.entry_u8, false);
TypeTableEntry *u8_slice = get_slice_type(ira->codegen, u8_ptr);
IrInstruction *new_stack = call_instruction->new_stack->other;
if (type_is_invalid(new_stack->value.type))
return ira->codegen->builtin_types.entry_invalid;

casted_new_stack = ir_implicit_cast(ira, new_stack, u8_slice);
if (type_is_invalid(casted_new_stack->value.type))
return ira->codegen->builtin_types.entry_invalid;
}

if (fn_type->data.fn.is_generic) {
if (!fn_entry) {
ir_add_error(ira, call_instruction->fn_ref,
Expand Down Expand Up @@ -12588,7 +12636,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal
assert(async_allocator_inst == nullptr);
IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
impl_fn, nullptr, impl_param_count, casted_args, false, fn_inline,
call_instruction->is_async, nullptr);
call_instruction->is_async, nullptr, casted_new_stack);

ir_add_alloca(ira, new_call_instruction, return_type);

Expand Down Expand Up @@ -12679,7 +12727,7 @@ static TypeTableEntry *ir_analyze_fn_call(IrAnalyze *ira, IrInstructionCall *cal


IrInstruction *new_call_instruction = ir_build_call_from(&ira->new_irb, &call_instruction->base,
fn_entry, fn_ref, call_param_count, casted_args, false, fn_inline, false, nullptr);
fn_entry, fn_ref, call_param_count, casted_args, false, fn_inline, false, nullptr, casted_new_stack);

ir_add_alloca(ira, new_call_instruction, return_type);
return ir_finish_anal(ira, return_type);
Expand Down

0 comments on commit a6ae451

Please sign in to comment.