Skip to content

Commit

Permalink
add @popcount intrinsic
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewrk committed Jul 7, 2018
1 parent e19f0b5 commit d8295c1
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 6 deletions.
15 changes: 13 additions & 2 deletions doc/langref.html.in
Expand Up @@ -5013,7 +5013,7 @@ comptime {
<p>
If <code>x</code> is zero, <code>@clz</code> returns <code>T.bit_count</code>.
</p>

{#see_also|@ctz|@popCount#}
{#header_close#}
{#header_open|@cmpxchgStrong#}
<pre><code class="zig">@cmpxchgStrong(comptime T: type, ptr: *T, expected_value: T, new_value: T, success_order: AtomicOrder, fail_order: AtomicOrder) ?T</code></pre>
Expand Down Expand Up @@ -5149,6 +5149,7 @@ test "main" {
<p>
If <code>x</code> is zero, <code>@ctz</code> returns <code>T.bit_count</code>.
</p>
{#see_also|@clz|@popCount#}
{#header_close#}
{#header_open|@divExact#}
<pre><code class="zig">@divExact(numerator: T, denominator: T) T</code></pre>
Expand Down Expand Up @@ -5631,6 +5632,16 @@ test "call foo" {
</ul>
{#see_also|Root Source File#}
{#header_close#}
{#header_open|@popCount#}
<pre><code class="zig">@popCount(integer: var) var</code></pre>
<p>Counts the number of bits set in an integer.</p>
<p>
If <code>integer</code> is known at {#link|comptime#}, the return type is <code>comptime_int</code>.
Otherwise, the return type is an unsigned integer with the minimum number
of bits that can represent the bit count of the integer type.
</p>
{#see_also|@ctz|@clz#}
{#header_close#}
{#header_open|@ptrCast#}
<pre><code class="zig">@ptrCast(comptime DestType: type, value: var) DestType</code></pre>
<p>
Expand Down Expand Up @@ -7337,7 +7348,7 @@ hljs.registerLanguage("zig", function(t) {
a = t.IR + "\\s*\\(",
c = {
keyword: "const align var extern stdcallcc nakedcc volatile export pub noalias inline struct packed enum union break return try catch test continue unreachable comptime and or asm defer errdefer if else switch while for fn use bool f32 f64 void type noreturn error i8 u8 i16 u16 i32 u32 i64 u64 isize usize i8w u8w i16w i32w u32w i64w u64w isizew usizew c_short c_ushort c_int c_uint c_long c_ulong c_longlong c_ulonglong resume cancel await async orelse",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage divTrunc divFloor enumTagName intToPtr ptrToInt panic ptrCast intCast floatCast intToFloat floatToInt boolToInt bytesToSlice sliceToBytes errSetCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo typeName newStackCall errorToInt intToError enumToInt intToEnum",
built_in: "atomicLoad breakpoint returnAddress frameAddress fieldParentPtr setFloatMode IntType OpaqueType compileError compileLog setCold setRuntimeSafety setEvalBranchQuota offsetOf memcpy inlineCall setGlobalLinkage divTrunc divFloor enumTagName intToPtr ptrToInt panic ptrCast intCast floatCast intToFloat floatToInt boolToInt bytesToSlice sliceToBytes errSetCast bitCast rem mod memset sizeOf alignOf alignCast maxValue minValue memberCount memberName memberType typeOf addWithOverflow subWithOverflow mulWithOverflow shlWithOverflow shlExact shrExact cInclude cDefine cUndef ctz clz popCount import cImport errorName embedFile cmpxchgStrong cmpxchgWeak fence divExact truncate atomicRmw sqrt field typeInfo typeName newStackCall errorToInt intToError enumToInt intToEnum",
literal: "true false null undefined"
},
n = [e, t.CLCM, t.CBCM, s, r];
Expand Down
12 changes: 12 additions & 0 deletions src/all_types.hpp
Expand Up @@ -1352,6 +1352,7 @@ enum BuiltinFnId {
BuiltinFnIdCompileLog,
BuiltinFnIdCtz,
BuiltinFnIdClz,
BuiltinFnIdPopCount,
BuiltinFnIdImport,
BuiltinFnIdCImport,
BuiltinFnIdErrName,
Expand Down Expand Up @@ -1477,6 +1478,7 @@ bool type_id_eql(TypeId a, TypeId b);
enum ZigLLVMFnId {
ZigLLVMFnIdCtz,
ZigLLVMFnIdClz,
ZigLLVMFnIdPopCount,
ZigLLVMFnIdOverflowArithmetic,
ZigLLVMFnIdFloor,
ZigLLVMFnIdCeil,
Expand All @@ -1499,6 +1501,9 @@ struct ZigLLVMFnKey {
struct {
uint32_t bit_count;
} clz;
struct {
uint32_t bit_count;
} pop_count;
struct {
uint32_t bit_count;
} floating;
Expand Down Expand Up @@ -2050,6 +2055,7 @@ enum IrInstructionId {
IrInstructionIdUnionTag,
IrInstructionIdClz,
IrInstructionIdCtz,
IrInstructionIdPopCount,
IrInstructionIdImport,
IrInstructionIdCImport,
IrInstructionIdCInclude,
Expand Down Expand Up @@ -2545,6 +2551,12 @@ struct IrInstructionClz {
IrInstruction *value;
};

struct IrInstructionPopCount {
IrInstruction base;

IrInstruction *value;
};

struct IrInstructionUnionTag {
IrInstruction base;

Expand Down
4 changes: 4 additions & 0 deletions src/analyze.cpp
Expand Up @@ -5976,6 +5976,8 @@ uint32_t zig_llvm_fn_key_hash(ZigLLVMFnKey x) {
return (uint32_t)(x.data.ctz.bit_count) * (uint32_t)810453934;
case ZigLLVMFnIdClz:
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)2428952817;
case ZigLLVMFnIdPopCount:
return (uint32_t)(x.data.clz.bit_count) * (uint32_t)101195049;
case ZigLLVMFnIdFloor:
return (uint32_t)(x.data.floating.bit_count) * (uint32_t)1899859168;
case ZigLLVMFnIdCeil:
Expand All @@ -5998,6 +6000,8 @@ bool zig_llvm_fn_key_eql(ZigLLVMFnKey a, ZigLLVMFnKey b) {
return a.data.ctz.bit_count == b.data.ctz.bit_count;
case ZigLLVMFnIdClz:
return a.data.clz.bit_count == b.data.clz.bit_count;
case ZigLLVMFnIdPopCount:
return a.data.pop_count.bit_count == b.data.pop_count.bit_count;
case ZigLLVMFnIdFloor:
case ZigLLVMFnIdCeil:
case ZigLLVMFnIdSqrt:
Expand Down
31 changes: 31 additions & 0 deletions src/bigint.cpp
Expand Up @@ -1593,6 +1593,37 @@ void bigint_append_buf(Buf *buf, const BigInt *op, uint64_t base) {
}
}

size_t bigint_popcount_unsigned(const BigInt *bi) {
assert(!bi->is_negative);
if (bi->digit_count == 0)
return 0;

size_t count = 0;
size_t bit_count = bi->digit_count * 64;
for (size_t i = 0; i < bit_count; i += 1) {
if (bit_at_index(bi, i))
count += 1;
}
return count;
}

size_t bigint_popcount_signed(const BigInt *bi, size_t bit_count) {
if (bit_count == 0)
return 0;
if (bi->digit_count == 0)
return 0;

BigInt twos_comp = {0};
to_twos_complement(&twos_comp, bi, bit_count);

size_t count = 0;
for (size_t i = 0; i < bit_count; i += 1) {
if (bit_at_index(&twos_comp, i))
count += 1;
}
return count;
}

size_t bigint_ctz(const BigInt *bi, size_t bit_count) {
if (bit_count == 0)
return 0;
Expand Down
2 changes: 2 additions & 0 deletions src/bigint.hpp
Expand Up @@ -81,6 +81,8 @@ void bigint_append_buf(Buf *buf, const BigInt *op, uint64_t base);

size_t bigint_ctz(const BigInt *bi, size_t bit_count);
size_t bigint_clz(const BigInt *bi, size_t bit_count);
size_t bigint_popcount_signed(const BigInt *bi, size_t bit_count);
size_t bigint_popcount_unsigned(const BigInt *bi);

size_t bigint_bits_needed(const BigInt *op);

Expand Down
21 changes: 20 additions & 1 deletion src/codegen.cpp
Expand Up @@ -3426,14 +3426,22 @@ static LLVMValueRef ir_render_unwrap_maybe(CodeGen *g, IrExecutable *executable,
static LLVMValueRef get_int_builtin_fn(CodeGen *g, TypeTableEntry *int_type, BuiltinFnId fn_id) {
ZigLLVMFnKey key = {};
const char *fn_name;
uint32_t n_args;
if (fn_id == BuiltinFnIdCtz) {
fn_name = "cttz";
n_args = 2;
key.id = ZigLLVMFnIdCtz;
key.data.ctz.bit_count = (uint32_t)int_type->data.integral.bit_count;
} else if (fn_id == BuiltinFnIdClz) {
fn_name = "ctlz";
n_args = 2;
key.id = ZigLLVMFnIdClz;
key.data.clz.bit_count = (uint32_t)int_type->data.integral.bit_count;
} else if (fn_id == BuiltinFnIdPopCount) {
fn_name = "ctpop";
n_args = 1;
key.id = ZigLLVMFnIdPopCount;
key.data.pop_count.bit_count = (uint32_t)int_type->data.integral.bit_count;
} else {
zig_unreachable();
}
Expand All @@ -3448,7 +3456,7 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, TypeTableEntry *int_type, Bui
int_type->type_ref,
LLVMInt1Type(),
};
LLVMTypeRef fn_type = LLVMFunctionType(int_type->type_ref, param_types, 2, false);
LLVMTypeRef fn_type = LLVMFunctionType(int_type->type_ref, param_types, n_args, false);
LLVMValueRef fn_val = LLVMAddFunction(g->module, llvm_name, fn_type);
assert(LLVMGetIntrinsicID(fn_val));

Expand Down Expand Up @@ -3481,6 +3489,14 @@ static LLVMValueRef ir_render_ctz(CodeGen *g, IrExecutable *executable, IrInstru
return gen_widen_or_shorten(g, false, int_type, instruction->base.value.type, wrong_size_int);
}

static LLVMValueRef ir_render_pop_count(CodeGen *g, IrExecutable *executable, IrInstructionPopCount *instruction) {
TypeTableEntry *int_type = instruction->value->value.type;
LLVMValueRef fn_val = get_int_builtin_fn(g, int_type, BuiltinFnIdPopCount);
LLVMValueRef operand = ir_llvm_value(g, instruction->value);
LLVMValueRef wrong_size_int = LLVMBuildCall(g->builder, fn_val, &operand, 1, "");
return gen_widen_or_shorten(g, false, int_type, instruction->base.value.type, wrong_size_int);
}

static LLVMValueRef ir_render_switch_br(CodeGen *g, IrExecutable *executable, IrInstructionSwitchBr *instruction) {
LLVMValueRef target_value = ir_llvm_value(g, instruction->target_value);
LLVMBasicBlockRef else_block = instruction->else_block->llvm_block;
Expand Down Expand Up @@ -4831,6 +4847,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_clz(g, executable, (IrInstructionClz *)instruction);
case IrInstructionIdCtz:
return ir_render_ctz(g, executable, (IrInstructionCtz *)instruction);
case IrInstructionIdPopCount:
return ir_render_pop_count(g, executable, (IrInstructionPopCount *)instruction);
case IrInstructionIdSwitchBr:
return ir_render_switch_br(g, executable, (IrInstructionSwitchBr *)instruction);
case IrInstructionIdPhi:
Expand Down Expand Up @@ -6342,6 +6360,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdCUndef, "cUndef", 1);
create_builtin_fn(g, BuiltinFnIdCtz, "ctz", 1);
create_builtin_fn(g, BuiltinFnIdClz, "clz", 1);
create_builtin_fn(g, BuiltinFnIdPopCount, "popCount", 1);
create_builtin_fn(g, BuiltinFnIdImport, "import", 1);
create_builtin_fn(g, BuiltinFnIdCImport, "cImport", 1);
create_builtin_fn(g, BuiltinFnIdErrName, "errorName", 1);
Expand Down
68 changes: 68 additions & 0 deletions src/ir.cpp
Expand Up @@ -427,6 +427,10 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionCtz *) {
return IrInstructionIdCtz;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionPopCount *) {
return IrInstructionIdPopCount;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionUnionTag *) {
return IrInstructionIdUnionTag;
}
Expand Down Expand Up @@ -1725,6 +1729,15 @@ static IrInstruction *ir_build_ctz_from(IrBuilder *irb, IrInstruction *old_instr
return new_instruction;
}

static IrInstruction *ir_build_pop_count(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) {
IrInstructionPopCount *instruction = ir_build_instruction<IrInstructionPopCount>(irb, scope, source_node);
instruction->value = value;

ir_ref_instruction(value, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_switch_br(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *target_value,
IrBasicBlock *else_block, size_t case_count, IrInstructionSwitchBrCase *cases, IrInstruction *is_comptime,
IrInstruction *switch_prongs_void)
Expand Down Expand Up @@ -3841,6 +3854,16 @@ static IrInstruction *ir_gen_builtin_fn_call(IrBuilder *irb, Scope *scope, AstNo
IrInstruction *ctz = ir_build_ctz(irb, scope, node, arg0_value);
return ir_lval_wrap(irb, scope, ctz, lval);
}
case BuiltinFnIdPopCount:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
IrInstruction *arg0_value = ir_gen_node(irb, arg0_node, scope);
if (arg0_value == irb->codegen->invalid_instruction)
return arg0_value;

IrInstruction *instr = ir_build_pop_count(irb, scope, node, arg0_value);
return ir_lval_wrap(irb, scope, instr, lval);
}
case BuiltinFnIdClz:
{
AstNode *arg0_node = node->data.fn_call_expr.params.at(0);
Expand Down Expand Up @@ -15275,6 +15298,48 @@ static TypeTableEntry *ir_analyze_instruction_clz(IrAnalyze *ira, IrInstructionC
}
}

static TypeTableEntry *ir_analyze_instruction_pop_count(IrAnalyze *ira, IrInstructionPopCount *instruction) {
IrInstruction *value = instruction->value->other;
if (type_is_invalid(value->value.type))
return ira->codegen->builtin_types.entry_invalid;

if (value->value.type->id != TypeTableEntryIdInt && value->value.type->id != TypeTableEntryIdComptimeInt) {
ir_add_error(ira, value,
buf_sprintf("expected integer type, found '%s'", buf_ptr(&value->value.type->name)));
return ira->codegen->builtin_types.entry_invalid;
}

if (instr_is_comptime(value)) {
ConstExprValue *val = ir_resolve_const(ira, value, UndefBad);
if (!val)
return ira->codegen->builtin_types.entry_invalid;
if (bigint_cmp_zero(&val->data.x_bigint) != CmpLT) {
size_t result = bigint_popcount_unsigned(&val->data.x_bigint);
ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
bigint_init_unsigned(&out_val->data.x_bigint, result);
return ira->codegen->builtin_types.entry_num_lit_int;
}
if (value->value.type->id == TypeTableEntryIdComptimeInt) {
Buf *val_buf = buf_alloc();
bigint_append_buf(val_buf, &val->data.x_bigint, 10);
ir_add_error(ira, &instruction->base,
buf_sprintf("@popCount on negative %s value %s",
buf_ptr(&value->value.type->name), buf_ptr(val_buf)));
return ira->codegen->builtin_types.entry_invalid;
}
size_t result = bigint_popcount_signed(&val->data.x_bigint, value->value.type->data.integral.bit_count);
ConstExprValue *out_val = ir_build_const_from(ira, &instruction->base);
bigint_init_unsigned(&out_val->data.x_bigint, result);
return ira->codegen->builtin_types.entry_num_lit_int;
}

IrInstruction *result = ir_build_pop_count(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, value);
result->value.type = get_smallest_unsigned_int_type(ira->codegen, value->value.type->data.integral.bit_count);
ir_link_new_instruction(result, &instruction->base);
return result->value.type;
}

static IrInstruction *ir_analyze_union_tag(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *value) {
if (type_is_invalid(value->value.type))
return ira->codegen->invalid_instruction;
Expand Down Expand Up @@ -20534,6 +20599,8 @@ static TypeTableEntry *ir_analyze_instruction_nocast(IrAnalyze *ira, IrInstructi
return ir_analyze_instruction_clz(ira, (IrInstructionClz *)instruction);
case IrInstructionIdCtz:
return ir_analyze_instruction_ctz(ira, (IrInstructionCtz *)instruction);
case IrInstructionIdPopCount:
return ir_analyze_instruction_pop_count(ira, (IrInstructionPopCount *)instruction);
case IrInstructionIdSwitchBr:
return ir_analyze_instruction_switch_br(ira, (IrInstructionSwitchBr *)instruction);
case IrInstructionIdSwitchTarget:
Expand Down Expand Up @@ -20892,6 +20959,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdUnwrapOptional:
case IrInstructionIdClz:
case IrInstructionIdCtz:
case IrInstructionIdPopCount:
case IrInstructionIdSwitchVar:
case IrInstructionIdSwitchTarget:
case IrInstructionIdUnionTag:
Expand Down
9 changes: 9 additions & 0 deletions src/ir_print.cpp
Expand Up @@ -501,6 +501,12 @@ static void ir_print_ctz(IrPrint *irp, IrInstructionCtz *instruction) {
fprintf(irp->f, ")");
}

static void ir_print_pop_count(IrPrint *irp, IrInstructionPopCount *instruction) {
fprintf(irp->f, "@popCount(");
ir_print_other_instruction(irp, instruction->value);
fprintf(irp->f, ")");
}

static void ir_print_switch_br(IrPrint *irp, IrInstructionSwitchBr *instruction) {
fprintf(irp->f, "switch (");
ir_print_other_instruction(irp, instruction->target_value);
Expand Down Expand Up @@ -1425,6 +1431,9 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction) {
case IrInstructionIdCtz:
ir_print_ctz(irp, (IrInstructionCtz *)instruction);
break;
case IrInstructionIdPopCount:
ir_print_pop_count(irp, (IrInstructionPopCount *)instruction);
break;
case IrInstructionIdClz:
ir_print_clz(irp, (IrInstructionClz *)instruction);
break;
Expand Down
7 changes: 4 additions & 3 deletions test/behavior.zig
Expand Up @@ -8,17 +8,17 @@ comptime {
_ = @import("cases/atomics.zig");
_ = @import("cases/bitcast.zig");
_ = @import("cases/bool.zig");
_ = @import("cases/bugs/1111.zig");
_ = @import("cases/bugs/394.zig");
_ = @import("cases/bugs/655.zig");
_ = @import("cases/bugs/656.zig");
_ = @import("cases/bugs/828.zig");
_ = @import("cases/bugs/920.zig");
_ = @import("cases/bugs/1111.zig");
_ = @import("cases/byval_arg_var.zig");
_ = @import("cases/cast.zig");
_ = @import("cases/const_slice_child.zig");
_ = @import("cases/coroutines.zig");
_ = @import("cases/coroutine_await_struct.zig");
_ = @import("cases/coroutines.zig");
_ = @import("cases/defer.zig");
_ = @import("cases/enum.zig");
_ = @import("cases/enum_with_members.zig");
Expand All @@ -36,11 +36,12 @@ comptime {
_ = @import("cases/math.zig");
_ = @import("cases/merge_error_sets.zig");
_ = @import("cases/misc.zig");
_ = @import("cases/optional.zig");
_ = @import("cases/namespace_depends_on_compile_var/index.zig");
_ = @import("cases/new_stack_call.zig");
_ = @import("cases/null.zig");
_ = @import("cases/optional.zig");
_ = @import("cases/pointers.zig");
_ = @import("cases/popcount.zig");
_ = @import("cases/pub_enum/index.zig");
_ = @import("cases/ref_var_in_if_after_if_2nd_switch_prong.zig");
_ = @import("cases/reflection.zig");
Expand Down

0 comments on commit d8295c1

Please sign in to comment.