1 | //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass modifies function signatures containing aggregate arguments |
10 | // and/or return value before IRTranslator. Information about the original |
11 | // signatures is stored in metadata. It is used during call lowering to |
12 | // restore correct SPIR-V types of function arguments and return values. |
13 | // This pass also substitutes some llvm intrinsic calls with calls to newly |
14 | // generated functions (as the Khronos LLVM/SPIR-V Translator does). |
15 | // |
16 | // NOTE: this pass is a module-level one due to the necessity to modify |
17 | // GVs/functions. |
18 | // |
19 | //===----------------------------------------------------------------------===// |
20 | |
21 | #include "SPIRV.h" |
22 | #include "SPIRVSubtarget.h" |
23 | #include "SPIRVTargetMachine.h" |
24 | #include "SPIRVUtils.h" |
25 | #include "llvm/ADT/StringExtras.h" |
26 | #include "llvm/Analysis/ValueTracking.h" |
27 | #include "llvm/CodeGen/IntrinsicLowering.h" |
28 | #include "llvm/IR/IRBuilder.h" |
29 | #include "llvm/IR/IntrinsicInst.h" |
30 | #include "llvm/IR/Intrinsics.h" |
31 | #include "llvm/IR/IntrinsicsSPIRV.h" |
32 | #include "llvm/Transforms/Utils/Cloning.h" |
33 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
34 | #include <regex> |
35 | |
36 | using namespace llvm; |
37 | |
38 | namespace { |
39 | |
40 | class SPIRVPrepareFunctions : public ModulePass { |
41 | const SPIRVTargetMachine &TM; |
42 | bool substituteIntrinsicCalls(Function *F); |
43 | Function *removeAggregateTypesFromSignature(Function *F); |
44 | |
45 | public: |
46 | static char ID; |
47 | SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) |
48 | : ModulePass(ID), TM(TM) {} |
49 | |
50 | bool runOnModule(Module &M) override; |
51 | |
52 | StringRef getPassName() const override { return "SPIRV prepare functions" ; } |
53 | |
54 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
55 | ModulePass::getAnalysisUsage(AU); |
56 | } |
57 | }; |
58 | |
59 | } // namespace |
60 | |
61 | char SPIRVPrepareFunctions::ID = 0; |
62 | |
63 | INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions" , |
64 | "SPIRV prepare functions" , false, false) |
65 | |
66 | static std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { |
67 | Function *IntrinsicFunc = II->getCalledFunction(); |
68 | assert(IntrinsicFunc && "Missing function" ); |
69 | std::string FuncName = IntrinsicFunc->getName().str(); |
70 | llvm::replace(Range&: FuncName, OldValue: '.', NewValue: '_'); |
71 | FuncName = "spirv." + FuncName; |
72 | return FuncName; |
73 | } |
74 | |
75 | static Function *getOrCreateFunction(Module *M, Type *RetTy, |
76 | ArrayRef<Type *> ArgTypes, |
77 | StringRef Name) { |
78 | FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false); |
79 | Function *F = M->getFunction(Name); |
80 | if (F && F->getFunctionType() == FT) |
81 | return F; |
82 | Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M); |
83 | if (F) |
84 | NewF->setDSOLocal(F->isDSOLocal()); |
85 | NewF->setCallingConv(CallingConv::SPIR_FUNC); |
86 | return NewF; |
87 | } |
88 | |
89 | static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { |
90 | // For @llvm.memset.* intrinsic cases with constant value and length arguments |
91 | // are emulated via "storing" a constant array to the destination. For other |
92 | // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the |
93 | // intrinsic to a loop via expandMemSetAsLoop(). |
94 | if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic)) |
95 | if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength())) |
96 | return false; // It is handled later using OpCopyMemorySized. |
97 | |
98 | Module *M = Intrinsic->getModule(); |
99 | std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic); |
100 | if (Intrinsic->isVolatile()) |
101 | FuncName += ".volatile" ; |
102 | // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* |
103 | Function *F = M->getFunction(Name: FuncName); |
104 | if (F) { |
105 | Intrinsic->setCalledFunction(F); |
106 | return true; |
107 | } |
108 | // TODO copy arguments attributes: nocapture writeonly. |
109 | FunctionCallee FC = |
110 | M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType()); |
111 | auto IntrinsicID = Intrinsic->getIntrinsicID(); |
112 | Intrinsic->setCalledFunction(FC); |
113 | |
114 | F = dyn_cast<Function>(Val: FC.getCallee()); |
115 | assert(F && "Callee must be a function" ); |
116 | |
117 | switch (IntrinsicID) { |
118 | case Intrinsic::memset: { |
119 | auto *MSI = static_cast<MemSetInst *>(Intrinsic); |
120 | Argument *Dest = F->getArg(i: 0); |
121 | Argument *Val = F->getArg(i: 1); |
122 | Argument *Len = F->getArg(i: 2); |
123 | Argument *IsVolatile = F->getArg(i: 3); |
124 | Dest->setName("dest" ); |
125 | Val->setName("val" ); |
126 | Len->setName("len" ); |
127 | IsVolatile->setName("isvolatile" ); |
128 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
129 | IRBuilder<> IRB(EntryBB); |
130 | auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(), |
131 | isVolatile: MSI->isVolatile()); |
132 | IRB.CreateRetVoid(); |
133 | expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet)); |
134 | MemSet->eraseFromParent(); |
135 | break; |
136 | } |
137 | case Intrinsic::bswap: { |
138 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
139 | IRBuilder<> IRB(EntryBB); |
140 | auto *BSwap = IRB.CreateIntrinsic(ID: Intrinsic::bswap, Types: Intrinsic->getType(), |
141 | Args: F->getArg(i: 0)); |
142 | IRB.CreateRet(V: BSwap); |
143 | IntrinsicLowering IL(M->getDataLayout()); |
144 | IL.LowerIntrinsicCall(CI: BSwap); |
145 | break; |
146 | } |
147 | default: |
148 | break; |
149 | } |
150 | return true; |
151 | } |
152 | |
153 | static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) { |
154 | if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(Val: AnnoVal)) |
155 | AnnoVal = Ref->getOperand(i_nocapture: 0); |
156 | if (auto *Ref = dyn_cast_or_null<BitCastInst>(Val: OptAnnoVal)) |
157 | OptAnnoVal = Ref->getOperand(i_nocapture: 0); |
158 | |
159 | std::string Anno; |
160 | if (auto *C = dyn_cast_or_null<Constant>(Val: AnnoVal)) { |
161 | StringRef Str; |
162 | if (getConstantStringInfo(V: C, Str)) |
163 | Anno = Str; |
164 | } |
165 | // handle optional annotation parameter in a way that Khronos Translator do |
166 | // (collect integers wrapped in a struct) |
167 | if (auto *C = dyn_cast_or_null<Constant>(Val: OptAnnoVal); |
168 | C && C->getNumOperands()) { |
169 | Value *MaybeStruct = C->getOperand(i: 0); |
170 | if (auto *Struct = dyn_cast<ConstantStruct>(Val: MaybeStruct)) { |
171 | for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) { |
172 | if (auto *CInt = dyn_cast<ConstantInt>(Val: Struct->getOperand(i_nocapture: I))) |
173 | Anno += (I == 0 ? ": " : ", " ) + |
174 | std::to_string(val: CInt->getType()->getIntegerBitWidth() == 1 |
175 | ? CInt->getZExtValue() |
176 | : CInt->getSExtValue()); |
177 | } |
178 | } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(Val: MaybeStruct)) { |
179 | // { i32 i32 ... } zeroinitializer |
180 | for (unsigned I = 0, E = Struct->getType()->getStructNumElements(); |
181 | I != E; ++I) |
182 | Anno += I == 0 ? ": 0" : ", 0" ; |
183 | } |
184 | } |
185 | return Anno; |
186 | } |
187 | |
188 | static SmallVector<Metadata *> parseAnnotation(Value *I, |
189 | const std::string &Anno, |
190 | LLVMContext &Ctx, |
191 | Type *Int32Ty) { |
192 | // Try to parse the annotation string according to the following rules: |
193 | // annotation := ({kind} | {kind:value,value,...})+ |
194 | // kind := number |
195 | // value := number | string |
196 | static const std::regex R( |
197 | "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}" ); |
198 | SmallVector<Metadata *> MDs; |
199 | int Pos = 0; |
200 | for (std::sregex_iterator |
201 | It = std::sregex_iterator(Anno.begin(), Anno.end(), R), |
202 | ItEnd = std::sregex_iterator(); |
203 | It != ItEnd; ++It) { |
204 | if (It->position() != Pos) |
205 | return SmallVector<Metadata *>{}; |
206 | Pos = It->position() + It->length(); |
207 | std::smatch Match = *It; |
208 | SmallVector<Metadata *> MDsItem; |
209 | for (std::size_t i = 1; i < Match.size(); ++i) { |
210 | std::ssub_match SMatch = Match[i]; |
211 | std::string Item = SMatch.str(); |
212 | if (Item.length() == 0) |
213 | break; |
214 | if (Item[0] == '"') { |
215 | Item = Item.substr(pos: 1, n: Item.length() - 2); |
216 | // Acceptable format of the string snippet is: |
217 | static const std::regex RStr("^(\\d+)(?:,(\\d+))*$" ); |
218 | if (std::smatch MatchStr; std::regex_match(s: Item, m&: MatchStr, re: RStr)) { |
219 | for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx) |
220 | if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length()) |
221 | MDsItem.push_back(Elt: ConstantAsMetadata::get( |
222 | C: ConstantInt::get(Ty: Int32Ty, V: std::stoi(str: SubStr)))); |
223 | } else { |
224 | MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item)); |
225 | } |
226 | } else if (int32_t Num; llvm::to_integer(S: StringRef(Item), Num, Base: 10)) { |
227 | MDsItem.push_back( |
228 | Elt: ConstantAsMetadata::get(C: ConstantInt::get(Ty: Int32Ty, V: Num))); |
229 | } else { |
230 | MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item)); |
231 | } |
232 | } |
233 | if (MDsItem.size() == 0) |
234 | return SmallVector<Metadata *>{}; |
235 | MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: MDsItem)); |
236 | } |
237 | return Pos == static_cast<int>(Anno.length()) ? MDs |
238 | : SmallVector<Metadata *>{}; |
239 | } |
240 | |
241 | static void lowerPtrAnnotation(IntrinsicInst *II) { |
242 | LLVMContext &Ctx = II->getContext(); |
243 | Type *Int32Ty = Type::getInt32Ty(C&: Ctx); |
244 | |
245 | // Retrieve an annotation string from arguments. |
246 | Value *PtrArg = nullptr; |
247 | if (auto *BI = dyn_cast<BitCastInst>(Val: II->getArgOperand(i: 0))) |
248 | PtrArg = BI->getOperand(i_nocapture: 0); |
249 | else |
250 | PtrArg = II->getOperand(i_nocapture: 0); |
251 | std::string Anno = |
252 | getAnnotation(AnnoVal: II->getArgOperand(i: 1), |
253 | OptAnnoVal: 4 < II->arg_size() ? II->getArgOperand(i: 4) : nullptr); |
254 | |
255 | // Parse the annotation. |
256 | SmallVector<Metadata *> MDs = parseAnnotation(I: II, Anno, Ctx, Int32Ty); |
257 | |
258 | // If the annotation string is not parsed successfully we don't know the |
259 | // format used and output it as a general UserSemantic decoration. |
260 | // Otherwise MDs is a Metadata tuple (a decoration list) in the format |
261 | // expected by `spirv.Decorations`. |
262 | if (MDs.size() == 0) { |
263 | auto UserSemantic = ConstantAsMetadata::get(C: ConstantInt::get( |
264 | Ty: Int32Ty, V: static_cast<uint32_t>(SPIRV::Decoration::UserSemantic))); |
265 | MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: {UserSemantic, MDString::get(Context&: Ctx, Str: Anno)})); |
266 | } |
267 | |
268 | // Build the internal intrinsic function. |
269 | IRBuilder<> IRB(II->getParent()); |
270 | IRB.SetInsertPoint(II); |
271 | IRB.CreateIntrinsic( |
272 | ID: Intrinsic::spv_assign_decoration, Types: {PtrArg->getType()}, |
273 | Args: {PtrArg, MetadataAsValue::get(Context&: Ctx, MD: MDNode::get(Context&: Ctx, MDs))}); |
274 | II->replaceAllUsesWith(V: II->getOperand(i_nocapture: 0)); |
275 | } |
276 | |
277 | static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { |
278 | // Get a separate function - otherwise, we'd have to rework the CFG of the |
279 | // current one. Then simply replace the intrinsic uses with a call to the new |
280 | // function. |
281 | // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) |
282 | Module *M = FSHIntrinsic->getModule(); |
283 | FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); |
284 | Type *FSHRetTy = FSHFuncTy->getReturnType(); |
285 | const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic); |
286 | Function *FSHFunc = |
287 | getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName); |
288 | |
289 | if (!FSHFunc->empty()) { |
290 | FSHIntrinsic->setCalledFunction(FSHFunc); |
291 | return; |
292 | } |
293 | BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate" , Parent: FSHFunc); |
294 | IRBuilder<> IRB(RotateBB); |
295 | Type *Ty = FSHFunc->getReturnType(); |
296 | // Build the actual funnel shift rotate logic. |
297 | // In the comments, "int" is used interchangeably with "vector of int |
298 | // elements". |
299 | FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty); |
300 | Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; |
301 | unsigned BitWidth = IntTy->getIntegerBitWidth(); |
302 | ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth}); |
303 | Value *BitWidthForInsts = |
304 | VectorTy |
305 | ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant) |
306 | : BitWidthConstant; |
307 | Value *RotateModVal = |
308 | IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts); |
309 | Value *FirstShift = nullptr, *SecShift = nullptr; |
310 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
311 | // Shift the less significant number right, the "rotate" number of bits |
312 | // will be 0-filled on the left as a result of this regular shift. |
313 | FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal); |
314 | } else { |
315 | // Shift the more significant number left, the "rotate" number of bits |
316 | // will be 0-filled on the right as a result of this regular shift. |
317 | FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal); |
318 | } |
319 | // We want the "rotate" number of the more significant int's LSBs (MSBs) to |
320 | // occupy the leftmost (rightmost) "0 space" left by the previous operation. |
321 | // Therefore, subtract the "rotate" number from the integer bitsize... |
322 | Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal); |
323 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
324 | // ...and left-shift the more significant int by this number, zero-filling |
325 | // the LSBs. |
326 | SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal); |
327 | } else { |
328 | // ...and right-shift the less significant int by this number, zero-filling |
329 | // the MSBs. |
330 | SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal); |
331 | } |
332 | // A simple binary addition of the shifted ints yields the final result. |
333 | IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift)); |
334 | |
335 | FSHIntrinsic->setCalledFunction(FSHFunc); |
336 | } |
337 | |
338 | static void lowerExpectAssume(IntrinsicInst *II) { |
339 | // If we cannot use the SPV_KHR_expect_assume extension, then we need to |
340 | // ignore the intrinsic and move on. It should be removed later on by LLVM. |
341 | // Otherwise we should lower the intrinsic to the corresponding SPIR-V |
342 | // instruction. |
343 | // For @llvm.assume we have OpAssumeTrueKHR. |
344 | // For @llvm.expect we have OpExpectKHR. |
345 | // |
346 | // We need to lower this into a builtin and then the builtin into a SPIR-V |
347 | // instruction. |
348 | if (II->getIntrinsicID() == Intrinsic::assume) { |
349 | Function *F = Intrinsic::getOrInsertDeclaration( |
350 | M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_assume); |
351 | II->setCalledFunction(F); |
352 | } else if (II->getIntrinsicID() == Intrinsic::expect) { |
353 | Function *F = Intrinsic::getOrInsertDeclaration( |
354 | M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_expect, |
355 | Tys: {II->getOperand(i_nocapture: 0)->getType()}); |
356 | II->setCalledFunction(F); |
357 | } else { |
358 | llvm_unreachable("Unknown intrinsic" ); |
359 | } |
360 | } |
361 | |
362 | static bool toSpvOverloadedIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID, |
363 | ArrayRef<unsigned> OpNos) { |
364 | Function *F = nullptr; |
365 | if (OpNos.empty()) { |
366 | F = Intrinsic::getOrInsertDeclaration(M: II->getModule(), id: NewID); |
367 | } else { |
368 | SmallVector<Type *, 4> Tys; |
369 | for (unsigned OpNo : OpNos) |
370 | Tys.push_back(Elt: II->getOperand(i_nocapture: OpNo)->getType()); |
371 | F = Intrinsic::getOrInsertDeclaration(M: II->getModule(), id: NewID, Tys); |
372 | } |
373 | II->setCalledFunction(F); |
374 | return true; |
375 | } |
376 | |
377 | // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics |
378 | // or calls to proper generated functions. Returns True if F was modified. |
379 | bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { |
380 | bool Changed = false; |
381 | const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F: *F); |
382 | for (BasicBlock &BB : *F) { |
383 | for (Instruction &I : BB) { |
384 | auto Call = dyn_cast<CallInst>(Val: &I); |
385 | if (!Call) |
386 | continue; |
387 | Function *CF = Call->getCalledFunction(); |
388 | if (!CF || !CF->isIntrinsic()) |
389 | continue; |
390 | auto *II = cast<IntrinsicInst>(Val: Call); |
391 | switch (II->getIntrinsicID()) { |
392 | case Intrinsic::memset: |
393 | case Intrinsic::bswap: |
394 | Changed |= lowerIntrinsicToFunction(Intrinsic: II); |
395 | break; |
396 | case Intrinsic::fshl: |
397 | case Intrinsic::fshr: |
398 | lowerFunnelShifts(FSHIntrinsic: II); |
399 | Changed = true; |
400 | break; |
401 | case Intrinsic::assume: |
402 | case Intrinsic::expect: |
403 | if (STI.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) |
404 | lowerExpectAssume(II); |
405 | Changed = true; |
406 | break; |
407 | case Intrinsic::lifetime_start: |
408 | if (!STI.isShader()) { |
409 | Changed |= toSpvOverloadedIntrinsic( |
410 | II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_start, OpNos: {1}); |
411 | } |
412 | break; |
413 | case Intrinsic::lifetime_end: |
414 | if (!STI.isShader()) { |
415 | Changed |= toSpvOverloadedIntrinsic( |
416 | II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_end, OpNos: {1}); |
417 | } |
418 | break; |
419 | case Intrinsic::ptr_annotation: |
420 | lowerPtrAnnotation(II); |
421 | Changed = true; |
422 | break; |
423 | } |
424 | } |
425 | } |
426 | return Changed; |
427 | } |
428 | |
429 | // Returns F if aggregate argument/return types are not present or cloned F |
430 | // function with the types replaced by i32 types. The change in types is |
431 | // noted in 'spv.cloned_funcs' metadata for later restoration. |
432 | Function * |
433 | SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { |
434 | bool IsRetAggr = F->getReturnType()->isAggregateType(); |
435 | // Allow intrinsics with aggregate return type to reach GlobalISel |
436 | if (F->isIntrinsic() && IsRetAggr) |
437 | return F; |
438 | |
439 | IRBuilder<> B(F->getContext()); |
440 | |
441 | bool HasAggrArg = llvm::any_of(Range: F->args(), P: [](Argument &Arg) { |
442 | return Arg.getType()->isAggregateType(); |
443 | }); |
444 | bool DoClone = IsRetAggr || HasAggrArg; |
445 | if (!DoClone) |
446 | return F; |
447 | SmallVector<std::pair<int, Type *>, 4> ChangedTypes; |
448 | Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); |
449 | if (IsRetAggr) |
450 | ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType())); |
451 | SmallVector<Type *, 4> ArgTypes; |
452 | for (const auto &Arg : F->args()) { |
453 | if (Arg.getType()->isAggregateType()) { |
454 | ArgTypes.push_back(Elt: B.getInt32Ty()); |
455 | ChangedTypes.push_back( |
456 | Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); |
457 | } else |
458 | ArgTypes.push_back(Elt: Arg.getType()); |
459 | } |
460 | FunctionType *NewFTy = |
461 | FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg()); |
462 | Function *NewF = |
463 | Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), N: F->getName(), M&: *F->getParent()); |
464 | |
465 | ValueToValueMapTy VMap; |
466 | auto NewFArgIt = NewF->arg_begin(); |
467 | for (auto &Arg : F->args()) { |
468 | StringRef ArgName = Arg.getName(); |
469 | NewFArgIt->setName(ArgName); |
470 | VMap[&Arg] = &(*NewFArgIt++); |
471 | } |
472 | SmallVector<ReturnInst *, 8> Returns; |
473 | |
474 | CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly, |
475 | Returns); |
476 | NewF->takeName(V: F); |
477 | |
478 | NamedMDNode *FuncMD = |
479 | F->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs" ); |
480 | SmallVector<Metadata *, 2> MDArgs; |
481 | MDArgs.push_back(Elt: MDString::get(Context&: B.getContext(), Str: NewF->getName())); |
482 | for (auto &ChangedTyP : ChangedTypes) |
483 | MDArgs.push_back(Elt: MDNode::get( |
484 | Context&: B.getContext(), |
485 | MDs: {ConstantAsMetadata::get(C: B.getInt32(C: ChangedTyP.first)), |
486 | ValueAsMetadata::get(V: Constant::getNullValue(Ty: ChangedTyP.second))})); |
487 | MDNode *ThisFuncMD = MDNode::get(Context&: B.getContext(), MDs: MDArgs); |
488 | FuncMD->addOperand(M: ThisFuncMD); |
489 | |
490 | for (auto *U : make_early_inc_range(Range: F->users())) { |
491 | if (auto *CI = dyn_cast<CallInst>(Val: U)) |
492 | CI->mutateFunctionType(FTy: NewF->getFunctionType()); |
493 | U->replaceUsesOfWith(From: F, To: NewF); |
494 | } |
495 | |
496 | // register the mutation |
497 | if (RetType != F->getReturnType()) |
498 | TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated( |
499 | Val: NewF, Ty: F->getReturnType()); |
500 | return NewF; |
501 | } |
502 | |
503 | bool SPIRVPrepareFunctions::runOnModule(Module &M) { |
504 | bool Changed = false; |
505 | for (Function &F : M) { |
506 | Changed |= substituteIntrinsicCalls(F: &F); |
507 | Changed |= sortBlocks(F); |
508 | } |
509 | |
510 | std::vector<Function *> FuncsWorklist; |
511 | for (auto &F : M) |
512 | FuncsWorklist.push_back(x: &F); |
513 | |
514 | for (auto *F : FuncsWorklist) { |
515 | Function *NewF = removeAggregateTypesFromSignature(F); |
516 | |
517 | if (NewF != F) { |
518 | F->eraseFromParent(); |
519 | Changed = true; |
520 | } |
521 | } |
522 | return Changed; |
523 | } |
524 | |
525 | ModulePass * |
526 | llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { |
527 | return new SPIRVPrepareFunctions(TM); |
528 | } |
529 | |