Skip to content

Commit

Permalink
Implement receiving exceptions from RPCs.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 9, 2015
1 parent 8b7d38d commit 02b1543
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 73 deletions.
156 changes: 116 additions & 40 deletions artiq/compiler/transforms/llvm_ir_generator.py
Expand Up @@ -146,6 +146,10 @@ def llbuiltin(self, name):
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.IntType(32)])
elif name == "llvm.copysign.f64":
llty = ll.FunctionType(ll.DoubleType(), [ll.DoubleType(), ll.DoubleType()])
elif name == "llvm.stacksave":
llty = ll.FunctionType(ll.IntType(8).as_pointer(), [])
elif name == "llvm.stackrestore":
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()])
elif name == self.target.print_function:
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(8).as_pointer()], var_arg=True)
elif name == "__artiq_personality":
Expand All @@ -155,8 +159,10 @@ def llbuiltin(self, name):
elif name == "__artiq_reraise":
llty = ll.FunctionType(ll.VoidType(), [])
elif name == "send_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(32), ll.IntType(8).as_pointer()],
llty = ll.FunctionType(ll.VoidType(), [ll.IntType(32), ll.IntType(8).as_pointer()],
var_arg=True)
elif name == "recv_rpc":
llty = ll.FunctionType(ll.IntType(32), [ll.IntType(8).as_pointer().as_pointer()])
else:
assert False

Expand Down Expand Up @@ -559,12 +565,18 @@ def process_Closure(self, insn):
name=insn.name)
return llvalue

# See session.c:send_rpc_value.
def _rpc_tag(self, typ, root_type, root_loc):
def _prepare_closure_call(self, insn):
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)

# See session.c:send_rpc_value and session.c:recv_rpc_value.
def _rpc_tag(self, typ, error_handler):
if types.is_tuple(typ):
assert len(typ.elts) < 256
return b"t" + bytes([len(typ.elts)]) + \
b"".join([self._rpc_tag(elt_type, root_type, root_loc)
b"".join([self._rpc_tag(elt_type, error_handler)
for elt_type in typ.elts])
elif builtins.is_none(typ):
return b"n"
Expand All @@ -580,38 +592,53 @@ def _rpc_tag(self, typ, root_type, root_loc):
return b"s"
elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
error_handler)
elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
root_type, root_loc)
error_handler)
elif ir.is_option(typ):
return b"o" + self._rpc_tag(typ.params["inner"],
root_type, root_loc)
error_handler)
else:
error_handler(typ)

def _build_rpc(self, fun_loc, fun_type, args, llnormalblock, llunwindblock):
llservice = ll.Constant(ll.IntType(32), fun_type.service)

tag = b""

for arg in args:
def arg_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(typ)},
arg.loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(arg.typ)},
arg.loc)
self.engine.process(diag)
tag += self._rpc_tag(arg.type, arg_error_handler)
tag += b":"

def ret_error_handler(typ):
printer = types.TypePrinter()
note = diagnostic.Diagnostic("note",
"value of type {type}",
{"type": printer.name(root_type)},
root_loc)
diag = diagnostic.Diagnostic("error",
"type {type} is not supported in remote procedure calls",
{"type": printer.name(typ)},
root_loc)
fun_loc)
diag = diagnostic.Diagnostic("error",
"return type {type} is not supported in remote procedure calls",
{"type": printer.name(fun_type.ret)},
fun_loc)
self.engine.process(diag)
tag += self._rpc_tag(fun_type.ret, ret_error_handler)
tag += b"\x00"

def _build_rpc(self, service, args, return_type):
llservice = ll.Constant(ll.IntType(32), service)
lltag = self.llconst_of_const(ir.Constant(tag + b"\x00", builtins.TStr()))

tag = b""
for arg in args:
if isinstance(arg, ir.Constant):
# Constants don't have locations, but conveniently
# they also never fail to serialize.
tag += self._rpc_tag(arg.type, arg.type, None)
else:
tag += self._rpc_tag(arg.type, arg.type, arg.loc)
tag += b"\x00"
lltag = self.llconst_of_const(ir.Constant(tag, builtins.TStr()))
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])

llargs = []
for arg in args:
Expand All @@ -620,30 +647,79 @@ def _build_rpc(self, service, args, return_type):
self.llbuilder.store(llarg, llargslot)
llargs.append(llargslot)

return self.llbuiltin("send_rpc"), [llservice, lltag] + llargs
self.llbuilder.call(self.llbuiltin("send_rpc"),
[llservice, lltag] + llargs)

# Don't waste stack space on saved arguments.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])

# T result = {
# void *ptr = NULL;
# loop: int size = rpc_recv("tag", ptr);
# if(size) { ptr = alloca(size); goto loop; }
# else *(T*)ptr
# }
llprehead = self.llbuilder.basic_block
llhead = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head")
if llunwindblock:
llheadu = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.head.unwind")
llalloc = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.alloc")
lltail = self.llbuilder.append_basic_block(name=llprehead.name + ".rpc.tail")

llslot = self.llbuilder.alloca(ll.IntType(8).as_pointer())
self.llbuilder.store(ll.Constant(ll.IntType(8).as_pointer(), None), llslot)
self.llbuilder.branch(llhead)

self.llbuilder.position_at_end(llhead)
if llunwindblock:
llsize = self.llbuilder.invoke(self.llbuiltin("recv_rpc"), [llslot],
llheadu, llunwindblock)
self.llbuilder.position_at_end(llheadu)
else:
llsize = self.llbuilder.call(self.llbuiltin("recv_rpc"), [llslot])
lldone = self.llbuilder.icmp_unsigned('==', llsize, ll.Constant(llsize.type, 0))
self.llbuilder.cbranch(lldone, lltail, llalloc)

self.llbuilder.position_at_end(llalloc)
llalloca = self.llbuilder.alloca(ll.IntType(8), llsize)
self.llbuilder.store(llalloca, llslot)
self.llbuilder.branch(llhead)

self.llbuilder.position_at_end(lltail)
llretty = self.llty_of_type(fun_type.ret, for_return=True)
llretptr = self.llbuilder.bitcast(llslot, llretty.as_pointer())
llret = self.llbuilder.load(llretptr)
if not builtins.is_allocated(fun_type.ret):
# We didn't allocate anything except the slot for the value itself.
# Don't waste stack space.
self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])
if llnormalblock:
self.llbuilder.branch(llnormalblock)
return llret

def prepare_call(self, insn):
def process_Call(self, insn):
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().type.service,
return self._build_rpc(insn.target_function().loc,
insn.target_function().type,
insn.arguments(),
insn.target_function().type.ret)
llnormalblock=None, llunwindblock=None)
else:
llclosure, llargs = self.map(insn.target_function()), map(self.map, insn.arguments())
llenv = self.llbuilder.extract_value(llclosure, 0)
llfun = self.llbuilder.extract_value(llclosure, 1)
return llfun, [llenv] + list(llargs)

def process_Call(self, insn):
llfun, llargs = self.prepare_call(insn)
return self.llbuilder.call(llfun, llargs,
name=insn.name)
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.call(llfun, llargs,
name=insn.name)

def process_Invoke(self, insn):
llfun, llargs = self.prepare_call(insn)
llnormalblock = self.map(insn.normal_target())
llunwindblock = self.map(insn.exception_target())
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
name=insn.name)
if types.is_rpc_function(insn.target_function().type):
return self._build_rpc(insn.target_function().loc,
insn.target_function().type,
insn.arguments(),
llnormalblock, llunwindblock)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.invoke(llfun, llargs, llnormalblock, llunwindblock,
name=insn.name)

def process_Select(self, insn):
return self.llbuilder.select(self.map(insn.condition()),
Expand Down
8 changes: 5 additions & 3 deletions artiq/coredevice/comm_generic.py
Expand Up @@ -8,6 +8,7 @@


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class _H2DMsgType(Enum):
Expand Down Expand Up @@ -325,13 +326,14 @@ def _receive_rpc_args(self, rpc_map):
def _serve_rpc(self, rpc_map):
service = self._read_int32()
args = self._receive_rpc_args(rpc_map)
logger.debug("rpc service: %d %r", service, args)
return_tag = self._read_string()
logger.debug("rpc service: %d %r -> %s", service, args, return_tag)

try:
result = rpc_map[service](*args)
if not isinstance(result, int) or not (-2**31 < result < 2**31-1):
raise ValueError("An RPC must return an int(width=32)")
except ARTIQException as exn:
except core_language.ARTIQException as exn:
logger.debug("rpc service: %d %r ! %r", service, args, exn)

self._write_header(_H2DMsgType.RPC_EXCEPTION)
Expand All @@ -355,7 +357,7 @@ def _serve_rpc(self, rpc_map):
for index in range(3):
self._write_int64(0)

((filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__)
(_, (filename, line, function, _), ) = traceback.extract_tb(exn.__traceback__, 2)
self._write_string(filename)
self._write_int32(line)
self._write_int32(-1) # column not known
Expand Down
45 changes: 28 additions & 17 deletions soc/runtime/ksupport.c
Expand Up @@ -93,6 +93,7 @@ static const struct symbol runtime_exports[] = {
{"log", &log},
{"lognonl", &lognonl},
{"send_rpc", &send_rpc},
{"recv_rpc", &recv_rpc},

/* direct syscalls */
{"rtio_get_counter", &rtio_get_counter},
Expand Down Expand Up @@ -301,7 +302,7 @@ void watchdog_clear(int id)
mailbox_send_and_wait(&request);
}

int send_rpc(int service, const char *tag, ...)
void send_rpc(int service, const char *tag, ...)
{
struct msg_rpc_send request;

Expand All @@ -311,24 +312,34 @@ int send_rpc(int service, const char *tag, ...)
va_start(request.args, tag);
mailbox_send_and_wait(&request);
va_end(request.args);
}

int recv_rpc(void **slot) {
struct msg_rpc_recv_request request;
struct msg_rpc_recv_reply *reply;

// struct msg_base *reply;
// reply = mailbox_wait_and_receive();
// if(reply->type == MESSAGE_TYPE_RPC_REPLY) {
// int result = ((struct msg_rpc_reply *)reply)->result;
// mailbox_acknowledge();
// return result;
// } else if(reply->type == MESSAGE_TYPE_RPC_EXCEPTION) {
// struct artiq_exception exception;
// memcpy(&exception, ((struct msg_rpc_exception *)reply)->exception,
// sizeof(struct artiq_exception));
// mailbox_acknowledge();
// __artiq_raise(&exception);
// } else {
// log("Malformed MESSAGE_TYPE_RPC_REQUEST reply type %d",
// reply->type);
request.type = MESSAGE_TYPE_RPC_RECV_REQUEST;
request.slot = slot;
mailbox_send_and_wait(&request);

reply = mailbox_wait_and_receive();
if(reply->type != MESSAGE_TYPE_RPC_RECV_REPLY) {
log("Malformed MESSAGE_TYPE_RPC_RECV_REQUEST reply type %d",
reply->type);
while(1);
// }
}

if(reply->exception) {
struct artiq_exception exception;
memcpy(&exception, reply->exception,
sizeof(struct artiq_exception));
mailbox_acknowledge();
__artiq_raise(&exception);
} else {
int alloc_size = reply->alloc_size;
mailbox_acknowledge();
return alloc_size;
}
}

void lognonl(const char *fmt, ...)
Expand Down
3 changes: 2 additions & 1 deletion soc/runtime/ksupport.h
Expand Up @@ -5,7 +5,8 @@ long long int now_init(void);
void now_save(long long int now);
int watchdog_set(int ms);
void watchdog_clear(int id);
int send_rpc(int service, const char *tag, ...);
void send_rpc(int service, const char *tag, ...);
int recv_rpc(void **slot);
void lognonl(const char *fmt, ...);
void log(const char *fmt, ...);

Expand Down
9 changes: 2 additions & 7 deletions soc/runtime/messages.h
Expand Up @@ -17,7 +17,6 @@ enum {
MESSAGE_TYPE_RPC_SEND,
MESSAGE_TYPE_RPC_RECV_REQUEST,
MESSAGE_TYPE_RPC_RECV_REPLY,
MESSAGE_TYPE_RPC_EXCEPTION,
MESSAGE_TYPE_LOG,

MESSAGE_TYPE_BRG_READY,
Expand Down Expand Up @@ -90,16 +89,12 @@ struct msg_rpc_send {

struct msg_rpc_recv_request {
int type;
// TODO ???
void **slot;
};

struct msg_rpc_recv_reply {
int type;
// TODO ???
};

struct msg_rpc_exception {
int type;
int alloc_size;
struct artiq_exception *exception;
};

Expand Down

0 comments on commit 02b1543

Please sign in to comment.