1//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass modifies function signatures containing aggregate arguments
10// and/or return value before IRTranslator. Information about the original
11// signatures is stored in metadata. It is used during call lowering to
12// restore correct SPIR-V types of function arguments and return values.
13// This pass also substitutes some llvm intrinsic calls with calls to newly
14// generated functions (as the Khronos LLVM/SPIR-V Translator does).
15//
16// NOTE: this pass is a module-level one due to the necessity to modify
17// GVs/functions.
18//
19//===----------------------------------------------------------------------===//
20
21#include "SPIRV.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVTargetMachine.h"
24#include "SPIRVUtils.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/CodeGen/IntrinsicLowering.h"
28#include "llvm/IR/IRBuilder.h"
29#include "llvm/IR/InstIterator.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/Intrinsics.h"
33#include "llvm/IR/IntrinsicsSPIRV.h"
34#include "llvm/Transforms/Utils/Cloning.h"
35#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
36#include <regex>
37
38using namespace llvm;
39
40namespace {
41
42class SPIRVPrepareFunctions : public ModulePass {
43 const SPIRVTargetMachine &TM;
44 bool substituteIntrinsicCalls(Function *F);
45 Function *removeAggregateTypesFromSignature(Function *F);
46 bool removeAggregateTypesFromCalls(Function *F);
47
48public:
49 static char ID;
50 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM)
51 : ModulePass(ID), TM(TM) {}
52
53 bool runOnModule(Module &M) override;
54
55 StringRef getPassName() const override { return "SPIRV prepare functions"; }
56
57 void getAnalysisUsage(AnalysisUsage &AU) const override {
58 ModulePass::getAnalysisUsage(AU);
59 }
60};
61
62static cl::list<std::string> SPVAllowUnknownIntrinsics(
63 "spv-allow-unknown-intrinsics", cl::CommaSeparated,
64 cl::desc("Emit unknown intrinsics as calls to external functions. A "
65 "comma-separated input list of intrinsic prefixes must be "
66 "provided, and only intrinsics carrying a listed prefix get "
67 "emitted as described."),
68 cl::value_desc("intrinsic_prefix_0,intrinsic_prefix_1"), cl::ValueOptional);
69} // namespace
70
71char SPIRVPrepareFunctions::ID = 0;
72
73INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
74 "SPIRV prepare functions", false, false)
75
76static std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
77 Function *IntrinsicFunc = II->getCalledFunction();
78 assert(IntrinsicFunc && "Missing function");
79 std::string FuncName = IntrinsicFunc->getName().str();
80 llvm::replace(Range&: FuncName, OldValue: '.', NewValue: '_');
81 FuncName = "spirv." + FuncName;
82 return FuncName;
83}
84
85static Function *getOrCreateFunction(Module *M, Type *RetTy,
86 ArrayRef<Type *> ArgTypes,
87 StringRef Name) {
88 FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false);
89 Function *F = M->getFunction(Name);
90 if (F && F->getFunctionType() == FT)
91 return F;
92 Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M);
93 if (F)
94 NewF->setDSOLocal(F->isDSOLocal());
95 NewF->setCallingConv(CallingConv::SPIR_FUNC);
96 return NewF;
97}
98
99static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
100 // For @llvm.memset.* intrinsic cases with constant value and length arguments
101 // are emulated via "storing" a constant array to the destination. For other
102 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
103 // intrinsic to a loop via expandMemSetAsLoop().
104 if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic))
105 if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength()))
106 return false; // It is handled later using OpCopyMemorySized.
107
108 Module *M = Intrinsic->getModule();
109 std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic);
110 if (Intrinsic->isVolatile())
111 FuncName += ".volatile";
112 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
113 Function *F = M->getFunction(Name: FuncName);
114 if (F) {
115 Intrinsic->setCalledFunction(F);
116 return true;
117 }
118 // TODO copy arguments attributes: nocapture writeonly.
119 FunctionCallee FC =
120 M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType());
121 auto IntrinsicID = Intrinsic->getIntrinsicID();
122 Intrinsic->setCalledFunction(FC);
123
124 F = dyn_cast<Function>(Val: FC.getCallee());
125 assert(F && "Callee must be a function");
126
127 switch (IntrinsicID) {
128 case Intrinsic::memset: {
129 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
130 Argument *Dest = F->getArg(i: 0);
131 Argument *Val = F->getArg(i: 1);
132 Argument *Len = F->getArg(i: 2);
133 Argument *IsVolatile = F->getArg(i: 3);
134 Dest->setName("dest");
135 Val->setName("val");
136 Len->setName("len");
137 IsVolatile->setName("isvolatile");
138 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
139 IRBuilder<> IRB(EntryBB);
140 auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(),
141 isVolatile: MSI->isVolatile());
142 IRB.CreateRetVoid();
143 expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet));
144 MemSet->eraseFromParent();
145 break;
146 }
147 case Intrinsic::bswap: {
148 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
149 IRBuilder<> IRB(EntryBB);
150 auto *BSwap = IRB.CreateIntrinsic(ID: Intrinsic::bswap, Types: Intrinsic->getType(),
151 Args: F->getArg(i: 0));
152 IRB.CreateRet(V: BSwap);
153 IntrinsicLowering IL(M->getDataLayout());
154 IL.LowerIntrinsicCall(CI: BSwap);
155 break;
156 }
157 default:
158 break;
159 }
160 return true;
161}
162
163static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
164 if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(Val: AnnoVal))
165 AnnoVal = Ref->getOperand(i_nocapture: 0);
166 if (auto *Ref = dyn_cast_or_null<BitCastInst>(Val: OptAnnoVal))
167 OptAnnoVal = Ref->getOperand(i_nocapture: 0);
168
169 std::string Anno;
170 if (auto *C = dyn_cast_or_null<Constant>(Val: AnnoVal)) {
171 StringRef Str;
172 if (getConstantStringInfo(V: C, Str))
173 Anno = Str;
174 }
175 // handle optional annotation parameter in a way that Khronos Translator do
176 // (collect integers wrapped in a struct)
177 if (auto *C = dyn_cast_or_null<Constant>(Val: OptAnnoVal);
178 C && C->getNumOperands()) {
179 Value *MaybeStruct = C->getOperand(i: 0);
180 if (auto *Struct = dyn_cast<ConstantStruct>(Val: MaybeStruct)) {
181 for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
182 if (auto *CInt = dyn_cast<ConstantInt>(Val: Struct->getOperand(i_nocapture: I)))
183 Anno += (I == 0 ? ": " : ", ") +
184 std::to_string(val: CInt->getType()->getIntegerBitWidth() == 1
185 ? CInt->getZExtValue()
186 : CInt->getSExtValue());
187 }
188 } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(Val: MaybeStruct)) {
189 // { i32 i32 ... } zeroinitializer
190 for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
191 I != E; ++I)
192 Anno += I == 0 ? ": 0" : ", 0";
193 }
194 }
195 return Anno;
196}
197
198static SmallVector<Metadata *> parseAnnotation(Value *I,
199 const std::string &Anno,
200 LLVMContext &Ctx,
201 Type *Int32Ty) {
202 // Try to parse the annotation string according to the following rules:
203 // annotation := ({kind} | {kind:value,value,...})+
204 // kind := number
205 // value := number | string
206 static const std::regex R(
207 "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
208 SmallVector<Metadata *> MDs;
209 int Pos = 0;
210 for (std::sregex_iterator
211 It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
212 ItEnd = std::sregex_iterator();
213 It != ItEnd; ++It) {
214 if (It->position() != Pos)
215 return SmallVector<Metadata *>{};
216 Pos = It->position() + It->length();
217 std::smatch Match = *It;
218 SmallVector<Metadata *> MDsItem;
219 for (std::size_t i = 1; i < Match.size(); ++i) {
220 std::ssub_match SMatch = Match[i];
221 std::string Item = SMatch.str();
222 if (Item.length() == 0)
223 break;
224 if (Item[0] == '"') {
225 Item = Item.substr(pos: 1, n: Item.length() - 2);
226 // Acceptable format of the string snippet is:
227 static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
228 if (std::smatch MatchStr; std::regex_match(s: Item, m&: MatchStr, re: RStr)) {
229 for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
230 if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
231 MDsItem.push_back(Elt: ConstantAsMetadata::get(
232 C: ConstantInt::get(Ty: Int32Ty, V: std::stoi(str: SubStr))));
233 } else {
234 MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item));
235 }
236 } else if (int32_t Num; llvm::to_integer(S: StringRef(Item), Num, Base: 10)) {
237 MDsItem.push_back(
238 Elt: ConstantAsMetadata::get(C: ConstantInt::get(Ty: Int32Ty, V: Num)));
239 } else {
240 MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item));
241 }
242 }
243 if (MDsItem.size() == 0)
244 return SmallVector<Metadata *>{};
245 MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: MDsItem));
246 }
247 return Pos == static_cast<int>(Anno.length()) ? std::move(MDs)
248 : SmallVector<Metadata *>{};
249}
250
251static void lowerPtrAnnotation(IntrinsicInst *II) {
252 LLVMContext &Ctx = II->getContext();
253 Type *Int32Ty = Type::getInt32Ty(C&: Ctx);
254
255 // Retrieve an annotation string from arguments.
256 Value *PtrArg = nullptr;
257 if (auto *BI = dyn_cast<BitCastInst>(Val: II->getArgOperand(i: 0)))
258 PtrArg = BI->getOperand(i_nocapture: 0);
259 else
260 PtrArg = II->getOperand(i_nocapture: 0);
261 std::string Anno =
262 getAnnotation(AnnoVal: II->getArgOperand(i: 1),
263 OptAnnoVal: 4 < II->arg_size() ? II->getArgOperand(i: 4) : nullptr);
264
265 // Parse the annotation.
266 SmallVector<Metadata *> MDs = parseAnnotation(I: II, Anno, Ctx, Int32Ty);
267
268 // If the annotation string is not parsed successfully we don't know the
269 // format used and output it as a general UserSemantic decoration.
270 // Otherwise MDs is a Metadata tuple (a decoration list) in the format
271 // expected by `spirv.Decorations`.
272 if (MDs.size() == 0) {
273 auto UserSemantic = ConstantAsMetadata::get(C: ConstantInt::get(
274 Ty: Int32Ty, V: static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
275 MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: {UserSemantic, MDString::get(Context&: Ctx, Str: Anno)}));
276 }
277
278 // Build the internal intrinsic function.
279 IRBuilder<> IRB(II->getParent());
280 IRB.SetInsertPoint(II);
281 IRB.CreateIntrinsic(
282 ID: Intrinsic::spv_assign_decoration, Types: {PtrArg->getType()},
283 Args: {PtrArg, MetadataAsValue::get(Context&: Ctx, MD: MDNode::get(Context&: Ctx, MDs))});
284 II->replaceAllUsesWith(V: II->getOperand(i_nocapture: 0));
285}
286
287static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
288 // Get a separate function - otherwise, we'd have to rework the CFG of the
289 // current one. Then simply replace the intrinsic uses with a call to the new
290 // function.
291 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
292 Module *M = FSHIntrinsic->getModule();
293 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
294 Type *FSHRetTy = FSHFuncTy->getReturnType();
295 const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic);
296 Function *FSHFunc =
297 getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName);
298
299 if (!FSHFunc->empty()) {
300 FSHIntrinsic->setCalledFunction(FSHFunc);
301 return;
302 }
303 BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate", Parent: FSHFunc);
304 IRBuilder<> IRB(RotateBB);
305 Type *Ty = FSHFunc->getReturnType();
306 // Build the actual funnel shift rotate logic.
307 // In the comments, "int" is used interchangeably with "vector of int
308 // elements".
309 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty);
310 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
311 unsigned BitWidth = IntTy->getIntegerBitWidth();
312 ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth});
313 Value *BitWidthForInsts =
314 VectorTy
315 ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant)
316 : BitWidthConstant;
317 Value *RotateModVal =
318 IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts);
319 Value *FirstShift = nullptr, *SecShift = nullptr;
320 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
321 // Shift the less significant number right, the "rotate" number of bits
322 // will be 0-filled on the left as a result of this regular shift.
323 FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal);
324 } else {
325 // Shift the more significant number left, the "rotate" number of bits
326 // will be 0-filled on the right as a result of this regular shift.
327 FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal);
328 }
329 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
330 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
331 // Therefore, subtract the "rotate" number from the integer bitsize...
332 Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal);
333 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
334 // ...and left-shift the more significant int by this number, zero-filling
335 // the LSBs.
336 SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal);
337 } else {
338 // ...and right-shift the less significant int by this number, zero-filling
339 // the MSBs.
340 SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal);
341 }
342 // A simple binary addition of the shifted ints yields the final result.
343 IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift));
344
345 FSHIntrinsic->setCalledFunction(FSHFunc);
346}
347
348static void lowerConstrainedFPCmpIntrinsic(
349 ConstrainedFPCmpIntrinsic *ConstrainedCmpIntrinsic,
350 SmallVector<Instruction *> &EraseFromParent) {
351 if (!ConstrainedCmpIntrinsic)
352 return;
353 // Extract the floating-point values being compared
354 Value *LHS = ConstrainedCmpIntrinsic->getArgOperand(i: 0);
355 Value *RHS = ConstrainedCmpIntrinsic->getArgOperand(i: 1);
356 FCmpInst::Predicate Pred = ConstrainedCmpIntrinsic->getPredicate();
357 IRBuilder<> Builder(ConstrainedCmpIntrinsic);
358 Value *FCmp = Builder.CreateFCmp(P: Pred, LHS, RHS);
359 ConstrainedCmpIntrinsic->replaceAllUsesWith(V: FCmp);
360 EraseFromParent.push_back(Elt: dyn_cast<Instruction>(Val: ConstrainedCmpIntrinsic));
361}
362
363static void lowerExpectAssume(IntrinsicInst *II) {
364 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
365 // ignore the intrinsic and move on. It should be removed later on by LLVM.
366 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
367 // instruction.
368 // For @llvm.assume we have OpAssumeTrueKHR.
369 // For @llvm.expect we have OpExpectKHR.
370 //
371 // We need to lower this into a builtin and then the builtin into a SPIR-V
372 // instruction.
373 if (II->getIntrinsicID() == Intrinsic::assume) {
374 Function *F = Intrinsic::getOrInsertDeclaration(
375 M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_assume);
376 II->setCalledFunction(F);
377 } else if (II->getIntrinsicID() == Intrinsic::expect) {
378 Function *F = Intrinsic::getOrInsertDeclaration(
379 M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_expect,
380 Tys: {II->getOperand(i_nocapture: 0)->getType()});
381 II->setCalledFunction(F);
382 } else {
383 llvm_unreachable("Unknown intrinsic");
384 }
385}
386
387static bool toSpvLifetimeIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID) {
388 auto *LifetimeArg0 = II->getArgOperand(i: 0);
389
390 // If the lifetime argument is a poison value, the intrinsic has no effect.
391 if (isa<PoisonValue>(Val: LifetimeArg0)) {
392 II->eraseFromParent();
393 return true;
394 }
395
396 IRBuilder<> Builder(II);
397 auto *Alloca = cast<AllocaInst>(Val: LifetimeArg0);
398 std::optional<TypeSize> Size =
399 Alloca->getAllocationSize(DL: Alloca->getDataLayout());
400 Value *SizeVal = Builder.getInt64(C: Size ? *Size : -1);
401 Builder.CreateIntrinsic(ID: NewID, Types: Alloca->getType(), Args: {SizeVal, LifetimeArg0});
402 II->eraseFromParent();
403 return true;
404}
405
406static void
407lowerConstrainedFmuladd(IntrinsicInst *II,
408 SmallVector<Instruction *> &EraseFromParent) {
409 auto *FPI = cast<ConstrainedFPIntrinsic>(Val: II);
410 Value *A = FPI->getArgOperand(i: 0);
411 Value *Mul = FPI->getArgOperand(i: 1);
412 Value *Add = FPI->getArgOperand(i: 2);
413 IRBuilder<> Builder(II->getParent());
414 Builder.SetInsertPoint(II);
415 std::optional<RoundingMode> Rounding = FPI->getRoundingMode();
416 Value *Product = Builder.CreateFMul(L: A, R: Mul, Name: II->getName() + ".mul");
417 Value *Result = Builder.CreateConstrainedFPBinOp(
418 ID: Intrinsic::experimental_constrained_fadd, L: Product, R: Add, FMFSource: {},
419 Name: II->getName() + ".add", FPMathTag: nullptr, Rounding);
420 II->replaceAllUsesWith(V: Result);
421 EraseFromParent.push_back(Elt: II);
422}
423
424// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
425// or calls to proper generated functions. Returns True if F was modified.
426bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
427 bool Changed = false;
428 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F: *F);
429 SmallVector<Instruction *> EraseFromParent;
430 for (BasicBlock &BB : *F) {
431 for (Instruction &I : make_early_inc_range(Range&: BB)) {
432 auto Call = dyn_cast<CallInst>(Val: &I);
433 if (!Call)
434 continue;
435 Function *CF = Call->getCalledFunction();
436 if (!CF || !CF->isIntrinsic())
437 continue;
438 auto *II = cast<IntrinsicInst>(Val: Call);
439 switch (II->getIntrinsicID()) {
440 case Intrinsic::memset:
441 case Intrinsic::bswap:
442 Changed |= lowerIntrinsicToFunction(Intrinsic: II);
443 break;
444 case Intrinsic::fshl:
445 case Intrinsic::fshr:
446 lowerFunnelShifts(FSHIntrinsic: II);
447 Changed = true;
448 break;
449 case Intrinsic::assume:
450 case Intrinsic::expect:
451 if (STI.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume))
452 lowerExpectAssume(II);
453 Changed = true;
454 break;
455 case Intrinsic::lifetime_start:
456 if (!STI.isShader()) {
457 Changed |= toSpvLifetimeIntrinsic(
458 II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_start);
459 } else {
460 II->eraseFromParent();
461 Changed = true;
462 }
463 break;
464 case Intrinsic::lifetime_end:
465 if (!STI.isShader()) {
466 Changed |= toSpvLifetimeIntrinsic(
467 II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_end);
468 } else {
469 II->eraseFromParent();
470 Changed = true;
471 }
472 break;
473 case Intrinsic::ptr_annotation:
474 lowerPtrAnnotation(II);
475 Changed = true;
476 break;
477 case Intrinsic::experimental_constrained_fmuladd:
478 lowerConstrainedFmuladd(II, EraseFromParent);
479 Changed = true;
480 break;
481 case Intrinsic::experimental_constrained_fcmp:
482 case Intrinsic::experimental_constrained_fcmps:
483 lowerConstrainedFPCmpIntrinsic(ConstrainedCmpIntrinsic: dyn_cast<ConstrainedFPCmpIntrinsic>(Val: II),
484 EraseFromParent);
485 Changed = true;
486 break;
487 default:
488 if (TM.getTargetTriple().getVendor() == Triple::AMD ||
489 any_of(Range&: SPVAllowUnknownIntrinsics, P: [II](auto &&Prefix) {
490 if (Prefix.empty())
491 return false;
492 return II->getCalledFunction()->getName().starts_with(Prefix);
493 }))
494 Changed |= lowerIntrinsicToFunction(Intrinsic: II);
495 break;
496 }
497 }
498 }
499 for (auto *I : EraseFromParent)
500 I->eraseFromParent();
501 return Changed;
502}
503
504static void
505addFunctionTypeMutation(NamedMDNode *NMD,
506 SmallVector<std::pair<int, Type *>> ChangedTys,
507 StringRef Name) {
508
509 LLVMContext &Ctx = NMD->getParent()->getContext();
510 Type *I32Ty = IntegerType::getInt32Ty(C&: Ctx);
511
512 SmallVector<Metadata *> MDArgs;
513 MDArgs.push_back(Elt: MDString::get(Context&: Ctx, Str: Name));
514 transform(Range&: ChangedTys, d_first: std::back_inserter(x&: MDArgs), F: [=, &Ctx](auto &&CTy) {
515 return MDNode::get(
516 Context&: Ctx, MDs: {ConstantAsMetadata::get(C: ConstantInt::get(I32Ty, CTy.first, true)),
517 ValueAsMetadata::get(V: Constant::getNullValue(Ty: CTy.second))});
518 });
519 NMD->addOperand(M: MDNode::get(Context&: Ctx, MDs: MDArgs));
520}
521// Returns F if aggregate argument/return types are not present or cloned F
522// function with the types replaced by i32 types. The change in types is
523// noted in 'spv.cloned_funcs' metadata for later restoration.
524Function *
525SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
526 bool IsRetAggr = F->getReturnType()->isAggregateType();
527 // Allow intrinsics with aggregate return type to reach GlobalISel
528 if (F->isIntrinsic() && IsRetAggr)
529 return F;
530
531 IRBuilder<> B(F->getContext());
532
533 bool HasAggrArg = llvm::any_of(Range: F->args(), P: [](Argument &Arg) {
534 return Arg.getType()->isAggregateType();
535 });
536 bool DoClone = IsRetAggr || HasAggrArg;
537 if (!DoClone)
538 return F;
539 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
540 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
541 if (IsRetAggr)
542 ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType()));
543 SmallVector<Type *, 4> ArgTypes;
544 for (const auto &Arg : F->args()) {
545 if (Arg.getType()->isAggregateType()) {
546 ArgTypes.push_back(Elt: B.getInt32Ty());
547 ChangedTypes.push_back(
548 Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
549 } else
550 ArgTypes.push_back(Elt: Arg.getType());
551 }
552 FunctionType *NewFTy =
553 FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg());
554 Function *NewF =
555 Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), AddrSpace: F->getAddressSpace(),
556 N: F->getName(), M: F->getParent());
557
558 ValueToValueMapTy VMap;
559 auto NewFArgIt = NewF->arg_begin();
560 for (auto &Arg : F->args()) {
561 StringRef ArgName = Arg.getName();
562 NewFArgIt->setName(ArgName);
563 VMap[&Arg] = &(*NewFArgIt++);
564 }
565 SmallVector<ReturnInst *, 8> Returns;
566
567 CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
568 Returns);
569 NewF->takeName(V: F);
570
571 addFunctionTypeMutation(
572 NMD: NewF->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs"),
573 ChangedTys: std::move(ChangedTypes), Name: NewF->getName());
574
575 for (auto *U : make_early_inc_range(Range: F->users())) {
576 if (CallInst *CI;
577 (CI = dyn_cast<CallInst>(Val: U)) && CI->getCalledFunction() == F)
578 CI->mutateFunctionType(FTy: NewF->getFunctionType());
579 if (auto *C = dyn_cast<Constant>(Val: U))
580 C->handleOperandChange(F, NewF);
581 else
582 U->replaceUsesOfWith(From: F, To: NewF);
583 }
584
585 // register the mutation
586 if (RetType != F->getReturnType())
587 TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated(
588 Val: NewF, Ty: F->getReturnType());
589 return NewF;
590}
591
592// Mutates indirect callsites iff if aggregate argument/return types are present
593// with the types replaced by i32 types. The change in types is noted in
594// 'spv.mutated_callsites' metadata for later restoration.
595bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
596 if (F->isDeclaration() || F->isIntrinsic())
597 return false;
598
599 SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
600 for (auto &&I : instructions(F)) {
601 if (auto *CB = dyn_cast<CallBase>(Val: &I)) {
602 if (!CB->getCalledOperand() || CB->getCalledFunction())
603 continue;
604 if (CB->getType()->isAggregateType() ||
605 any_of(Range: CB->args(),
606 P: [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
607 Calls.emplace_back(Args&: CB, Args: nullptr);
608 }
609 }
610
611 if (Calls.empty())
612 return false;
613
614 IRBuilder<> B(F->getContext());
615
616 for (auto &&[CB, NewFnTy] : Calls) {
617 SmallVector<std::pair<int, Type *>> ChangedTypes;
618 SmallVector<Type *> NewArgTypes;
619
620 Type *RetTy = CB->getType();
621 if (RetTy->isAggregateType()) {
622 ChangedTypes.emplace_back(Args: -1, Args&: RetTy);
623 RetTy = B.getInt32Ty();
624 }
625
626 for (auto &&Arg : CB->args()) {
627 if (Arg->getType()->isAggregateType()) {
628 NewArgTypes.push_back(Elt: B.getInt32Ty());
629 ChangedTypes.emplace_back(Args: Arg.getOperandNo(), Args: Arg->getType());
630 } else {
631 NewArgTypes.push_back(Elt: Arg->getType());
632 }
633 }
634 NewFnTy = FunctionType::get(Result: RetTy, Params: NewArgTypes,
635 isVarArg: CB->getFunctionType()->isVarArg());
636
637 if (!CB->hasName())
638 CB->setName("spv.mutated_callsite." + F->getName());
639 else
640 CB->setName("spv.named_mutated_callsite." + F->getName() + "." +
641 CB->getName());
642
643 addFunctionTypeMutation(
644 NMD: F->getParent()->getOrInsertNamedMetadata(Name: "spv.mutated_callsites"),
645 ChangedTys: std::move(ChangedTypes), Name: CB->getName());
646 }
647
648 for (auto &&[CB, NewFTy] : Calls) {
649 if (NewFTy->getReturnType() != CB->getType())
650 TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated(
651 Val: CB, Ty: CB->getType());
652 CB->mutateFunctionType(FTy: NewFTy);
653 }
654
655 return true;
656}
657
658bool SPIRVPrepareFunctions::runOnModule(Module &M) {
659 bool Changed = false;
660 for (Function &F : M) {
661 Changed |= substituteIntrinsicCalls(F: &F);
662 Changed |= sortBlocks(F);
663 Changed |= removeAggregateTypesFromCalls(F: &F);
664 }
665
666 std::vector<Function *> FuncsWorklist;
667 for (auto &F : M)
668 FuncsWorklist.push_back(x: &F);
669
670 for (auto *F : FuncsWorklist) {
671 Function *NewF = removeAggregateTypesFromSignature(F);
672
673 if (NewF != F) {
674 F->eraseFromParent();
675 Changed = true;
676 }
677 }
678 return Changed;
679}
680
681ModulePass *
682llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
683 return new SPIRVPrepareFunctions(TM);
684}
685