Skip to content

Commit

Permalink
LLVMIRGenerator: use sret when returning large structures.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 19, 2015
1 parent 673512f commit 27a6979
Showing 1 changed file with 53 additions and 5 deletions.
58 changes: 53 additions & 5 deletions artiq/compiler/transforms/llvm_ir_generator.py
Expand Up @@ -171,6 +171,21 @@ def __init__(self, engine, module_name, target):
self.phis = []
self.debug_info_emitter = DebugInfoEmitter(self.llmodule)

def needs_sret(self, lltyp, may_be_large=True):
if isinstance(lltyp, ll.VoidType):
return False
elif isinstance(lltyp, ll.IntType) and lltyp.width <= 32:
return False
elif isinstance(lltyp, ll.PointerType):
return False
elif may_be_large and isinstance(lltyp, ll.DoubleType):
return False
elif may_be_large and isinstance(lltyp, ll.LiteralStructType) \
and len(lltyp.elements) <= 2:
return not any([self.needs_sret(elt, may_be_large=False) for elt in lltyp.elements])
else:
return True

def llty_of_type(self, typ, bare=False, for_return=False):
typ = typ.find()
if types.is_tuple(typ):
Expand All @@ -183,13 +198,28 @@ def llty_of_type(self, typ, bare=False, for_return=False):
elif types._is_pointer(typ):
return llptr
elif types.is_function(typ):
sretarg = []
llretty = self.llty_of_type(typ.ret, for_return=True)
if self.needs_sret(llretty):
sretarg = [llretty.as_pointer()]
llretty = llvoid

envarg = llptr
llty = ll.FunctionType(args=[envarg] +
llty = ll.FunctionType(args=sretarg + [envarg] +
[self.llty_of_type(typ.args[arg])
for arg in typ.args] +
[self.llty_of_type(ir.TOption(typ.optargs[arg]))
for arg in typ.optargs],
return_type=self.llty_of_type(typ.ret, for_return=True))
return_type=llretty)

# TODO: actually mark the first argument as sret (also noalias nocapture).
# llvmlite currently does not have support for this;
# https://github.com/numba/llvmlite/issues/91.
if sretarg:
llty.__has_sret = True
else:
llty.__has_sret = False

if bare:
return llty
else:
Expand Down Expand Up @@ -896,8 +926,22 @@ def process_Call(self, insn):
name=insn.name)
else:
llfun, llargs = self._prepare_closure_call(insn)
return self.llbuilder.call(llfun, llargs,
name=insn.name)

if llfun.type.pointee.__has_sret:
llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])

llresultslot = self.llbuilder.alloca(llfun.type.pointee.args[0].pointee)
print(llfun)
print(llresultslot)
self.llbuilder.call(llfun, [llresultslot] + llargs)
llresult = self.llbuilder.load(llresultslot)

self.llbuilder.call(self.llbuiltin("llvm.stackrestore"), [llstackptr])

return llresult
else:
return self.llbuilder.call(llfun, llargs,
name=insn.name)

def process_Invoke(self, insn):
llnormalblock = self.map(insn.normal_target())
Expand Down Expand Up @@ -937,7 +981,11 @@ def process_Return(self, insn):
if builtins.is_none(insn.value().type):
return self.llbuilder.ret_void()
else:
return self.llbuilder.ret(self.map(insn.value()))
if self.llfunction.type.pointee.__has_sret:
self.llbuilder.store(self.map(insn.value()), self.llfunction.args[0])
return self.llbuilder.ret_void()
else:
return self.llbuilder.ret(self.map(insn.value()))

def process_Unreachable(self, insn):
return self.llbuilder.unreachable()
Expand Down

0 comments on commit 27a6979

Please sign in to comment.