1//===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass modifies function signatures containing aggregate arguments
10// and/or return value before IRTranslator. Information about the original
11// signatures is stored in metadata. It is used during call lowering to
12// restore correct SPIR-V types of function arguments and return values.
13// This pass also substitutes some llvm intrinsic calls with calls to newly
14// generated functions (as the Khronos LLVM/SPIR-V Translator does).
15//
16// NOTE: this pass is a module-level one due to the necessity to modify
17// GVs/functions.
18//
19//===----------------------------------------------------------------------===//
20
21#include "SPIRV.h"
22#include "SPIRVSubtarget.h"
23#include "SPIRVTargetMachine.h"
24#include "SPIRVUtils.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/Analysis/TargetTransformInfo.h"
27#include "llvm/Analysis/ValueTracking.h"
28#include "llvm/CodeGen/IntrinsicLowering.h"
29#include "llvm/IR/IRBuilder.h"
30#include "llvm/IR/InstIterator.h"
31#include "llvm/IR/Instructions.h"
32#include "llvm/IR/IntrinsicInst.h"
33#include "llvm/IR/Intrinsics.h"
34#include "llvm/IR/IntrinsicsSPIRV.h"
35#include "llvm/Transforms/Utils/Cloning.h"
36#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
37#include <regex>
38
39using namespace llvm;
40
41namespace {
42
43class SPIRVPrepareFunctions : public ModulePass {
44 const SPIRVTargetMachine &TM;
45 bool substituteIntrinsicCalls(Function *F);
46 Function *removeAggregateTypesFromSignature(Function *F);
47 bool removeAggregateTypesFromCalls(Function *F);
48
49public:
50 static char ID;
51 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM)
52 : ModulePass(ID), TM(TM) {}
53
54 bool runOnModule(Module &M) override;
55
56 StringRef getPassName() const override { return "SPIRV prepare functions"; }
57
58 void getAnalysisUsage(AnalysisUsage &AU) const override {
59 ModulePass::getAnalysisUsage(AU);
60 }
61};
62
63static cl::list<std::string> SPVAllowUnknownIntrinsics(
64 "spv-allow-unknown-intrinsics", cl::CommaSeparated,
65 cl::desc("Emit unknown intrinsics as calls to external functions. A "
66 "comma-separated input list of intrinsic prefixes must be "
67 "provided, and only intrinsics carrying a listed prefix get "
68 "emitted as described."),
69 cl::value_desc("intrinsic_prefix_0,intrinsic_prefix_1"), cl::ValueOptional);
70} // namespace
71
72char SPIRVPrepareFunctions::ID = 0;
73
74INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
75 "SPIRV prepare functions", false, false)
76
77static std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
78 Function *IntrinsicFunc = II->getCalledFunction();
79 assert(IntrinsicFunc && "Missing function");
80 std::string FuncName = IntrinsicFunc->getName().str();
81 llvm::replace(Range&: FuncName, OldValue: '.', NewValue: '_');
82 FuncName = "spirv." + FuncName;
83 return FuncName;
84}
85
86static Function *getOrCreateFunction(Module *M, Type *RetTy,
87 ArrayRef<Type *> ArgTypes,
88 StringRef Name) {
89 FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false);
90 Function *F = M->getFunction(Name);
91 if (F && F->getFunctionType() == FT)
92 return F;
93 Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M);
94 if (F)
95 NewF->setDSOLocal(F->isDSOLocal());
96 NewF->setCallingConv(CallingConv::SPIR_FUNC);
97 return NewF;
98}
99
100static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic,
101 const TargetTransformInfo &TTI) {
102 // For @llvm.memset.* intrinsic cases with constant value and length arguments
103 // are emulated via "storing" a constant array to the destination. For other
104 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
105 // intrinsic to a loop via expandMemSetAsLoop().
106 if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic))
107 if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength()))
108 return false; // It is handled later using OpCopyMemorySized.
109
110 Module *M = Intrinsic->getModule();
111 std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic);
112 if (Intrinsic->isVolatile())
113 FuncName += ".volatile";
114 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
115 Function *F = M->getFunction(Name: FuncName);
116 if (F) {
117 Intrinsic->setCalledFunction(F);
118 return true;
119 }
120 // TODO copy arguments attributes: nocapture writeonly.
121 FunctionCallee FC =
122 M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType());
123 auto IntrinsicID = Intrinsic->getIntrinsicID();
124 Intrinsic->setCalledFunction(FC);
125
126 F = dyn_cast<Function>(Val: FC.getCallee());
127 assert(F && "Callee must be a function");
128
129 switch (IntrinsicID) {
130 case Intrinsic::memset: {
131 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
132 Argument *Dest = F->getArg(i: 0);
133 Argument *Val = F->getArg(i: 1);
134 Argument *Len = F->getArg(i: 2);
135 Argument *IsVolatile = F->getArg(i: 3);
136 Dest->setName("dest");
137 Val->setName("val");
138 Len->setName("len");
139 IsVolatile->setName("isvolatile");
140 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
141 IRBuilder<> IRB(EntryBB);
142 auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(),
143 isVolatile: MSI->isVolatile());
144 IRB.CreateRetVoid();
145 expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet), TTI);
146 MemSet->eraseFromParent();
147 break;
148 }
149 case Intrinsic::bswap: {
150 BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry", Parent: F);
151 IRBuilder<> IRB(EntryBB);
152 auto *BSwap = IRB.CreateIntrinsic(ID: Intrinsic::bswap, Types: Intrinsic->getType(),
153 Args: F->getArg(i: 0));
154 IRB.CreateRet(V: BSwap);
155 IntrinsicLowering IL(M->getDataLayout());
156 IL.LowerIntrinsicCall(CI: BSwap);
157 break;
158 }
159 default:
160 break;
161 }
162 return true;
163}
164
165static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) {
166 if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(Val: AnnoVal))
167 AnnoVal = Ref->getOperand(i_nocapture: 0);
168 if (auto *Ref = dyn_cast_or_null<BitCastInst>(Val: OptAnnoVal))
169 OptAnnoVal = Ref->getOperand(i_nocapture: 0);
170
171 std::string Anno;
172 if (auto *C = dyn_cast_or_null<Constant>(Val: AnnoVal)) {
173 StringRef Str;
174 if (getConstantStringInfo(V: C, Str))
175 Anno = Str;
176 }
177 // handle optional annotation parameter in a way that Khronos Translator do
178 // (collect integers wrapped in a struct)
179 if (auto *C = dyn_cast_or_null<Constant>(Val: OptAnnoVal);
180 C && C->getNumOperands()) {
181 Value *MaybeStruct = C->getOperand(i: 0);
182 if (auto *Struct = dyn_cast<ConstantStruct>(Val: MaybeStruct)) {
183 for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) {
184 if (auto *CInt = dyn_cast<ConstantInt>(Val: Struct->getOperand(i_nocapture: I)))
185 Anno += (I == 0 ? ": " : ", ") +
186 std::to_string(val: CInt->getType()->getIntegerBitWidth() == 1
187 ? CInt->getZExtValue()
188 : CInt->getSExtValue());
189 }
190 } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(Val: MaybeStruct)) {
191 // { i32 i32 ... } zeroinitializer
192 for (unsigned I = 0, E = Struct->getType()->getStructNumElements();
193 I != E; ++I)
194 Anno += I == 0 ? ": 0" : ", 0";
195 }
196 }
197 return Anno;
198}
199
200static SmallVector<Metadata *> parseAnnotation(Value *I,
201 const std::string &Anno,
202 LLVMContext &Ctx,
203 Type *Int32Ty) {
204 // Try to parse the annotation string according to the following rules:
205 // annotation := ({kind} | {kind:value,value,...})+
206 // kind := number
207 // value := number | string
208 static const std::regex R(
209 "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}");
210 SmallVector<Metadata *> MDs;
211 int Pos = 0;
212 for (std::sregex_iterator
213 It = std::sregex_iterator(Anno.begin(), Anno.end(), R),
214 ItEnd = std::sregex_iterator();
215 It != ItEnd; ++It) {
216 if (It->position() != Pos)
217 return SmallVector<Metadata *>{};
218 Pos = It->position() + It->length();
219 std::smatch Match = *It;
220 SmallVector<Metadata *> MDsItem;
221 for (std::size_t i = 1; i < Match.size(); ++i) {
222 std::ssub_match SMatch = Match[i];
223 std::string Item = SMatch.str();
224 if (Item.length() == 0)
225 break;
226 if (Item[0] == '"') {
227 Item = Item.substr(pos: 1, n: Item.length() - 2);
228 // Acceptable format of the string snippet is:
229 static const std::regex RStr("^(\\d+)(?:,(\\d+))*$");
230 if (std::smatch MatchStr; std::regex_match(s: Item, m&: MatchStr, re: RStr)) {
231 for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx)
232 if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length())
233 MDsItem.push_back(Elt: ConstantAsMetadata::get(
234 C: ConstantInt::get(Ty: Int32Ty, V: std::stoi(str: SubStr))));
235 } else {
236 MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item));
237 }
238 } else if (int32_t Num; llvm::to_integer(S: StringRef(Item), Num, Base: 10)) {
239 MDsItem.push_back(
240 Elt: ConstantAsMetadata::get(C: ConstantInt::get(Ty: Int32Ty, V: Num)));
241 } else {
242 MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item));
243 }
244 }
245 if (MDsItem.size() == 0)
246 return SmallVector<Metadata *>{};
247 MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: MDsItem));
248 }
249 return Pos == static_cast<int>(Anno.length()) ? std::move(MDs)
250 : SmallVector<Metadata *>{};
251}
252
253static void lowerPtrAnnotation(IntrinsicInst *II) {
254 LLVMContext &Ctx = II->getContext();
255 Type *Int32Ty = Type::getInt32Ty(C&: Ctx);
256
257 // Retrieve an annotation string from arguments.
258 Value *PtrArg = nullptr;
259 if (auto *BI = dyn_cast<BitCastInst>(Val: II->getArgOperand(i: 0)))
260 PtrArg = BI->getOperand(i_nocapture: 0);
261 else
262 PtrArg = II->getOperand(i_nocapture: 0);
263 std::string Anno =
264 getAnnotation(AnnoVal: II->getArgOperand(i: 1),
265 OptAnnoVal: 4 < II->arg_size() ? II->getArgOperand(i: 4) : nullptr);
266
267 // Parse the annotation.
268 SmallVector<Metadata *> MDs = parseAnnotation(I: II, Anno, Ctx, Int32Ty);
269
270 // If the annotation string is not parsed successfully we don't know the
271 // format used and output it as a general UserSemantic decoration.
272 // Otherwise MDs is a Metadata tuple (a decoration list) in the format
273 // expected by `spirv.Decorations`.
274 if (MDs.size() == 0) {
275 auto UserSemantic = ConstantAsMetadata::get(C: ConstantInt::get(
276 Ty: Int32Ty, V: static_cast<uint32_t>(SPIRV::Decoration::UserSemantic)));
277 MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: {UserSemantic, MDString::get(Context&: Ctx, Str: Anno)}));
278 }
279
280 // Build the internal intrinsic function.
281 IRBuilder<> IRB(II->getParent());
282 IRB.SetInsertPoint(II);
283 IRB.CreateIntrinsic(
284 ID: Intrinsic::spv_assign_decoration, Types: {PtrArg->getType()},
285 Args: {PtrArg, MetadataAsValue::get(Context&: Ctx, MD: MDNode::get(Context&: Ctx, MDs))});
286 II->replaceAllUsesWith(V: II->getOperand(i_nocapture: 0));
287}
288
289static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
290 // Get a separate function - otherwise, we'd have to rework the CFG of the
291 // current one. Then simply replace the intrinsic uses with a call to the new
292 // function.
293 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
294 Module *M = FSHIntrinsic->getModule();
295 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
296 Type *FSHRetTy = FSHFuncTy->getReturnType();
297 const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic);
298 Function *FSHFunc =
299 getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName);
300
301 if (!FSHFunc->empty()) {
302 FSHIntrinsic->setCalledFunction(FSHFunc);
303 return;
304 }
305 BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate", Parent: FSHFunc);
306 IRBuilder<> IRB(RotateBB);
307 Type *Ty = FSHFunc->getReturnType();
308 // Build the actual funnel shift rotate logic.
309 // In the comments, "int" is used interchangeably with "vector of int
310 // elements".
311 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty);
312 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
313 unsigned BitWidth = IntTy->getIntegerBitWidth();
314 ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth});
315 Value *BitWidthForInsts =
316 VectorTy
317 ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant)
318 : BitWidthConstant;
319 Value *RotateModVal =
320 IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts);
321 Value *FirstShift = nullptr, *SecShift = nullptr;
322 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
323 // Shift the less significant number right, the "rotate" number of bits
324 // will be 0-filled on the left as a result of this regular shift.
325 FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal);
326 } else {
327 // Shift the more significant number left, the "rotate" number of bits
328 // will be 0-filled on the right as a result of this regular shift.
329 FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal);
330 }
331 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
332 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
333 // Therefore, subtract the "rotate" number from the integer bitsize...
334 Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal);
335 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
336 // ...and left-shift the more significant int by this number, zero-filling
337 // the LSBs.
338 SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal);
339 } else {
340 // ...and right-shift the less significant int by this number, zero-filling
341 // the MSBs.
342 SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal);
343 }
344 // A simple binary addition of the shifted ints yields the final result.
345 IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift));
346
347 FSHIntrinsic->setCalledFunction(FSHFunc);
348}
349
350static void lowerConstrainedFPCmpIntrinsic(
351 ConstrainedFPCmpIntrinsic *ConstrainedCmpIntrinsic,
352 SmallVector<Instruction *> &EraseFromParent) {
353 if (!ConstrainedCmpIntrinsic)
354 return;
355 // Extract the floating-point values being compared
356 Value *LHS = ConstrainedCmpIntrinsic->getArgOperand(i: 0);
357 Value *RHS = ConstrainedCmpIntrinsic->getArgOperand(i: 1);
358 FCmpInst::Predicate Pred = ConstrainedCmpIntrinsic->getPredicate();
359 IRBuilder<> Builder(ConstrainedCmpIntrinsic);
360 Value *FCmp = Builder.CreateFCmp(P: Pred, LHS, RHS);
361 ConstrainedCmpIntrinsic->replaceAllUsesWith(V: FCmp);
362 EraseFromParent.push_back(Elt: dyn_cast<Instruction>(Val: ConstrainedCmpIntrinsic));
363}
364
365static void lowerExpectAssume(IntrinsicInst *II) {
366 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
367 // ignore the intrinsic and move on. It should be removed later on by LLVM.
368 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
369 // instruction.
370 // For @llvm.assume we have OpAssumeTrueKHR.
371 // For @llvm.expect we have OpExpectKHR.
372 //
373 // We need to lower this into a builtin and then the builtin into a SPIR-V
374 // instruction.
375 if (II->getIntrinsicID() == Intrinsic::assume) {
376 Function *F = Intrinsic::getOrInsertDeclaration(
377 M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_assume);
378 II->setCalledFunction(F);
379 } else if (II->getIntrinsicID() == Intrinsic::expect) {
380 Function *F = Intrinsic::getOrInsertDeclaration(
381 M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_expect,
382 Tys: {II->getOperand(i_nocapture: 0)->getType()});
383 II->setCalledFunction(F);
384 } else {
385 llvm_unreachable("Unknown intrinsic");
386 }
387}
388
389static bool toSpvLifetimeIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID) {
390 auto *LifetimeArg0 = II->getArgOperand(i: 0);
391
392 // If the lifetime argument is a poison value, the intrinsic has no effect.
393 if (isa<PoisonValue>(Val: LifetimeArg0)) {
394 II->eraseFromParent();
395 return true;
396 }
397
398 IRBuilder<> Builder(II);
399 auto *Alloca = cast<AllocaInst>(Val: LifetimeArg0);
400 std::optional<TypeSize> Size =
401 Alloca->getAllocationSize(DL: Alloca->getDataLayout());
402 Value *SizeVal = Builder.getInt64(C: Size ? *Size : -1);
403 Builder.CreateIntrinsic(ID: NewID, Types: Alloca->getType(), Args: {SizeVal, LifetimeArg0});
404 II->eraseFromParent();
405 return true;
406}
407
408static void
409lowerConstrainedFmuladd(IntrinsicInst *II,
410 SmallVector<Instruction *> &EraseFromParent) {
411 auto *FPI = cast<ConstrainedFPIntrinsic>(Val: II);
412 Value *A = FPI->getArgOperand(i: 0);
413 Value *Mul = FPI->getArgOperand(i: 1);
414 Value *Add = FPI->getArgOperand(i: 2);
415 IRBuilder<> Builder(II->getParent());
416 Builder.SetInsertPoint(II);
417 std::optional<RoundingMode> Rounding = FPI->getRoundingMode();
418 Value *Product = Builder.CreateFMul(L: A, R: Mul, Name: II->getName() + ".mul");
419 Value *Result = Builder.CreateConstrainedFPBinOp(
420 ID: Intrinsic::experimental_constrained_fadd, L: Product, R: Add, FMFSource: {},
421 Name: II->getName() + ".add", FPMathTag: nullptr, Rounding);
422 II->replaceAllUsesWith(V: Result);
423 EraseFromParent.push_back(Elt: II);
424}
425
426// Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
427// or calls to proper generated functions. Returns True if F was modified.
428bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
429 bool Changed = false;
430 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F: *F);
431 SmallVector<Instruction *> EraseFromParent;
432 const TargetTransformInfo &TTI = TM.getTargetTransformInfo(F: *F);
433 for (BasicBlock &BB : *F) {
434 for (Instruction &I : make_early_inc_range(Range&: BB)) {
435 auto Call = dyn_cast<CallInst>(Val: &I);
436 if (!Call)
437 continue;
438 Function *CF = Call->getCalledFunction();
439 if (!CF || !CF->isIntrinsic())
440 continue;
441 auto *II = cast<IntrinsicInst>(Val: Call);
442 switch (II->getIntrinsicID()) {
443 case Intrinsic::memset:
444 case Intrinsic::bswap:
445 Changed |= lowerIntrinsicToFunction(Intrinsic: II, TTI);
446 break;
447 case Intrinsic::fshl:
448 case Intrinsic::fshr:
449 lowerFunnelShifts(FSHIntrinsic: II);
450 Changed = true;
451 break;
452 case Intrinsic::assume:
453 case Intrinsic::expect:
454 if (STI.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume))
455 lowerExpectAssume(II);
456 Changed = true;
457 break;
458 case Intrinsic::lifetime_start:
459 if (!STI.isShader()) {
460 Changed |= toSpvLifetimeIntrinsic(
461 II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_start);
462 } else {
463 II->eraseFromParent();
464 Changed = true;
465 }
466 break;
467 case Intrinsic::lifetime_end:
468 if (!STI.isShader()) {
469 Changed |= toSpvLifetimeIntrinsic(
470 II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_end);
471 } else {
472 II->eraseFromParent();
473 Changed = true;
474 }
475 break;
476 case Intrinsic::ptr_annotation:
477 lowerPtrAnnotation(II);
478 Changed = true;
479 break;
480 case Intrinsic::experimental_constrained_fmuladd:
481 lowerConstrainedFmuladd(II, EraseFromParent);
482 Changed = true;
483 break;
484 case Intrinsic::experimental_constrained_fcmp:
485 case Intrinsic::experimental_constrained_fcmps:
486 lowerConstrainedFPCmpIntrinsic(ConstrainedCmpIntrinsic: dyn_cast<ConstrainedFPCmpIntrinsic>(Val: II),
487 EraseFromParent);
488 Changed = true;
489 break;
490 default:
491 if (TM.getTargetTriple().getVendor() == Triple::AMD ||
492 any_of(Range&: SPVAllowUnknownIntrinsics, P: [II](auto &&Prefix) {
493 if (Prefix.empty())
494 return false;
495 return II->getCalledFunction()->getName().starts_with(Prefix);
496 }))
497 Changed |= lowerIntrinsicToFunction(Intrinsic: II, TTI);
498 break;
499 }
500 }
501 }
502 for (auto *I : EraseFromParent)
503 I->eraseFromParent();
504 return Changed;
505}
506
507static void
508addFunctionTypeMutation(NamedMDNode *NMD,
509 SmallVector<std::pair<int, Type *>> ChangedTys,
510 StringRef Name) {
511
512 LLVMContext &Ctx = NMD->getParent()->getContext();
513 Type *I32Ty = IntegerType::getInt32Ty(C&: Ctx);
514
515 SmallVector<Metadata *> MDArgs;
516 MDArgs.push_back(Elt: MDString::get(Context&: Ctx, Str: Name));
517 transform(Range&: ChangedTys, d_first: std::back_inserter(x&: MDArgs), F: [=, &Ctx](auto &&CTy) {
518 return MDNode::get(
519 Context&: Ctx, MDs: {ConstantAsMetadata::get(C: ConstantInt::get(I32Ty, CTy.first, true)),
520 ValueAsMetadata::get(V: Constant::getNullValue(Ty: CTy.second))});
521 });
522 NMD->addOperand(M: MDNode::get(Context&: Ctx, MDs: MDArgs));
523}
524// Returns F if aggregate argument/return types are not present or cloned F
525// function with the types replaced by i32 types. The change in types is
526// noted in 'spv.cloned_funcs' metadata for later restoration.
527Function *
528SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
529 bool IsRetAggr = F->getReturnType()->isAggregateType();
530 // Allow intrinsics with aggregate return type to reach GlobalISel
531 if (F->isIntrinsic() && IsRetAggr)
532 return F;
533
534 IRBuilder<> B(F->getContext());
535
536 bool HasAggrArg = llvm::any_of(Range: F->args(), P: [](Argument &Arg) {
537 return Arg.getType()->isAggregateType();
538 });
539 bool DoClone = IsRetAggr || HasAggrArg;
540 if (!DoClone)
541 return F;
542 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
543 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
544 if (IsRetAggr)
545 ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType()));
546 SmallVector<Type *, 4> ArgTypes;
547 for (const auto &Arg : F->args()) {
548 if (Arg.getType()->isAggregateType()) {
549 ArgTypes.push_back(Elt: B.getInt32Ty());
550 ChangedTypes.push_back(
551 Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
552 } else
553 ArgTypes.push_back(Elt: Arg.getType());
554 }
555 FunctionType *NewFTy =
556 FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg());
557 Function *NewF =
558 Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), AddrSpace: F->getAddressSpace(),
559 N: F->getName(), M: F->getParent());
560
561 ValueToValueMapTy VMap;
562 auto NewFArgIt = NewF->arg_begin();
563 for (auto &Arg : F->args()) {
564 StringRef ArgName = Arg.getName();
565 NewFArgIt->setName(ArgName);
566 VMap[&Arg] = &(*NewFArgIt++);
567 }
568 SmallVector<ReturnInst *, 8> Returns;
569
570 CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly,
571 Returns);
572 NewF->takeName(V: F);
573
574 addFunctionTypeMutation(
575 NMD: NewF->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs"),
576 ChangedTys: std::move(ChangedTypes), Name: NewF->getName());
577
578 for (auto *U : make_early_inc_range(Range: F->users())) {
579 if (CallInst *CI;
580 (CI = dyn_cast<CallInst>(Val: U)) && CI->getCalledFunction() == F)
581 CI->mutateFunctionType(FTy: NewF->getFunctionType());
582 if (auto *C = dyn_cast<Constant>(Val: U))
583 C->handleOperandChange(F, NewF);
584 else
585 U->replaceUsesOfWith(From: F, To: NewF);
586 }
587
588 // register the mutation
589 if (RetType != F->getReturnType())
590 TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated(
591 Val: NewF, Ty: F->getReturnType());
592 return NewF;
593}
594
595// Mutates indirect callsites iff if aggregate argument/return types are present
596// with the types replaced by i32 types. The change in types is noted in
597// 'spv.mutated_callsites' metadata for later restoration.
598bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) {
599 if (F->isDeclaration() || F->isIntrinsic())
600 return false;
601
602 SmallVector<std::pair<CallBase *, FunctionType *>> Calls;
603 for (auto &&I : instructions(F)) {
604 if (auto *CB = dyn_cast<CallBase>(Val: &I)) {
605 if (!CB->getCalledOperand() || CB->getCalledFunction())
606 continue;
607 if (CB->getType()->isAggregateType() ||
608 any_of(Range: CB->args(),
609 P: [](auto &&Arg) { return Arg->getType()->isAggregateType(); }))
610 Calls.emplace_back(Args&: CB, Args: nullptr);
611 }
612 }
613
614 if (Calls.empty())
615 return false;
616
617 IRBuilder<> B(F->getContext());
618
619 for (auto &&[CB, NewFnTy] : Calls) {
620 SmallVector<std::pair<int, Type *>> ChangedTypes;
621 SmallVector<Type *> NewArgTypes;
622
623 Type *RetTy = CB->getType();
624 if (RetTy->isAggregateType()) {
625 ChangedTypes.emplace_back(Args: -1, Args&: RetTy);
626 RetTy = B.getInt32Ty();
627 }
628
629 for (auto &&Arg : CB->args()) {
630 if (Arg->getType()->isAggregateType()) {
631 NewArgTypes.push_back(Elt: B.getInt32Ty());
632 ChangedTypes.emplace_back(Args: Arg.getOperandNo(), Args: Arg->getType());
633 } else {
634 NewArgTypes.push_back(Elt: Arg->getType());
635 }
636 }
637 NewFnTy = FunctionType::get(Result: RetTy, Params: NewArgTypes,
638 isVarArg: CB->getFunctionType()->isVarArg());
639
640 if (!CB->hasName())
641 CB->setName("spv.mutated_callsite." + F->getName());
642 else
643 CB->setName("spv.named_mutated_callsite." + F->getName() + "." +
644 CB->getName());
645
646 addFunctionTypeMutation(
647 NMD: F->getParent()->getOrInsertNamedMetadata(Name: "spv.mutated_callsites"),
648 ChangedTys: std::move(ChangedTypes), Name: CB->getName());
649 }
650
651 for (auto &&[CB, NewFTy] : Calls) {
652 if (NewFTy->getReturnType() != CB->getType())
653 TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated(
654 Val: CB, Ty: CB->getType());
655 CB->mutateFunctionType(FTy: NewFTy);
656 }
657
658 return true;
659}
660
661bool SPIRVPrepareFunctions::runOnModule(Module &M) {
662 // Resolve the SPIR-V environment from module content before any
663 // function-level processing. This must happen before legalization so that
664 // isShader()/isKernel() return correct values.
665 const_cast<SPIRVTargetMachine &>(TM)
666 .getMutableSubtargetImpl()
667 ->resolveEnvFromModule(M);
668
669 bool Changed = false;
670 for (Function &F : M) {
671 Changed |= substituteIntrinsicCalls(F: &F);
672 Changed |= sortBlocks(F);
673 Changed |= removeAggregateTypesFromCalls(F: &F);
674 }
675
676 std::vector<Function *> FuncsWorklist;
677 for (auto &F : M)
678 FuncsWorklist.push_back(x: &F);
679
680 for (auto *F : FuncsWorklist) {
681 Function *NewF = removeAggregateTypesFromSignature(F);
682
683 if (NewF != F) {
684 F->eraseFromParent();
685 Changed = true;
686 }
687 }
688 return Changed;
689}
690
691ModulePass *
692llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
693 return new SPIRVPrepareFunctions(TM);
694}
695