Skip to content

Commit 5d4a02c

Browse files
authoredJul 30, 2018
Merge pull request #1307 from ziglang/cancel-semantics
improved coroutine cancel semantics
·
0.15.10.3.0
2 parents 608ff52 + cfe03c7 commit 5d4a02c

File tree

10 files changed

+486
-142
lines changed

10 files changed

+486
-142
lines changed
 

‎doc/langref.html.in‎

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4665,24 +4665,24 @@ async fn testSuspendBlock() void {
46654665
block, while the old thread continued executing the suspend block.
46664666
</p>
46674667
<p>
4668-
However, if you use labeled <code>break</code> on the suspend block, the coroutine
4668+
However, the coroutine can be directly resumed from the suspend block, in which case it
46694669
never returns to its resumer and continues executing.
46704670
</p>
46714671
{#code_begin|test#}
46724672
const std = @import("std");
46734673
const assert = std.debug.assert;
46744674

4675-
test "break from suspend" {
4675+
test "resume from suspend" {
46764676
var buf: [500]u8 = undefined;
46774677
var a = &std.heap.FixedBufferAllocator.init(buf[0..]).allocator;
46784678
var my_result: i32 = 1;
4679-
const p = try async<a> testBreakFromSuspend(&my_result);
4679+
const p = try async<a> testResumeFromSuspend(&my_result);
46804680
cancel p;
46814681
std.debug.assert(my_result == 2);
46824682
}
4683-
async fn testBreakFromSuspend(my_result: *i32) void {
4684-
s: suspend |p| {
4685-
break :s;
4683+
async fn testResumeFromSuspend(my_result: *i32) void {
4684+
suspend |p| {
4685+
resume p;
46864686
}
46874687
my_result.* += 1;
46884688
suspend;
@@ -7336,7 +7336,7 @@ Defer(body) = ("defer" | "deferror") body
73367336

73377337
IfExpression(body) = "if" "(" Expression ")" body option("else" BlockExpression(body))
73387338

7339-
SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
7339+
SuspendExpression(body) = "suspend" option(("|" Symbol "|" body))
73407340

73417341
IfErrorExpression(body) = "if" "(" Expression ")" option("|" option("*") Symbol "|") body "else" "|" Symbol "|" BlockExpression(body)
73427342

‎src/all_types.hpp‎

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct IrExecutable {
6060
ZigList<Tld *> tld_list;
6161

6262
IrInstruction *coro_handle;
63-
IrInstruction *coro_awaiter_field_ptr; // this one is shared and in the promise
63+
IrInstruction *atomic_state_field_ptr; // this one is shared and in the promise
6464
IrInstruction *coro_result_ptr_field_ptr;
6565
IrInstruction *coro_result_field_ptr;
6666
IrInstruction *await_handle_var_ptr; // this one is where we put the one we extracted from the promise
@@ -898,7 +898,6 @@ struct AstNodeAwaitExpr {
898898
};
899899

900900
struct AstNodeSuspend {
901-
Buf *name;
902901
AstNode *block;
903902
AstNode *promise_symbol;
904903
};
@@ -1929,7 +1928,6 @@ struct ScopeLoop {
19291928
struct ScopeSuspend {
19301929
Scope base;
19311930

1932-
Buf *name;
19331931
IrBasicBlock *resume_block;
19341932
bool reported_err;
19351933
};
@@ -3245,7 +3243,7 @@ static const size_t stack_trace_ptr_count = 30;
32453243
#define RESULT_FIELD_NAME "result"
32463244
#define ASYNC_ALLOC_FIELD_NAME "allocFn"
32473245
#define ASYNC_FREE_FIELD_NAME "freeFn"
3248-
#define AWAITER_HANDLE_FIELD_NAME "awaiter_handle"
3246+
#define ATOMIC_STATE_FIELD_NAME "atomic_state"
32493247
// these point to data belonging to the awaiter
32503248
#define ERR_RET_TRACE_PTR_FIELD_NAME "err_ret_trace_ptr"
32513249
#define RESULT_PTR_FIELD_NAME "result_ptr"

‎src/analyze.cpp‎

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@ ScopeSuspend *create_suspend_scope(AstNode *node, Scope *parent) {
161161
assert(node->type == NodeTypeSuspend);
162162
ScopeSuspend *scope = allocate<ScopeSuspend>(1);
163163
init_scope(&scope->base, ScopeIdSuspend, node, parent);
164-
scope->name = node->data.suspend.name;
165164
return scope;
166165
}
167166

@@ -519,11 +518,11 @@ TypeTableEntry *get_promise_frame_type(CodeGen *g, TypeTableEntry *return_type)
519518
return return_type->promise_frame_parent;
520519
}
521520

522-
TypeTableEntry *awaiter_handle_type = get_optional_type(g, g->builtin_types.entry_promise);
521+
TypeTableEntry *atomic_state_type = g->builtin_types.entry_usize;
523522
TypeTableEntry *result_ptr_type = get_pointer_to_type(g, return_type, false);
524523

525524
ZigList<const char *> field_names = {};
526-
field_names.append(AWAITER_HANDLE_FIELD_NAME);
525+
field_names.append(ATOMIC_STATE_FIELD_NAME);
527526
field_names.append(RESULT_FIELD_NAME);
528527
field_names.append(RESULT_PTR_FIELD_NAME);
529528
if (g->have_err_ret_tracing) {
@@ -533,7 +532,7 @@ TypeTableEntry *get_promise_frame_type(CodeGen *g, TypeTableEntry *return_type)
533532
}
534533

535534
ZigList<TypeTableEntry *> field_types = {};
536-
field_types.append(awaiter_handle_type);
535+
field_types.append(atomic_state_type);
537536
field_types.append(return_type);
538537
field_types.append(result_ptr_type);
539538
if (g->have_err_ret_tracing) {
@@ -6228,7 +6227,12 @@ uint32_t get_abi_alignment(CodeGen *g, TypeTableEntry *type_entry) {
62286227
} else if (type_entry->id == TypeTableEntryIdOpaque) {
62296228
return 1;
62306229
} else {
6231-
return LLVMABIAlignmentOfType(g->target_data_ref, type_entry->type_ref);
6230+
uint32_t llvm_alignment = LLVMABIAlignmentOfType(g->target_data_ref, type_entry->type_ref);
6231+
// promises have at least alignment 8 so that we can have 3 extra bits when doing atomicrmw
6232+
if (type_entry->id == TypeTableEntryIdPromise && llvm_alignment < 8) {
6233+
return 8;
6234+
}
6235+
return llvm_alignment;
62326236
}
62336237
}
62346238

‎src/ir.cpp‎

Lines changed: 357 additions & 97 deletions
Large diffs are not rendered by default.

‎src/parser.cpp‎

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -648,30 +648,12 @@ static AstNode *ast_parse_asm_expr(ParseContext *pc, size_t *token_index, bool m
648648
}
649649

650650
/*
651-
SuspendExpression(body) = option(Symbol ":") "suspend" option(("|" Symbol "|" body))
651+
SuspendExpression(body) = "suspend" option(("|" Symbol "|" body))
652652
*/
653653
static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, bool mandatory) {
654654
size_t orig_token_index = *token_index;
655655

656-
Token *name_token = nullptr;
657-
Token *token = &pc->tokens->at(*token_index);
658-
if (token->id == TokenIdSymbol) {
659-
*token_index += 1;
660-
Token *colon_token = &pc->tokens->at(*token_index);
661-
if (colon_token->id == TokenIdColon) {
662-
*token_index += 1;
663-
name_token = token;
664-
token = &pc->tokens->at(*token_index);
665-
} else if (mandatory) {
666-
ast_expect_token(pc, colon_token, TokenIdColon);
667-
zig_unreachable();
668-
} else {
669-
*token_index = orig_token_index;
670-
return nullptr;
671-
}
672-
}
673-
674-
Token *suspend_token = token;
656+
Token *suspend_token = &pc->tokens->at(*token_index);
675657
if (suspend_token->id == TokenIdKeywordSuspend) {
676658
*token_index += 1;
677659
} else if (mandatory) {
@@ -693,9 +675,6 @@ static AstNode *ast_parse_suspend_block(ParseContext *pc, size_t *token_index, b
693675
}
694676

695677
AstNode *node = ast_create_node(pc, NodeTypeSuspend, suspend_token);
696-
if (name_token != nullptr) {
697-
node->data.suspend.name = token_buf(name_token);
698-
}
699678
node->data.suspend.promise_symbol = ast_parse_symbol(pc, token_index);
700679
ast_eat_token(pc, token_index, TokenIdBinOr);
701680
node->data.suspend.block = ast_parse_block(pc, token_index, true);

‎std/debug/index.zig‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub fn warn(comptime fmt: []const u8, args: ...) void {
2727
const stderr = getStderrStream() catch return;
2828
stderr.print(fmt, args) catch return;
2929
}
30-
fn getStderrStream() !*io.OutStream(io.FileOutStream.Error) {
30+
pub fn getStderrStream() !*io.OutStream(io.FileOutStream.Error) {
3131
if (stderr_stream) |st| {
3232
return st;
3333
} else {
@@ -172,6 +172,16 @@ pub fn writeStackTrace(stack_trace: *const builtin.StackTrace, out_stream: var,
172172
}
173173
}
174174

175+
pub inline fn getReturnAddress(frame_count: usize) usize {
176+
var fp = @ptrToInt(@frameAddress());
177+
var i: usize = 0;
178+
while (fp != 0 and i < frame_count) {
179+
fp = @intToPtr(*const usize, fp).*;
180+
i += 1;
181+
}
182+
return @intToPtr(*const usize, fp + @sizeOf(usize)).*;
183+
}
184+
175185
pub fn writeCurrentStackTrace(out_stream: var, allocator: *mem.Allocator, debug_info: *ElfStackTrace, tty_color: bool, start_addr: ?usize) !void {
176186
const AddressState = union(enum) {
177187
NotLookingForStartAddress,
@@ -205,7 +215,7 @@ pub fn writeCurrentStackTrace(out_stream: var, allocator: *mem.Allocator, debug_
205215
}
206216
}
207217

208-
fn printSourceAtAddress(debug_info: *ElfStackTrace, out_stream: var, address: usize, tty_color: bool) !void {
218+
pub fn printSourceAtAddress(debug_info: *ElfStackTrace, out_stream: var, address: usize, tty_color: bool) !void {
209219
switch (builtin.os) {
210220
builtin.Os.windows => return error.UnsupportedDebugInfo,
211221
builtin.Os.macosx => {

‎std/event/loop.zig‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub const Loop = struct {
5555
/// After initialization, call run().
5656
/// TODO copy elision / named return values so that the threads referencing *Loop
5757
/// have the correct pointer value.
58-
fn initSingleThreaded(self: *Loop, allocator: *mem.Allocator) !void {
58+
pub fn initSingleThreaded(self: *Loop, allocator: *mem.Allocator) !void {
5959
return self.initInternal(allocator, 1);
6060
}
6161

@@ -64,7 +64,7 @@ pub const Loop = struct {
6464
/// After initialization, call run().
6565
/// TODO copy elision / named return values so that the threads referencing *Loop
6666
/// have the correct pointer value.
67-
fn initMultiThreaded(self: *Loop, allocator: *mem.Allocator) !void {
67+
pub fn initMultiThreaded(self: *Loop, allocator: *mem.Allocator) !void {
6868
const core_count = try std.os.cpuCount(allocator);
6969
return self.initInternal(allocator, core_count);
7070
}

‎test/behavior.zig‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ comptime {
1616
_ = @import("cases/bugs/828.zig");
1717
_ = @import("cases/bugs/920.zig");
1818
_ = @import("cases/byval_arg_var.zig");
19+
_ = @import("cases/cancel.zig");
1920
_ = @import("cases/cast.zig");
2021
_ = @import("cases/const_slice_child.zig");
2122
_ = @import("cases/coroutine_await_struct.zig");

‎test/cases/cancel.zig‎

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
const std = @import("std");
2+
3+
var defer_f1: bool = false;
4+
var defer_f2: bool = false;
5+
var defer_f3: bool = false;
6+
7+
test "cancel forwards" {
8+
var da = std.heap.DirectAllocator.init();
9+
defer da.deinit();
10+
11+
const p = async<&da.allocator> f1() catch unreachable;
12+
cancel p;
13+
std.debug.assert(defer_f1);
14+
std.debug.assert(defer_f2);
15+
std.debug.assert(defer_f3);
16+
}
17+
18+
async fn f1() void {
19+
defer {
20+
defer_f1 = true;
21+
}
22+
await (async f2() catch unreachable);
23+
}
24+
25+
async fn f2() void {
26+
defer {
27+
defer_f2 = true;
28+
}
29+
await (async f3() catch unreachable);
30+
}
31+
32+
async fn f3() void {
33+
defer {
34+
defer_f3 = true;
35+
}
36+
suspend;
37+
}
38+
39+
var defer_b1: bool = false;
40+
var defer_b2: bool = false;
41+
var defer_b3: bool = false;
42+
var defer_b4: bool = false;
43+
44+
test "cancel backwards" {
45+
var da = std.heap.DirectAllocator.init();
46+
defer da.deinit();
47+
48+
const p = async<&da.allocator> b1() catch unreachable;
49+
cancel p;
50+
std.debug.assert(defer_b1);
51+
std.debug.assert(defer_b2);
52+
std.debug.assert(defer_b3);
53+
std.debug.assert(defer_b4);
54+
}
55+
56+
async fn b1() void {
57+
defer {
58+
defer_b1 = true;
59+
}
60+
await (async b2() catch unreachable);
61+
}
62+
63+
var b4_handle: promise = undefined;
64+
65+
async fn b2() void {
66+
const b3_handle = async b3() catch unreachable;
67+
resume b4_handle;
68+
cancel b4_handle;
69+
defer {
70+
defer_b2 = true;
71+
}
72+
const value = await b3_handle;
73+
@panic("unreachable");
74+
}
75+
76+
async fn b3() i32 {
77+
defer {
78+
defer_b3 = true;
79+
}
80+
await (async b4() catch unreachable);
81+
return 1234;
82+
}
83+
84+
async fn b4() void {
85+
defer {
86+
defer_b4 = true;
87+
}
88+
suspend |p| {
89+
b4_handle = p;
90+
}
91+
suspend;
92+
}

‎test/cases/coroutines.zig‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,8 +244,8 @@ test "break from suspend" {
244244
std.debug.assert(my_result == 2);
245245
}
246246
async fn testBreakFromSuspend(my_result: *i32) void {
247-
s: suspend |p| {
248-
break :s;
247+
suspend |p| {
248+
resume p;
249249
}
250250
my_result.* += 1;
251251
suspend;

0 commit comments

Comments
 (0)
Please sign in to comment.