Skip to content

Commit

Permalink
improve cmpxchg
Browse files Browse the repository at this point in the history
 * remove @cmpxchg, add @cmpxchgWeak and @cmpxchgStrong
   - See explanations in the langref.
 * add operand type as first parameter
 * return type is ?T where T is the operand type

closes #461
  • Loading branch information
andrewrk committed Apr 18, 2018
1 parent 96ebd8b commit f1f998e
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 63 deletions.
54 changes: 49 additions & 5 deletions doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -4065,16 +4065,60 @@ comptime {
</p>

{#header_close#}
{#header_open|@cmpxchg#}
<pre><code class="zig">@cmpxchg(ptr: &T, cmp: T, new: T, success_order: AtomicOrder, fail_order: AtomicOrder) -&gt; bool</code></pre>
{#header_open|@cmpxchgStrong#}
<pre><code class="zig">@cmpxchgStrong(comptime T: type, ptr: &T, expected_value: T, new_value: T, success_order: AtomicOrder, fail_order: AtomicOrder) -&gt; ?T</code></pre>
<p>
This function performs an atomic compare exchange operation.
This function performs a strong atomic compare exchange operation. It's the equivalent of this code,
except atomic:
</p>
{#code_begin|syntax#}
fn cmpxchgStrongButNotAtomic(comptime T: type, ptr: &T, expected_value: T, new_value: T) ?T {
const old_value = *ptr;
if (old_value == expected_value) {
*ptr = new_value;
return null;
} else {
return old_value;
}
}
{#code_end#}
<p>
If you are using cmpxchg in a loop, {#link|@cmpxchgWeak#} is the better choice, because it can be implemented
more efficiently in machine instructions.
</p>
<p>
<code>AtomicOrder</code> can be found with <code>@import("builtin").AtomicOrder</code>.
</p>
<p><code>@typeOf(ptr).alignment</code> must be <code>&gt;= @sizeOf(T).</code></p>
{#see_also|Compile Variables#}
{#see_also|Compile Variables|cmpxchgWeak#}
{#header_close#}
{#header_open|@cmpxchgWeak#}
<pre><code class="zig">@cmpxchgWeak(comptime T: type, ptr: &T, expected_value: T, new_value: T, success_order: AtomicOrder, fail_order: AtomicOrder) -&gt; ?T</code></pre>
<p>
This function performs a weak atomic compare exchange operation. It's the equivalent of this code,
except atomic:
</p>
{#code_begin|syntax#}
fn cmpxchgWeakButNotAtomic(comptime T: type, ptr: &T, expected_value: T, new_value: T) ?T {
const old_value = *ptr;
if (old_value == expected_value and usuallyTrueButSometimesFalse()) {
*ptr = new_value;
return null;
} else {
return old_value;
}
}
{#code_end#}
<p>
If you are using cmpxchg in a loop, the sporadic failure will be no problem, and <code>cmpxchgWeak</code>
is the better choice, because it can be implemented more efficiently in machine instructions.
However if you need a stronger guarantee, use {#link|@cmpxchgStrong#}.
</p>
<p>
<code>AtomicOrder</code> can be found with <code>@import("builtin").AtomicOrder</code>.
</p>
<p><code>@typeOf(ptr).alignment</code> must be <code>&gt;= @sizeOf(T).</code></p>
{#see_also|Compile Variables|cmpxchgStrong#}
{#header_close#}
{#header_open|@compileError#}
<pre><code class="zig">@compileError(comptime msg: []u8)</code></pre>
Expand Down Expand Up @@ -6020,7 +6064,7 @@ hljs.registerLanguage("zig", function(t) {
a = t.IR + "\\s*\\(",
c = {
keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchg fence divExact truncate atomicRmw sqrt",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage setGlobalSection divTrunc divFloor enumTagName intToPtr ptrToInt panic canImplicitCast ptrCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt",
literal: "true false null undefined"
},
n = [e, t.CLCM, t.CBCM, s, r];
Expand Down
9 changes: 8 additions & 1 deletion src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,8 @@ enum BuiltinFnId {
BuiltinFnIdReturnAddress,
BuiltinFnIdFrameAddress,
BuiltinFnIdEmbedFile,
BuiltinFnIdCmpExchange,
BuiltinFnIdCmpxchgWeak,
BuiltinFnIdCmpxchgStrong,
BuiltinFnIdFence,
BuiltinFnIdDivExact,
BuiltinFnIdDivTrunc,
Expand Down Expand Up @@ -2528,15 +2529,21 @@ struct IrInstructionEmbedFile {
struct IrInstructionCmpxchg {
IrInstruction base;

IrInstruction *type_value;
IrInstruction *ptr;
IrInstruction *cmp_value;
IrInstruction *new_value;
IrInstruction *success_order_value;
IrInstruction *failure_order_value;

// if this instruction gets to runtime then we know these values:
TypeTableEntry *type;
AtomicOrder success_order;
AtomicOrder failure_order;

bool is_weak;

LLVMValueRef tmp_ptr;
};

struct IrInstructionFence {
Expand Down
21 changes: 18 additions & 3 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3558,9 +3558,20 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutable *executable, IrIn
LLVMAtomicOrdering failure_order = to_LLVMAtomicOrdering(instruction->failure_order);

LLVMValueRef result_val = ZigLLVMBuildCmpXchg(g->builder, ptr_val, cmp_val, new_val,
success_order, failure_order);
success_order, failure_order, instruction->is_weak);

return LLVMBuildExtractValue(g->builder, result_val, 1, "");
assert(instruction->tmp_ptr != nullptr);
assert(type_has_bits(instruction->type));

LLVMValueRef payload_val = LLVMBuildExtractValue(g->builder, result_val, 0, "");
LLVMValueRef val_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, maybe_child_index, "");
gen_assign_raw(g, val_ptr, get_pointer_to_type(g, instruction->type, false), payload_val);

LLVMValueRef success_bit = LLVMBuildExtractValue(g->builder, result_val, 1, "");
LLVMValueRef nonnull_bit = LLVMBuildNot(g->builder, success_bit, "");
LLVMValueRef maybe_ptr = LLVMBuildStructGEP(g->builder, instruction->tmp_ptr, maybe_null_index, "");
gen_store_untyped(g, nonnull_bit, maybe_ptr, 0, false);
return instruction->tmp_ptr;
}

static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutable *executable, IrInstructionFence *instruction) {
Expand Down Expand Up @@ -5588,6 +5599,9 @@ static void do_code_gen(CodeGen *g) {
} else if (instruction->id == IrInstructionIdErrWrapCode) {
IrInstructionErrWrapCode *err_wrap_code_instruction = (IrInstructionErrWrapCode *)instruction;
slot = &err_wrap_code_instruction->tmp_ptr;
} else if (instruction->id == IrInstructionIdCmpxchg) {
IrInstructionCmpxchg *cmpxchg_instruction = (IrInstructionCmpxchg *)instruction;
slot = &cmpxchg_instruction->tmp_ptr;
} else {
zig_unreachable();
}
Expand Down Expand Up @@ -6115,7 +6129,8 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdTypeName, "typeName", 1);
create_builtin_fn(g, BuiltinFnIdCanImplicitCast, "canImplicitCast", 2);
create_builtin_fn(g, BuiltinFnIdEmbedFile, "embedFile", 1);
create_builtin_fn(g, BuiltinFnIdCmpExchange, "cmpxchg", 5);
create_builtin_fn(g, BuiltinFnIdCmpxchgWeak, "cmpxchgWeak", 6);
create_builtin_fn(g, BuiltinFnIdCmpxchgStrong, "cmpxchgStrong", 6);
create_builtin_fn(g, BuiltinFnIdFence, "fence", 1);
create_builtin_fn(g, BuiltinFnIdTruncate, "truncate", 2);
create_builtin_fn(g, BuiltinFnIdCompileErr, "compileError", 1);
Expand Down
85 changes: 45 additions & 40 deletions src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ static IrInstruction *ir_analyze_container_field_ptr(IrAnalyze *ira, Buf *field_
IrInstruction *source_instr, IrInstruction *container_ptr, TypeTableEntry *container_type);
static IrInstruction *ir_get_var_ptr(IrAnalyze *ira, IrInstruction *instruction,
VariableTableEntry *var, bool is_const_ptr, bool is_volatile_ptr);
static TypeTableEntry *ir_resolve_atomic_operand_type(IrAnalyze *ira, IrInstruction *op);

ConstExprValue *const_ptr_pointee(CodeGen *g, ConstExprValue *const_val) {
assert(const_val->type->id == TypeTableEntryIdPointer);
Expand Down Expand Up @@ -1832,38 +1833,34 @@ static IrInstruction *ir_build_embed_file(IrBuilder *irb, Scope *scope, AstNode
return &instruction->base;
}

static IrInstruction *ir_build_cmpxchg(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *ptr,
IrInstruction *cmp_value, IrInstruction *new_value, IrInstruction *success_order_value, IrInstruction *failure_order_value,
AtomicOrder success_order, AtomicOrder failure_order)
static IrInstruction *ir_build_cmpxchg(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *type_value,
IrInstruction *ptr, IrInstruction *cmp_value, IrInstruction *new_value,
IrInstruction *success_order_value, IrInstruction *failure_order_value,
bool is_weak,
TypeTableEntry *type, AtomicOrder success_order, AtomicOrder failure_order)
{
IrInstructionCmpxchg *instruction = ir_build_instruction<IrInstructionCmpxchg>(irb, scope, source_node);
instruction->type_value = type_value;
instruction->ptr = ptr;
instruction->cmp_value = cmp_value;
instruction->new_value = new_value;
instruction->success_order_value = success_order_value;
instruction->failure_order_value = failure_order_value;
instruction->is_weak = is_weak;
instruction->type = type;
instruction->success_order = success_order;
instruction->failure_order = failure_order;

if (type_value != nullptr) ir_ref_instruction(type_value, irb->current_basic_block);
ir_ref_instruction(ptr, irb->current_basic_block);
ir_ref_instruction(cmp_value, irb->current_basic_block);
ir_ref_instruction(new_value, irb->current_basic_block);
ir_ref_instruction(success_order_value, irb->current_basic_block);
ir_ref_instruction(failure_order_value, irb->current_basic_block);
if (type_value != nullptr) ir_ref_instruction(success_order_value, irb->current_basic_block);
if (type_value != nullptr) ir_ref_instruction(failure_order_value, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_cmpxchg_from(IrBuilder *irb, IrInstruction *old_instruction, IrInstruction *ptr,
IrInstruction *cmp_value, IrInstruction *new_value, IrInstruction *success_order_value, IrInstruction *failure_order_value,
AtomicOrder success_order, AtomicOrder failure_order)
{
IrInstruction *new_instruction = ir_build_cmpxchg(irb, old_instruction->scope, old_instruction->source_node,
ptr, cmp_value, new_value, success_order_value, failure_order_value, success_order, failure_order);
ir_link_new_instruction(new_instruction, old_instruction);
return new_instruction;
}

static IrInstruction *ir_build_fence(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *order_value, AtomicOrder order) {
IrInstructionFence *instruction = ir_build_instruction<IrInstructionFence>(irb, scope, source_node);
instruction->order_value = order_value;
Expand Down Expand Up @@ -3771,7 +3768,8 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo

return ir_build_embed_file(irb, scope, node, arg0_value);
}
case BuiltinFnIdCmpExchange:
case BuiltinFnIdCmpxchgWeak:
case BuiltinFnIdCmpxchgStrong:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
Expand All @@ -3798,9 +3796,14 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
if (arg4_value == irb->codegen->invalid_instruction)
return arg4_value;

AstNode *arg5_node = node->data.fn_call_expr.params.at(5);
IrInstruction *arg5_value = ir_gen_node(irb, arg5_node, scope);
if (arg5_value == irb->codegen->invalid_instruction)
return arg5_value;

return ir_build_cmpxchg(irb, scope, node, arg0_value, arg1_value,
arg2_value, arg3_value, arg4_value,
AtomicOrderUnordered, AtomicOrderUnordered);
arg2_value, arg3_value, arg4_value, arg5_value, (builtin_fn->id == BuiltinFnIdCmpxchgWeak),
nullptr, AtomicOrderUnordered, AtomicOrderUnordered);
}
case BuiltinFnIdFence:
{
Expand Down Expand Up @@ -15730,10 +15733,20 @@ static TypeTableEntry *ir_analyze_instruction_embed_file(IrAnalyze *ira, IrInstr
}

static TypeTableEntry *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstructionCmpxchg *instruction) {
TypeTableEntry *operand_type = ir_resolve_atomic_operand_type(ira, instruction->type_value->other);
if (type_is_invalid(operand_type))
return ira->codegen->builtin_types.entry_invalid;

IrInstruction *ptr = instruction->ptr->other;
if (type_is_invalid(ptr->value.type))
return ira->codegen->builtin_types.entry_invalid;

// TODO let this be volatile
TypeTableEntry *ptr_type = get_pointer_to_type(ira->codegen, operand_type, false);
IrInstruction *casted_ptr = ir_implicit_cast(ira, ptr, ptr_type);
if (type_is_invalid(casted_ptr->value.type))
return ira->codegen->builtin_types.entry_invalid;

IrInstruction *cmp_value = instruction->cmp_value->other;
if (type_is_invalid(cmp_value->value.type))
return ira->codegen->builtin_types.entry_invalid;
Expand All @@ -15758,28 +15771,11 @@ static TypeTableEntry *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstruct
if (!ir_resolve_atomic_order(ira, failure_order_value, &failure_order))
return ira->codegen->builtin_types.entry_invalid;

if (ptr->value.type->id != TypeTableEntryIdPointer) {
ir_add_error(ira, instruction->ptr,
buf_sprintf("expected pointer argument, found '%s'", buf_ptr(&ptr->value.type->name)));
return ira->codegen->builtin_types.entry_invalid;
}

TypeTableEntry *child_type = ptr->value.type->data.pointer.child_type;

uint32_t align_bytes = ptr->value.type->data.pointer.alignment;
uint64_t size_bytes = type_size(ira->codegen, child_type);
if (align_bytes < size_bytes) {
ir_add_error(ira, instruction->ptr,
buf_sprintf("expected pointer alignment of at least %" ZIG_PRI_u64 ", found %" PRIu32,
size_bytes, align_bytes));
return ira->codegen->builtin_types.entry_invalid;
}

IrInstruction *casted_cmp_value = ir_implicit_cast(ira, cmp_value, child_type);
IrInstruction *casted_cmp_value = ir_implicit_cast(ira, cmp_value, operand_type);
if (type_is_invalid(casted_cmp_value->value.type))
return ira->codegen->builtin_types.entry_invalid;

IrInstruction *casted_new_value = ir_implicit_cast(ira, new_value, child_type);
IrInstruction *casted_new_value = ir_implicit_cast(ira, new_value, operand_type);
if (type_is_invalid(casted_new_value->value.type))
return ira->codegen->builtin_types.entry_invalid;

Expand All @@ -15804,9 +15800,17 @@ static TypeTableEntry *ir_analyze_instruction_cmpxchg(IrAnalyze *ira, IrInstruct
return ira->codegen->builtin_types.entry_invalid;
}

ir_build_cmpxchg_from(&ira->new_irb, &instruction->base, ptr, casted_cmp_value, casted_new_value,
success_order_value, failure_order_value, success_order, failure_order);
return ira->codegen->builtin_types.entry_bool;
if (instr_is_comptime(casted_ptr) && instr_is_comptime(casted_cmp_value) && instr_is_comptime(casted_new_value)) {
zig_panic("TODO compile-time execution of cmpxchg");
}

IrInstruction *result = ir_build_cmpxchg(&ira->new_irb, instruction->base.scope, instruction->base.source_node,
nullptr, casted_ptr, casted_cmp_value, casted_new_value, nullptr, nullptr, instruction->is_weak,
operand_type, success_order, failure_order);
result->value.type = get_maybe_type(ira->codegen, operand_type);
ir_link_new_instruction(result, &instruction->base);
ir_add_alloca(ira, result, result->value.type);
return result->value.type;
}

static TypeTableEntry *ir_analyze_instruction_fence(IrAnalyze *ira, IrInstructionFence *instruction) {
Expand Down Expand Up @@ -17981,6 +17985,7 @@ static TypeTableEntry *ir_analyze_instruction_atomic_rmw(IrAnalyze *ira, IrInstr
if (type_is_invalid(ptr_inst->value.type))
return ira->codegen->builtin_types.entry_invalid;

// TODO let this be volatile
TypeTableEntry *ptr_type = get_pointer_to_type(ira->codegen, operand_type, false);
IrInstruction *casted_ptr = ir_implicit_cast(ira, ptr_inst, ptr_type);
if (type_is_invalid(casted_ptr->value.type))
Expand Down
8 changes: 5 additions & 3 deletions src/zig_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,12 @@ static AtomicOrdering mapFromLLVMOrdering(LLVMAtomicOrdering Ordering) {

LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
LLVMAtomicOrdering failure_ordering)
LLVMAtomicOrdering failure_ordering, bool is_weak)
{
return wrap(unwrap(builder)->CreateAtomicCmpXchg(unwrap(ptr), unwrap(cmp), unwrap(new_val),
mapFromLLVMOrdering(success_ordering), mapFromLLVMOrdering(failure_ordering)));
AtomicCmpXchgInst *inst = unwrap(builder)->CreateAtomicCmpXchg(unwrap(ptr), unwrap(cmp),
unwrap(new_val), mapFromLLVMOrdering(success_ordering), mapFromLLVMOrdering(failure_ordering));
inst->setWeak(is_weak);
return wrap(inst);
}

LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
Expand Down
2 changes: 1 addition & 1 deletion src/zig_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildCall(LLVMBuilderRef B, LLVMValueRef Fn, LL

ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildCmpXchg(LLVMBuilderRef builder, LLVMValueRef ptr, LLVMValueRef cmp,
LLVMValueRef new_val, LLVMAtomicOrdering success_ordering,
LLVMAtomicOrdering failure_ordering);
LLVMAtomicOrdering failure_ordering, bool is_weak);

ZIG_EXTERN_C LLVMValueRef ZigLLVMBuildNSWShl(LLVMBuilderRef builder, LLVMValueRef LHS, LLVMValueRef RHS,
const char *name);
Expand Down
16 changes: 14 additions & 2 deletions test/cases/atomics.zig
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
const assert = @import("std").debug.assert;
const std = @import("std");
const assert = std.debug.assert;
const builtin = @import("builtin");
const AtomicRmwOp = builtin.AtomicRmwOp;
const AtomicOrder = builtin.AtomicOrder;

test "cmpxchg" {
var x: i32 = 1234;
while (!@cmpxchg(&x, 1234, 5678, AtomicOrder.SeqCst, AtomicOrder.SeqCst)) {}
if (@cmpxchgWeak(i32, &x, 99, 5678, AtomicOrder.SeqCst, AtomicOrder.SeqCst)) |x1| {
assert(x1 == 1234);
} else {
@panic("cmpxchg should have failed");
}

while (@cmpxchgWeak(i32, &x, 1234, 5678, AtomicOrder.SeqCst, AtomicOrder.SeqCst)) |x1| {
assert(x1 == 1234);
}
assert(x == 5678);

assert(@cmpxchgStrong(i32, &x, 5678, 42, AtomicOrder.SeqCst, AtomicOrder.SeqCst) == null);
assert(x == 42);
}

test "fence" {
Expand Down

0 comments on commit f1f998e

Please sign in to comment.