Skip to content

Commit 0d10ae7

Browse files
committedDec 19, 2014
rpc: support all data types as parameters
1 parent 44e7b99 commit 0d10ae7

File tree

5 files changed

+181
-75
lines changed

5 files changed

+181
-75
lines changed
 

‎artiq/coredevice/comm_serial.py

+45-20
Original file line numberDiff line numberDiff line change
@@ -171,33 +171,58 @@ def run(self, kname):
171171
_write_exactly(self.port, struct.pack(
172172
">lbl", 0x5a5a5a5a, _H2DMsgType.RUN_KERNEL.value, len(kname)))
173173
for c in kname:
174-
_write_exactly(self.port, struct.pack("B", ord(c)))
174+
_write_exactly(self.port, struct.pack(">B", ord(c)))
175175
logger.debug("running kernel: {}".format(kname))
176176

177+
def _receive_rpc_values(self):
178+
r = []
179+
while True:
180+
type_tag = chr(struct.unpack(">B", _read_exactly(self.port, 1))[0])
181+
if type_tag == "\x00":
182+
return r
183+
if type_tag == "n":
184+
r.append(None)
185+
if type_tag == "b":
186+
r.append(bool(struct.unpack(">B",
187+
_read_exactly(self.port, 1))[0]))
188+
if type_tag == "i":
189+
r.append(struct.unpack(">l", _read_exactly(self.port, 4))[0])
190+
if type_tag == "I":
191+
r.append(struct.unpack(">q", _read_exactly(self.port, 8))[0])
192+
if type_tag == "f":
193+
r.append(struct.unpack(">d", _read_exactly(self.port, 8))[0])
194+
if type_tag == "F":
195+
n, d = struct.unpack(">qq", _read_exactly(self.port, 16))
196+
r.append(Fraction(n, d))
197+
if type_tag == "l":
198+
r.append(self._receive_rpc_values())
199+
200+
def _serve_rpc(self, rpc_map):
201+
(rpc_num, ) = struct.unpack(">h", _read_exactly(self.port, 2))
202+
args = self._receive_rpc_values()
203+
logger.debug("rpc service: {} ({})".format(rpc_num, args))
204+
r = rpc_map[rpc_num](*args)
205+
if r is None:
206+
r = 0
207+
_write_exactly(self.port, struct.pack(">l", r))
208+
logger.debug("rpc service: {} ({}) == {}".format(
209+
rpc_num, args, r))
210+
211+
def _serve_exception(self, user_exception_map):
212+
(eid, ) = struct.unpack(">l", _read_exactly(self.port, 4))
213+
if eid < core_language.first_user_eid:
214+
exception = runtime_exceptions.exception_map[eid]
215+
else:
216+
exception = user_exception_map[eid]
217+
raise exception
218+
177219
def serve(self, rpc_map, user_exception_map):
178220
while True:
179221
msg = self._get_device_msg()
180222
if msg == _D2HMsgType.RPC_REQUEST:
181-
rpc_num, n_args = struct.unpack(">hB",
182-
_read_exactly(self.port, 3))
183-
args = []
184-
for i in range(n_args):
185-
args.append(*struct.unpack(">l",
186-
_read_exactly(self.port, 4)))
187-
logger.debug("rpc service: {} ({})".format(rpc_num, args))
188-
r = rpc_map[rpc_num](*args)
189-
if r is None:
190-
r = 0
191-
_write_exactly(self.port, struct.pack(">l", r))
192-
logger.debug("rpc service: {} ({}) == {}".format(
193-
rpc_num, args, r))
223+
self._serve_rpc(rpc_map)
194224
elif msg == _D2HMsgType.KERNEL_EXCEPTION:
195-
(eid, ) = struct.unpack(">l", _read_exactly(self.port, 4))
196-
if eid < core_language.first_user_eid:
197-
exception = runtime_exceptions.exception_map[eid]
198-
else:
199-
exception = user_exception_map[eid]
200-
raise exception
225+
self._serve_exception(user_exception_map)
201226
elif msg == _D2HMsgType.KERNEL_FINISHED:
202227
return
203228
else:

‎artiq/coredevice/runtime.py

+88-38
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import llvmlite.ir as ll
44
import llvmlite.binding as llvm
55

6-
from artiq.py2llvm import base_types
6+
from artiq.py2llvm import base_types, fractions, lists
77
from artiq.language import units
88

99

@@ -12,7 +12,6 @@
1212
llvm.initialize_all_asmprinters()
1313

1414
_syscalls = {
15-
"rpc": "i+:i",
1615
"gpio_set": "ib:n",
1716
"rtio_oe": "ib:n",
1817
"rtio_set": "Iii:n",
@@ -23,50 +22,72 @@
2322
"dds_program": "Iiiiibb:n",
2423
}
2524

26-
_chr_to_type = {
27-
"n": lambda: ll.VoidType(),
28-
"b": lambda: ll.IntType(1),
29-
"i": lambda: ll.IntType(32),
30-
"I": lambda: ll.IntType(64)
31-
}
3225

33-
_chr_to_value = {
34-
"n": lambda: base_types.VNone(),
35-
"b": lambda: base_types.VBool(),
36-
"i": lambda: base_types.VInt(),
37-
"I": lambda: base_types.VInt(64)
38-
}
26+
def _chr_to_type(c):
27+
if c == "n":
28+
return ll.VoidType()
29+
if c == "b":
30+
return ll.IntType(1)
31+
if c == "i":
32+
return ll.IntType(32)
33+
if c == "I":
34+
return ll.IntType(64)
35+
raise ValueError
3936

4037

4138
def _str_to_functype(s):
4239
assert(s[-2] == ":")
43-
type_ret = _chr_to_type[s[-1]]()
44-
45-
var_arg_fixcount = None
46-
type_args = []
47-
for n, c in enumerate(s[:-2]):
48-
if c == "+":
49-
type_args.append(ll.IntType(32))
50-
var_arg_fixcount = n
51-
elif c != "n":
52-
type_args.append(_chr_to_type[c]())
53-
return (var_arg_fixcount,
54-
ll.FunctionType(type_ret, type_args,
55-
var_arg=var_arg_fixcount is not None))
40+
type_ret = _chr_to_type(s[-1])
41+
type_args = [_chr_to_type(c) for c in s[:-2] if c != "n"]
42+
return ll.FunctionType(type_ret, type_args)
43+
44+
45+
def _chr_to_value(c):
46+
if c == "n":
47+
return base_types.VNone()
48+
if c == "b":
49+
return base_types.VBool()
50+
if c == "i":
51+
return base_types.VInt()
52+
if c == "I":
53+
return base_types.VInt(64)
54+
raise ValueError
55+
56+
57+
def _value_to_str(v):
58+
if isinstance(v, base_types.VNone):
59+
return "n"
60+
if isinstance(v, base_types.VBool):
61+
return "b"
62+
if isinstance(v, base_types.VInt):
63+
if v.nbits == 32:
64+
return "i"
65+
if v.nbits == 64:
66+
return "I"
67+
raise ValueError
68+
if isinstance(v, base_types.VFloat):
69+
return "f"
70+
if isinstance(v, fractions.VFraction):
71+
return "F"
72+
if isinstance(v, lists.VList):
73+
return "l" + _value_to_str(v.el_type)
74+
raise ValueError
5675

5776

5877
class LinkInterface:
5978
def init_module(self, module):
6079
self.module = module
6180
llvm_module = self.module.llvm_module
6281

82+
# RPC
83+
func_type = ll.FunctionType(ll.IntType(32), [ll.IntType(32)],
84+
var_arg=1)
85+
self.rpc = ll.Function(llvm_module, func_type, "__syscall_rpc")
86+
6387
# syscalls
6488
self.syscalls = dict()
65-
self.var_arg_fixcount = dict()
6689
for func_name, func_type_str in _syscalls.items():
67-
var_arg_fixcount, func_type = _str_to_functype(func_type_str)
68-
if var_arg_fixcount is not None:
69-
self.var_arg_fixcount[func_name] = var_arg_fixcount
90+
func_type = _str_to_functype(func_type_str)
7091
self.syscalls[func_name] = ll.Function(
7192
llvm_module, func_type, "__syscall_" + func_name)
7293

@@ -91,19 +112,48 @@ def init_module(self, module):
91112
self.eh_raise = ll.Function(llvm_module, func_type, "__eh_raise")
92113
self.eh_raise.attributes.add("noreturn")
93114

94-
def build_syscall(self, syscall_name, args, builder):
95-
r = _chr_to_value[_syscalls[syscall_name][-1]]()
115+
def _build_rpc(self, args, builder):
116+
r = base_types.VInt()
117+
if builder is not None:
118+
new_args = []
119+
new_args.append(args[0].auto_load(builder)) # RPC number
120+
for arg in args[1:]:
121+
# type tag
122+
arg_type_str = _value_to_str(arg)
123+
arg_type_int = 0
124+
for c in reversed(arg_type_str):
125+
arg_type_int <<= 8
126+
arg_type_int |= ord(c)
127+
new_args.append(ll.Constant(ll.IntType(32), arg_type_int))
128+
129+
# pointer to value
130+
if not isinstance(arg, base_types.VNone):
131+
if isinstance(arg.llvm_value.type, ll.PointerType):
132+
new_args.append(arg.llvm_value)
133+
else:
134+
arg_ptr = arg.new()
135+
arg_ptr.alloca(builder)
136+
arg_ptr.auto_store(builder, arg.llvm_value)
137+
new_args.append(arg_ptr.llvm_value)
138+
# end marker
139+
new_args.append(ll.Constant(ll.IntType(32), 0))
140+
r.auto_store(builder, builder.call(self.rpc, new_args))
141+
return r
142+
143+
def _build_regular_syscall(self, syscall_name, args, builder):
144+
r = _chr_to_value(_syscalls[syscall_name][-1])
96145
if builder is not None:
97146
args = [arg.auto_load(builder) for arg in args]
98-
if syscall_name in self.var_arg_fixcount:
99-
fixcount = self.var_arg_fixcount[syscall_name]
100-
args = args[:fixcount] \
101-
+ [ll.Constant(ll.IntType(32), len(args) - fixcount)] \
102-
+ args[fixcount:]
103147
r.auto_store(builder, builder.call(self.syscalls[syscall_name],
104148
args))
105149
return r
106150

151+
def build_syscall(self, syscall_name, args, builder):
152+
if syscall_name == "rpc":
153+
return self._build_rpc(args, builder)
154+
else:
155+
return self._build_regular_syscall(syscall_name, args, builder)
156+
107157
def build_catch(self, builder):
108158
jmpbuf = builder.call(self.eh_push, [])
109159
exception_occured = builder.call(self.eh_setjmp, [jmpbuf])

‎artiq/transforms/inline.py

-11
Original file line numberDiff line numberDiff line change
@@ -447,17 +447,6 @@ def get_attr_writeback(attribute_namespace, rpc_mapper, loc_node):
447447
attr_writeback = []
448448
for (_, attr), attr_info in attribute_namespace.items():
449449
if attr_info.read_write:
450-
# HACK/FIXME: since RPC of non-int is not supported yet, skip
451-
# writeback of other types for now.
452-
# This code breaks if an int is promoted to int64
453-
if hasattr(attr_info.obj, attr):
454-
val = getattr(attr_info.obj, attr)
455-
if (not isinstance(val, int)
456-
or isinstance(val, core_language.int64)
457-
or isinstance(val, bool)):
458-
continue
459-
#
460-
461450
setter = partial(setattr, attr_info.obj, attr)
462451
func = ast.copy_location(
463452
ast.Name("syscall", ast.Load()), loc_node)

‎soc/runtime/comm.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ typedef int (*object_loader)(void *, int);
1111
typedef int (*kernel_runner)(const char *, int *);
1212

1313
void comm_serve(object_loader load_object, kernel_runner run_kernel);
14-
int comm_rpc(int rpc_num, int n_args, ...);
14+
int comm_rpc(int rpc_num, ...);
1515
void comm_log(const char *fmt, ...);
1616

1717
#endif /* __COMM_H */

‎soc/runtime/comm_serial.c

+47-5
Original file line numberDiff line numberDiff line change
@@ -181,17 +181,59 @@ void comm_serve(object_loader load_object, kernel_runner run_kernel)
181181
}
182182
}
183183

184-
int comm_rpc(int rpc_num, int n_args, ...)
184+
static int send_value(int type_tag, void *value)
185185
{
186+
char base_type;
187+
int i, p;
188+
int len;
189+
190+
base_type = type_tag;
191+
send_char(base_type);
192+
switch(base_type) {
193+
case 'n':
194+
return 0;
195+
case 'b':
196+
if(*(char *)value)
197+
send_char(1);
198+
else
199+
send_char(0);
200+
return 1;
201+
case 'i':
202+
send_int(*(int *)value);
203+
return 4;
204+
case 'I':
205+
case 'f':
206+
send_int(*(int *)value);
207+
send_int(*((int *)value + 1));
208+
return 8;
209+
case 'F':
210+
for(i=0;i<4;i++)
211+
send_int(*((int *)value + i));
212+
return 16;
213+
case 'l':
214+
len = *(int *)value;
215+
p = 4;
216+
for(i=0;i<len;i++)
217+
p += send_value(type_tag >> 8, (char *)value + p);
218+
send_char(0);
219+
return p;
220+
}
221+
return 0;
222+
}
223+
224+
int comm_rpc(int rpc_num, ...)
225+
{
226+
int type_tag;
227+
186228
send_char(MSGTYPE_RPC_REQUEST);
187229
send_sint(rpc_num);
188-
send_char(n_args);
189230

190231
va_list args;
191-
va_start(args, n_args);
192-
while(n_args--)
193-
send_int(va_arg(args, int));
232+
va_start(args, rpc_num);
233+
while((type_tag = va_arg(args, int)))
234+
send_value(type_tag, type_tag == 'n' ? NULL : va_arg(args, void *));
194235
va_end(args);
236+
send_char(0);
195237

196238
return receive_int();
197239
}

0 commit comments

Comments
 (0)
Please sign in to comment.