Skip to content

Commit d4270cf

Browse files
author
whitequark
committedAug 9, 2015
Implement receiving data from RPCs.
1 parent 02b1543 commit d4270cf

File tree

6 files changed

+337
-75
lines changed

6 files changed

+337
-75
lines changed
 

Diff for: ‎artiq/compiler/transforms/llvm_ir_generator.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def llbuiltin(self, name):
162162
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()],
163163
var_arg=True)
164164
elif name == "recv_rpc":
165-
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()])
165+
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer()])
166166
else:
167167
assert False
168168

@@ -571,7 +571,7 @@ def _prepare_closure_call(self, insn):
571571
llfun = self.llbuilder.extract_value(llclosure, 1)
572572
return llfun, [llenv] + list(llargs)
573573

574-
# See session.c:send_rpc_value and session.c:recv_rpc_value.
574+
# See session.c:{send,receive}_rpc_value and comm_generic.py:_{send,receive}_rpc_value.
575575
def _rpc_tag(self, typ, error_handler):
576576
if types.is_tuple(typ):
577577
assert len(typ.elts) < 256
@@ -666,29 +666,30 @@ def ret_error_handler(typ):
666666
llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc")
667667
lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail")
668668

669-
llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer())
670-
self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot)
669+
llretty = self.llty_of_type(fun_type.ret)
670+
llslot = self.llbuilder.alloca(llretty)
671+
llslotgen = self.llbuilder.bitcast(llslot, ll.IntType(8).as_pointer())
671672
self.llbuilder.branch(llhead)
672673

673674
self.llbuilder.position_at_end(llhead)
675+
llphi = self.llbuilder.phi(llslotgen.type)
676+
llphi.add_incoming(llslotgen, llprehead)
674677
if llunwindblock:
675-
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot],
678+
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llphi],
676679
llheadu, llunwindblock)
677680
self.llbuilder.position_at_end(llheadu)
678681
else:
679-
llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot])
682+
llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llphi])
680683
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0))
681684
self.llbuilder.cbranch(lldone, lltail, llalloc)
682685

683686
self.llbuilder.position_at_end(llalloc)
684687
llalloca = self.llbuilder.alloca(ll.IntType(8), llsize)
685-
self.llbuilder.store(llalloca, llslot)
688+
llphi.add_incoming(llalloca, llalloc)
686689
self.llbuilder.branch(llhead)
687690

688691
self.llbuilder.position_at_end(lltail)
689-
llretty = self.llty_of_type(fun_type.ret, for_return=True)
690-
llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer())
691-
llret = self.llbuilder.load(llretptr)
692+
llret = self.llbuilder.load(llslot)
692693
if not builtins.is_allocated(fun_type.ret):
693694
# We didn't allocate anything except the slot for the value itself.
694695
# Don't waste stack space.

Diff for: ‎artiq/coredevice/comm_generic.py

+100-14
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99

1010
logger = logging.getLogger(__name__)
11-
logger.setLevel(logging.DEBUG)
1211

1312

1413
class _H2DMsgType(Enum):
@@ -51,6 +50,9 @@ class _D2HMsgType(Enum):
5150
class UnsupportedDevice(Exception):
5251
pass
5352

53+
class RPCReturnValueError(ValueError):
54+
pass
55+
5456

5557
class CommGeneric:
5658
def __init__(self):
@@ -279,6 +281,7 @@ def run(self):
279281

280282
_rpc_sentinel = object()
281283

284+
# See session.c:{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
282285
def _receive_rpc_value(self, rpc_map):
283286
tag = chr(self._read_int8())
284287
if tag == "\x00":
@@ -306,11 +309,15 @@ def _receive_rpc_value(self, rpc_map):
306309
length = self._read_int32()
307310
return [self._receive_rpc_value(rpc_map) for _ in range(length)]
308311
elif tag == "r":
309-
lower = self._receive_rpc_value(rpc_map)
310-
upper = self._receive_rpc_value(rpc_map)
312+
start = self._receive_rpc_value(rpc_map)
313+
stop = self._receive_rpc_value(rpc_map)
311314
step = self._receive_rpc_value(rpc_map)
312-
return range(lower, upper, step)
315+
return range(start, stop, step)
313316
elif tag == "o":
317+
present = self._read_int8()
318+
if present:
319+
return self._receive_rpc_value(rpc_map)
320+
elif tag == "O":
314321
return rpc_map[self._read_int32()]
315322
else:
316323
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
@@ -323,16 +330,101 @@ def _receive_rpc_args(self, rpc_map):
323330
return args
324331
args.append(value)
325332

333+
def _skip_rpc_value(self, tags):
334+
tag = tags.pop(0)
335+
if tag == "t":
336+
length = tags.pop(0)
337+
for _ in range(length):
338+
self._skip_rpc_value(tags)
339+
elif tag == "l":
340+
self._skip_rpc_value(tags)
341+
elif tag == "r":
342+
self._skip_rpc_value(tags)
343+
else:
344+
pass
345+
346+
def _send_rpc_value(self, tags, value, root, function):
347+
def check(cond, expected):
348+
if not cond:
349+
raise RPCReturnValueError(
350+
"type mismatch: cannot serialize {value} as {type}"
351+
" ({function} has returned {root})".format(
352+
value=repr(value), type=expected(),
353+
function=function, root=root))
354+
355+
tag = chr(tags.pop(0))
356+
if tag == "t":
357+
length = tags.pop(0)
358+
check(isinstance(value, tuple) and length == len(value),
359+
lambda: "tuple of {}".format(length))
360+
for elt in value:
361+
self._send_rpc_value(tags, elt, root, function)
362+
elif tag == "n":
363+
check(value is None,
364+
lambda: "None")
365+
elif tag == "b":
366+
check(isinstance(value, bool),
367+
lambda: "bool")
368+
self._write_int8(value)
369+
elif tag == "i":
370+
check(isinstance(value, int) and (-2**31 < value < 2**31-1),
371+
lambda: "32-bit int")
372+
self._write_int32(value)
373+
elif tag == "I":
374+
check(isinstance(value, int) and (-2**63 < value < 2**63-1),
375+
lambda: "64-bit int")
376+
self._write_int64(value)
377+
elif tag == "f":
378+
check(isinstance(value, float),
379+
lambda: "float")
380+
self._write_float64(value)
381+
elif tag == "F":
382+
check(isinstance(value, Fraction) and
383+
(-2**63 < value.numerator < 2**63-1) and
384+
(-2**63 < value.denominator < 2**63-1),
385+
lambda: "64-bit Fraction")
386+
self._write_int64(value.numerator)
387+
self._write_int64(value.denominator)
388+
elif tag == "s":
389+
check(isinstance(value, str) and "\x00" not in value,
390+
lambda: "str")
391+
self._write_string(value)
392+
elif tag == "l":
393+
check(isinstance(value, list),
394+
lambda: "list")
395+
self._write_int32(len(value))
396+
for elt in value:
397+
tags_copy = bytearray(tags)
398+
self._send_rpc_value(tags_copy, elt, root, function)
399+
self._skip_rpc_value(tags)
400+
elif tag == "r":
401+
check(isinstance(value, range),
402+
lambda: "range")
403+
tags_copy = bytearray(tags)
404+
self._send_rpc_value(tags_copy, value.start, root, function)
405+
tags_copy = bytearray(tags)
406+
self._send_rpc_value(tags_copy, value.stop, root, function)
407+
tags_copy = bytearray(tags)
408+
self._send_rpc_value(tags_copy, value.step, root, function)
409+
tags = tags_copy
410+
else:
411+
raise IOError("Unknown RPC value tag: {}".format(repr(tag)))
412+
326413
def _serve_rpc(self, rpc_map):
327414
service = self._read_int32()
328415
args = self._receive_rpc_args(rpc_map)
329-
return_tag = self._read_string()
330-
logger.debug("rpc service: %d %r -> %s", service, args, return_tag)
416+
return_tags = self._read_bytes()
417+
logger.debug("rpc service: %d %r -> %s", service, args, return_tags)
331418

332419
try:
333420
result = rpc_map[service](*args)
334-
if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
335-
raise ValueError("An RPC must return an int(width=32)")
421+
logger.debug("rpc service: %d %r == %r", service, args, result)
422+
423+
self._write_header(_H2DMsgType.RPC_REPLY)
424+
self._write_bytes(return_tags)
425+
self._send_rpc_value(bytearray(return_tags), result, result,
426+
rpc_map[service])
427+
self._write_flush()
336428
except core_language.ARTIQException as exn:
337429
logger.debug("rpc service: %d %r ! %r", service, args, exn)
338430

@@ -364,12 +456,6 @@ def _serve_rpc(self, rpc_map):
364456
self._write_string(function)
365457

366458
self._write_flush()
367-
else:
368-
logger.debug("rpc service: %d %r == %r", service, args, result)
369-
370-
self._write_header(_H2DMsgType.RPC_REPLY)
371-
self._write_int32(result)
372-
self._write_flush()
373459

374460
def _serve_exception(self):
375461
name = self._read_string()

Diff for: ‎soc/runtime/ksupport.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ void send_rpc(int service, const char *tag, ...)
314314
va_end(request.args);
315315
}
316316

317-
int recv_rpc(void **slot) {
317+
int recv_rpc(void *slot) {
318318
struct msg_rpc_recv_request request;
319319
struct msg_rpc_recv_reply *reply;
320320

Diff for: ‎soc/runtime/ksupport.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ void now_save(long long int now);
66
int watchdog_set(int ms);
77
void watchdog_clear(int id);
88
void send_rpc(int service, const char *tag, ...);
9-
int recv_rpc(void **slot);
9+
int recv_rpc(void *slot);
1010
void lognonl(const char *fmt, ...);
1111
void log(const char *fmt, ...);
1212

Diff for: ‎soc/runtime/messages.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ struct msg_rpc_send {
8989

9090
struct msg_rpc_recv_request {
9191
int type;
92-
void **slot;
92+
void *slot;
9393
};
9494

9595
struct msg_rpc_recv_reply {

Diff for: ‎soc/runtime/session.c

+223-48
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ static const char *in_packet_string()
147147
{
148148
int length;
149149
const char *string = in_packet_bytes(&length);
150-
if(string[length] != 0) {
150+
if(string[length - 1] != 0) {
151151
log("session.c: string is not zero-terminated");
152152
return "";
153153
}
@@ -346,6 +346,8 @@ enum {
346346
REMOTEMSG_TYPE_FLASH_ERROR_REPLY
347347
};
348348

349+
static int receive_rpc_value(const char **tag, void **slot);
350+
349351
static int process_input(void)
350352
{
351353
switch(buffer_in.header.type) {
@@ -457,23 +459,37 @@ static int process_input(void)
457459
user_kernel_state = USER_KERNEL_RUNNING;
458460
break;
459461

460-
// case REMOTEMSG_TYPE_RPC_REPLY: {
461-
// struct msg_rpc_reply reply;
462+
case REMOTEMSG_TYPE_RPC_REPLY: {
463+
struct msg_rpc_recv_request *request;
464+
struct msg_rpc_recv_reply reply;
462465

463-
// int result = in_packet_int32();
466+
if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
467+
log("Unsolicited RPC reply");
468+
return 0; // restart session
469+
}
470+
471+
request = mailbox_wait_and_receive();
472+
if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
473+
log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
474+
request->type);
475+
return 0; // restart session
476+
}
464477

465-
// if(user_kernel_state != USER_KERNEL_WAIT_RPC) {
466-
// log("Unsolicited RPC reply");
467-
// return 0; // restart session
468-
// }
478+
const char *tag = in_packet_string();
479+
void *slot = request->slot;
480+
if(!receive_rpc_value(&tag, &slot)) {
481+
log("Failed to receive RPC reply");
482+
return 0; // restart session
483+
}
469484

470-
// reply.type = MESSAGE_TYPE_RPC_REPLY;
471-
// reply.result = result;
472-
// mailbox_send_and_wait(&reply);
485+
reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
486+
reply.alloc_size = 0;
487+
reply.exception = NULL;
488+
mailbox_send_and_wait(&reply);
473489

474-
// user_kernel_state = USER_KERNEL_RUNNING;
475-
// break;
476-
// }
490+
user_kernel_state = USER_KERNEL_RUNNING;
491+
break;
492+
}
477493

478494
case REMOTEMSG_TYPE_RPC_EXCEPTION: {
479495
struct msg_rpc_recv_request *request;
@@ -512,13 +528,191 @@ static int process_input(void)
512528
}
513529

514530
default:
531+
log("Received invalid packet type %d from host",
532+
buffer_in.header.type);
533+
return 0;
534+
}
535+
536+
return 1;
537+
}
538+
539+
// See comm_generic.py:_{send,receive}_rpc_value and llvm_ir_generator.py:_rpc_tag.
540+
static void skip_rpc_value(const char **tag) {
541+
switch(*(*tag)++) {
542+
case 't': {
543+
int size = *(*tag)++;
544+
for(int i = 0; i < size; i++)
545+
skip_rpc_value(tag);
546+
break;
547+
}
548+
549+
case 'l':
550+
skip_rpc_value(tag);
551+
break;
552+
553+
case 'r':
554+
skip_rpc_value(tag);
555+
break;
556+
}
557+
}
558+
559+
static int sizeof_rpc_value(const char **tag)
560+
{
561+
switch(*(*tag)++) {
562+
case 't': { // tuple
563+
int size = *(*tag)++;
564+
565+
int32_t length = 0;
566+
for(int i = 0; i < size; i++)
567+
length += sizeof_rpc_value(tag);
568+
return length;
569+
}
570+
571+
case 'n': // None
572+
return 0;
573+
574+
case 'b': // bool
575+
return sizeof(int8_t);
576+
577+
case 'i': // int(width=32)
578+
return sizeof(int32_t);
579+
580+
case 'I': // int(width=64)
581+
return sizeof(int64_t);
582+
583+
case 'f': // float
584+
return sizeof(double);
585+
586+
case 'F': // Fraction
587+
return sizeof(struct { int64_t numerator, denominator; });
588+
589+
case 's': // string
590+
return sizeof(char *);
591+
592+
case 'l': // list(elt='a)
593+
skip_rpc_value(tag);
594+
return sizeof(struct { int32_t length; struct {} *elements; });
595+
596+
case 'r': // range(elt='a)
597+
return sizeof_rpc_value(tag) * 3;
598+
599+
default:
600+
log("sizeof_rpc_value: unknown tag %02x", *((*tag) - 1));
601+
return 0;
602+
}
603+
}
604+
605+
static void *alloc_rpc_value(int size)
606+
{
607+
struct msg_rpc_recv_request *request;
608+
struct msg_rpc_recv_reply reply;
609+
610+
reply.type = MESSAGE_TYPE_RPC_RECV_REPLY;
611+
reply.alloc_size = size;
612+
reply.exception = NULL;
613+
mailbox_send_and_wait(&reply);
614+
615+
request = mailbox_wait_and_receive();
616+
if(request->type != MESSAGE_TYPE_RPC_RECV_REQUEST) {
617+
log("Expected MESSAGE_TYPE_RPC_RECV_REQUEST, got %d",
618+
request->type);
619+
return NULL;
620+
}
621+
return request->slot;
622+
}
623+
624+
static int receive_rpc_value(const char **tag, void **slot)
625+
{
626+
switch(*(*tag)++) {
627+
case 't': { // tuple
628+
int size = *(*tag)++;
629+
630+
for(int i = 0; i < size; i++) {
631+
if(!receive_rpc_value(tag, slot))
632+
return 0;
633+
}
634+
break;
635+
}
636+
637+
case 'n': // None
638+
break;
639+
640+
case 'b': { // bool
641+
*((*(int8_t**)slot)++) = in_packet_int8();
642+
break;
643+
}
644+
645+
case 'i': { // int(width=32)
646+
*((*(int32_t**)slot)++) = in_packet_int32();
647+
break;
648+
}
649+
650+
case 'I': { // int(width=64)
651+
*((*(int64_t**)slot)++) = in_packet_int64();
652+
break;
653+
}
654+
655+
case 'f': { // float
656+
*((*(int64_t**)slot)++) = in_packet_int64();
657+
break;
658+
}
659+
660+
case 'F': { // Fraction
661+
struct { int64_t numerator, denominator; } *fraction = *slot;
662+
fraction->numerator = in_packet_int64();
663+
fraction->denominator = in_packet_int64();
664+
*slot = (void*)((intptr_t)(*slot) + sizeof(*fraction));
665+
break;
666+
}
667+
668+
case 's': { // string
669+
const char *in_string = in_packet_string();
670+
char *out_string = alloc_rpc_value(strlen(in_string) + 1);
671+
memcpy(out_string, in_string, strlen(in_string) + 1);
672+
*((*(char***)slot)++) = out_string;
673+
break;
674+
}
675+
676+
case 'l': { // list(elt='a)
677+
struct { int32_t length; struct {} *elements; } *list = *slot;
678+
list->length = in_packet_int32();
679+
680+
const char *tag_copy = *tag;
681+
list->elements = alloc_rpc_value(sizeof_rpc_value(&tag_copy) * list->length);
682+
683+
void *element = list->elements;
684+
for(int i = 0; i < list->length; i++) {
685+
const char *tag_copy = *tag;
686+
if(!receive_rpc_value(&tag_copy, &element))
687+
return 0;
688+
}
689+
skip_rpc_value(tag);
690+
break;
691+
}
692+
693+
case 'r': { // range(elt='a)
694+
const char *tag_copy;
695+
tag_copy = *tag;
696+
if(!receive_rpc_value(&tag_copy, slot)) // min
697+
return 0;
698+
tag_copy = *tag;
699+
if(!receive_rpc_value(&tag_copy, slot)) // max
700+
return 0;
701+
tag_copy = *tag;
702+
if(!receive_rpc_value(&tag_copy, slot)) // step
703+
return 0;
704+
*tag = tag_copy;
705+
break;
706+
}
707+
708+
default:
709+
log("receive_rpc_value: unknown tag %02x", *((*tag) - 1));
515710
return 0;
516711
}
517712

518713
return 1;
519714
}
520715

521-
// See llvm_ir_generator.py:_rpc_tag.
522716
static int send_rpc_value(const char **tag, void **value)
523717
{
524718
if(!out_packet_int8(**tag))
@@ -541,51 +735,33 @@ static int send_rpc_value(const char **tag, void **value)
541735
break;
542736

543737
case 'b': { // bool
544-
int size = sizeof(int8_t);
545-
if(!out_packet_chunk(*value, size))
546-
return 0;
547-
*value = (void*)((intptr_t)(*value) + size);
548-
break;
738+
return out_packet_int8(*((*(int8_t**)value)++));
549739
}
550740

551741
case 'i': { // int(width=32)
552-
int size = sizeof(int32_t);
553-
if(!out_packet_chunk(*value, size))
554-
return 0;
555-
*value = (void*)((intptr_t)(*value) + size);
556-
break;
742+
return out_packet_int32(*((*(int32_t**)value)++));
557743
}
558744

559745
case 'I': { // int(width=64)
560-
int size = sizeof(int64_t);
561-
if(!out_packet_chunk(*value, size))
562-
return 0;
563-
*value = (void*)((intptr_t)(*value) + size);
564-
break;
746+
return out_packet_int64(*((*(int64_t**)value)++));
565747
}
566748

567749
case 'f': { // float
568-
int size = sizeof(double);
569-
if(!out_packet_chunk(*value, size))
570-
return 0;
571-
*value = (void*)((intptr_t)(*value) + size);
572-
break;
750+
return out_packet_float64(*((*(double**)value)++));
573751
}
574752

575753
case 'F': { // Fraction
576-
int size = sizeof(int64_t) * 2;
577-
if(!out_packet_chunk(*value, size))
754+
struct { int64_t numerator, denominator; } *fraction = *value;
755+
if(!out_packet_int64(fraction->numerator))
756+
return 0;
757+
if(!out_packet_int64(fraction->denominator))
578758
return 0;
579-
*value = (void*)((intptr_t)(*value) + size);
759+
*value = (void*)((intptr_t)(*value) + sizeof(*fraction));
580760
break;
581761
}
582762

583763
case 's': { // string
584-
const char **string = *value;
585-
if(!out_packet_string(*string))
586-
return 0;
587-
*value = (void*)((intptr_t)(*value) + strlen(*string) + 1);
588-
break;
764+
return out_packet_string(*((*(const char***)value)++));
589765
}
590766

591767
case 'l': { // list(elt='a)
@@ -595,11 +771,11 @@ static int send_rpc_value(const char **tag, void **value)
595771
if(!out_packet_int32(list->length))
596772
return 0;
597773

598-
const char *tag_copy;
774+
const char *tag_copy = *tag;
599775
for(int i = 0; i < list->length; i++) {
600-
tag_copy = *tag;
601776
if(!send_rpc_value(&tag_copy, &element))
602777
return 0;
778+
tag_copy = *tag;
603779
}
604780
*tag = tag_copy;
605781

@@ -634,7 +810,7 @@ static int send_rpc_value(const char **tag, void **value)
634810
if(option->present) {
635811
return send_rpc_value(tag, &contents);
636812
} else {
637-
(*tag)++;
813+
skip_rpc_value(tag);
638814
break;
639815
}
640816
}
@@ -668,8 +844,7 @@ static int send_rpc_request(int service, const char *tag, va_list args)
668844
}
669845
out_packet_int8(0);
670846

671-
out_packet_string(tag + 1);
672-
847+
out_packet_string(tag + 1); // return tags
673848
out_packet_finish();
674849
return 1;
675850
}

0 commit comments

Comments
 (0)
Please sign in to comment.