@@ -171,6 +171,21 @@ def __init__(self, engine, module_name, target):
171
171
self .phis = []
172
172
self .debug_info_emitter = DebugInfoEmitter (self .llmodule )
173
173
174
+ def needs_sret (self , lltyp , may_be_large = True ):
175
+ if isinstance (lltyp , ll .VoidType ):
176
+ return False
177
+ elif isinstance (lltyp , ll .IntType ) and lltyp .width <= 32 :
178
+ return False
179
+ elif isinstance (lltyp , ll .PointerType ):
180
+ return False
181
+ elif may_be_large and isinstance (lltyp , ll .DoubleType ):
182
+ return False
183
+ elif may_be_large and isinstance (lltyp , ll .LiteralStructType ) \
184
+ and len (lltyp .elements ) <= 2 :
185
+ return not any ([self .needs_sret (elt , may_be_large = False ) for elt in lltyp .elements ])
186
+ else :
187
+ return True
188
+
174
189
def llty_of_type (self , typ , bare = False , for_return = False ):
175
190
typ = typ .find ()
176
191
if types .is_tuple (typ ):
@@ -183,13 +198,28 @@ def llty_of_type(self, typ, bare=False, for_return=False):
183
198
elif types ._is_pointer (typ ):
184
199
return llptr
185
200
elif types .is_function (typ ):
201
+ sretarg = []
202
+ llretty = self .llty_of_type (typ .ret , for_return = True )
203
+ if self .needs_sret (llretty ):
204
+ sretarg = [llretty .as_pointer ()]
205
+ llretty = llvoid
206
+
186
207
envarg = llptr
187
- llty = ll .FunctionType (args = [envarg ] +
208
+ llty = ll .FunctionType (args = sretarg + [envarg ] +
188
209
[self .llty_of_type (typ .args [arg ])
189
210
for arg in typ .args ] +
190
211
[self .llty_of_type (ir .TOption (typ .optargs [arg ]))
191
212
for arg in typ .optargs ],
192
- return_type = self .llty_of_type (typ .ret , for_return = True ))
213
+ return_type = llretty )
214
+
215
+ # TODO: actually mark the first argument as sret (also noalias nocapture).
216
+ # llvmlite currently does not have support for this;
217
+ # https://github.com/numba/llvmlite/issues/91.
218
+ if sretarg :
219
+ llty .__has_sret = True
220
+ else :
221
+ llty .__has_sret = False
222
+
193
223
if bare :
194
224
return llty
195
225
else :
@@ -896,8 +926,22 @@ def process_Call(self, insn):
896
926
name = insn .name )
897
927
else :
898
928
llfun , llargs = self ._prepare_closure_call (insn )
899
- return self .llbuilder .call (llfun , llargs ,
900
- name = insn .name )
929
+
930
+ if llfun .type .pointee .__has_sret :
931
+ llstackptr = self .llbuilder .call (self .llbuiltin ("llvm.stacksave" ), [])
932
+
933
+ llresultslot = self .llbuilder .alloca (llfun .type .pointee .args [0 ].pointee )
934
+ print (llfun )
935
+ print (llresultslot )
936
+ self .llbuilder .call (llfun , [llresultslot ] + llargs )
937
+ llresult = self .llbuilder .load (llresultslot )
938
+
939
+ self .llbuilder .call (self .llbuiltin ("llvm.stackrestore" ), [llstackptr ])
940
+
941
+ return llresult
942
+ else :
943
+ return self .llbuilder .call (llfun , llargs ,
944
+ name = insn .name )
901
945
902
946
def process_Invoke (self , insn ):
903
947
llnormalblock = self .map (insn .normal_target ())
@@ -937,7 +981,11 @@ def process_Return(self, insn):
937
981
if builtins .is_none (insn .value ().type ):
938
982
return self .llbuilder .ret_void ()
939
983
else :
940
- return self .llbuilder .ret (self .map (insn .value ()))
984
+ if self .llfunction .type .pointee .__has_sret :
985
+ self .llbuilder .store (self .map (insn .value ()), self .llfunction .args [0 ])
986
+ return self .llbuilder .ret_void ()
987
+ else :
988
+ return self .llbuilder .ret (self .map (insn .value ()))
941
989
942
990
def process_Unreachable (self , insn ):
943
991
return self .llbuilder .unreachable ()
0 commit comments