| 1 | //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass modifies function signatures containing aggregate arguments |
| 10 | // and/or return value before IRTranslator. Information about the original |
| 11 | // signatures is stored in metadata. It is used during call lowering to |
| 12 | // restore correct SPIR-V types of function arguments and return values. |
| 13 | // This pass also substitutes some llvm intrinsic calls with calls to newly |
| 14 | // generated functions (as the Khronos LLVM/SPIR-V Translator does). |
| 15 | // |
| 16 | // NOTE: this pass is a module-level one due to the necessity to modify |
| 17 | // GVs/functions. |
| 18 | // |
| 19 | //===----------------------------------------------------------------------===// |
| 20 | |
| 21 | #include "SPIRV.h" |
| 22 | #include "SPIRVSubtarget.h" |
| 23 | #include "SPIRVTargetMachine.h" |
| 24 | #include "SPIRVUtils.h" |
| 25 | #include "llvm/ADT/StringExtras.h" |
| 26 | #include "llvm/Analysis/ValueTracking.h" |
| 27 | #include "llvm/CodeGen/IntrinsicLowering.h" |
| 28 | #include "llvm/IR/IRBuilder.h" |
| 29 | #include "llvm/IR/InstIterator.h" |
| 30 | #include "llvm/IR/Instructions.h" |
| 31 | #include "llvm/IR/IntrinsicInst.h" |
| 32 | #include "llvm/IR/Intrinsics.h" |
| 33 | #include "llvm/IR/IntrinsicsSPIRV.h" |
| 34 | #include "llvm/Transforms/Utils/Cloning.h" |
| 35 | #include "llvm/Transforms/Utils/LowerMemIntrinsics.h" |
| 36 | #include <regex> |
| 37 | |
| 38 | using namespace llvm; |
| 39 | |
| 40 | namespace { |
| 41 | |
| 42 | class SPIRVPrepareFunctions : public ModulePass { |
| 43 | const SPIRVTargetMachine &TM; |
| 44 | bool substituteIntrinsicCalls(Function *F); |
| 45 | Function *removeAggregateTypesFromSignature(Function *F); |
| 46 | bool removeAggregateTypesFromCalls(Function *F); |
| 47 | |
| 48 | public: |
| 49 | static char ID; |
| 50 | SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) |
| 51 | : ModulePass(ID), TM(TM) {} |
| 52 | |
| 53 | bool runOnModule(Module &M) override; |
| 54 | |
| 55 | StringRef getPassName() const override { return "SPIRV prepare functions" ; } |
| 56 | |
| 57 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 58 | ModulePass::getAnalysisUsage(AU); |
| 59 | } |
| 60 | }; |
| 61 | |
| 62 | static cl::list<std::string> SPVAllowUnknownIntrinsics( |
| 63 | "spv-allow-unknown-intrinsics" , cl::CommaSeparated, |
| 64 | cl::desc("Emit unknown intrinsics as calls to external functions. A " |
| 65 | "comma-separated input list of intrinsic prefixes must be " |
| 66 | "provided, and only intrinsics carrying a listed prefix get " |
| 67 | "emitted as described." ), |
| 68 | cl::value_desc("intrinsic_prefix_0,intrinsic_prefix_1" ), cl::ValueOptional); |
| 69 | } // namespace |
| 70 | |
| 71 | char SPIRVPrepareFunctions::ID = 0; |
| 72 | |
| 73 | INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions" , |
| 74 | "SPIRV prepare functions" , false, false) |
| 75 | |
| 76 | static std::string lowerLLVMIntrinsicName(IntrinsicInst *II) { |
| 77 | Function *IntrinsicFunc = II->getCalledFunction(); |
| 78 | assert(IntrinsicFunc && "Missing function" ); |
| 79 | std::string FuncName = IntrinsicFunc->getName().str(); |
| 80 | llvm::replace(Range&: FuncName, OldValue: '.', NewValue: '_'); |
| 81 | FuncName = "spirv." + FuncName; |
| 82 | return FuncName; |
| 83 | } |
| 84 | |
| 85 | static Function *getOrCreateFunction(Module *M, Type *RetTy, |
| 86 | ArrayRef<Type *> ArgTypes, |
| 87 | StringRef Name) { |
| 88 | FunctionType *FT = FunctionType::get(Result: RetTy, Params: ArgTypes, isVarArg: false); |
| 89 | Function *F = M->getFunction(Name); |
| 90 | if (F && F->getFunctionType() == FT) |
| 91 | return F; |
| 92 | Function *NewF = Function::Create(Ty: FT, Linkage: GlobalValue::ExternalLinkage, N: Name, M); |
| 93 | if (F) |
| 94 | NewF->setDSOLocal(F->isDSOLocal()); |
| 95 | NewF->setCallingConv(CallingConv::SPIR_FUNC); |
| 96 | return NewF; |
| 97 | } |
| 98 | |
| 99 | static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) { |
| 100 | // For @llvm.memset.* intrinsic cases with constant value and length arguments |
| 101 | // are emulated via "storing" a constant array to the destination. For other |
| 102 | // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the |
| 103 | // intrinsic to a loop via expandMemSetAsLoop(). |
| 104 | if (auto *MSI = dyn_cast<MemSetInst>(Val: Intrinsic)) |
| 105 | if (isa<Constant>(Val: MSI->getValue()) && isa<ConstantInt>(Val: MSI->getLength())) |
| 106 | return false; // It is handled later using OpCopyMemorySized. |
| 107 | |
| 108 | Module *M = Intrinsic->getModule(); |
| 109 | std::string FuncName = lowerLLVMIntrinsicName(II: Intrinsic); |
| 110 | if (Intrinsic->isVolatile()) |
| 111 | FuncName += ".volatile" ; |
| 112 | // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_* |
| 113 | Function *F = M->getFunction(Name: FuncName); |
| 114 | if (F) { |
| 115 | Intrinsic->setCalledFunction(F); |
| 116 | return true; |
| 117 | } |
| 118 | // TODO copy arguments attributes: nocapture writeonly. |
| 119 | FunctionCallee FC = |
| 120 | M->getOrInsertFunction(Name: FuncName, T: Intrinsic->getFunctionType()); |
| 121 | auto IntrinsicID = Intrinsic->getIntrinsicID(); |
| 122 | Intrinsic->setCalledFunction(FC); |
| 123 | |
| 124 | F = dyn_cast<Function>(Val: FC.getCallee()); |
| 125 | assert(F && "Callee must be a function" ); |
| 126 | |
| 127 | switch (IntrinsicID) { |
| 128 | case Intrinsic::memset: { |
| 129 | auto *MSI = static_cast<MemSetInst *>(Intrinsic); |
| 130 | Argument *Dest = F->getArg(i: 0); |
| 131 | Argument *Val = F->getArg(i: 1); |
| 132 | Argument *Len = F->getArg(i: 2); |
| 133 | Argument *IsVolatile = F->getArg(i: 3); |
| 134 | Dest->setName("dest" ); |
| 135 | Val->setName("val" ); |
| 136 | Len->setName("len" ); |
| 137 | IsVolatile->setName("isvolatile" ); |
| 138 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
| 139 | IRBuilder<> IRB(EntryBB); |
| 140 | auto *MemSet = IRB.CreateMemSet(Ptr: Dest, Val, Size: Len, Align: MSI->getDestAlign(), |
| 141 | isVolatile: MSI->isVolatile()); |
| 142 | IRB.CreateRetVoid(); |
| 143 | expandMemSetAsLoop(MemSet: cast<MemSetInst>(Val: MemSet)); |
| 144 | MemSet->eraseFromParent(); |
| 145 | break; |
| 146 | } |
| 147 | case Intrinsic::bswap: { |
| 148 | BasicBlock *EntryBB = BasicBlock::Create(Context&: M->getContext(), Name: "entry" , Parent: F); |
| 149 | IRBuilder<> IRB(EntryBB); |
| 150 | auto *BSwap = IRB.CreateIntrinsic(ID: Intrinsic::bswap, Types: Intrinsic->getType(), |
| 151 | Args: F->getArg(i: 0)); |
| 152 | IRB.CreateRet(V: BSwap); |
| 153 | IntrinsicLowering IL(M->getDataLayout()); |
| 154 | IL.LowerIntrinsicCall(CI: BSwap); |
| 155 | break; |
| 156 | } |
| 157 | default: |
| 158 | break; |
| 159 | } |
| 160 | return true; |
| 161 | } |
| 162 | |
| 163 | static std::string getAnnotation(Value *AnnoVal, Value *OptAnnoVal) { |
| 164 | if (auto *Ref = dyn_cast_or_null<GetElementPtrInst>(Val: AnnoVal)) |
| 165 | AnnoVal = Ref->getOperand(i_nocapture: 0); |
| 166 | if (auto *Ref = dyn_cast_or_null<BitCastInst>(Val: OptAnnoVal)) |
| 167 | OptAnnoVal = Ref->getOperand(i_nocapture: 0); |
| 168 | |
| 169 | std::string Anno; |
| 170 | if (auto *C = dyn_cast_or_null<Constant>(Val: AnnoVal)) { |
| 171 | StringRef Str; |
| 172 | if (getConstantStringInfo(V: C, Str)) |
| 173 | Anno = Str; |
| 174 | } |
| 175 | // handle optional annotation parameter in a way that Khronos Translator do |
| 176 | // (collect integers wrapped in a struct) |
| 177 | if (auto *C = dyn_cast_or_null<Constant>(Val: OptAnnoVal); |
| 178 | C && C->getNumOperands()) { |
| 179 | Value *MaybeStruct = C->getOperand(i: 0); |
| 180 | if (auto *Struct = dyn_cast<ConstantStruct>(Val: MaybeStruct)) { |
| 181 | for (unsigned I = 0, E = Struct->getNumOperands(); I != E; ++I) { |
| 182 | if (auto *CInt = dyn_cast<ConstantInt>(Val: Struct->getOperand(i_nocapture: I))) |
| 183 | Anno += (I == 0 ? ": " : ", " ) + |
| 184 | std::to_string(val: CInt->getType()->getIntegerBitWidth() == 1 |
| 185 | ? CInt->getZExtValue() |
| 186 | : CInt->getSExtValue()); |
| 187 | } |
| 188 | } else if (auto *Struct = dyn_cast<ConstantAggregateZero>(Val: MaybeStruct)) { |
| 189 | // { i32 i32 ... } zeroinitializer |
| 190 | for (unsigned I = 0, E = Struct->getType()->getStructNumElements(); |
| 191 | I != E; ++I) |
| 192 | Anno += I == 0 ? ": 0" : ", 0" ; |
| 193 | } |
| 194 | } |
| 195 | return Anno; |
| 196 | } |
| 197 | |
| 198 | static SmallVector<Metadata *> parseAnnotation(Value *I, |
| 199 | const std::string &Anno, |
| 200 | LLVMContext &Ctx, |
| 201 | Type *Int32Ty) { |
| 202 | // Try to parse the annotation string according to the following rules: |
| 203 | // annotation := ({kind} | {kind:value,value,...})+ |
| 204 | // kind := number |
| 205 | // value := number | string |
| 206 | static const std::regex R( |
| 207 | "\\{(\\d+)(?:[:,](\\d+|\"[^\"]*\")(?:,(\\d+|\"[^\"]*\"))*)?\\}" ); |
| 208 | SmallVector<Metadata *> MDs; |
| 209 | int Pos = 0; |
| 210 | for (std::sregex_iterator |
| 211 | It = std::sregex_iterator(Anno.begin(), Anno.end(), R), |
| 212 | ItEnd = std::sregex_iterator(); |
| 213 | It != ItEnd; ++It) { |
| 214 | if (It->position() != Pos) |
| 215 | return SmallVector<Metadata *>{}; |
| 216 | Pos = It->position() + It->length(); |
| 217 | std::smatch Match = *It; |
| 218 | SmallVector<Metadata *> MDsItem; |
| 219 | for (std::size_t i = 1; i < Match.size(); ++i) { |
| 220 | std::ssub_match SMatch = Match[i]; |
| 221 | std::string Item = SMatch.str(); |
| 222 | if (Item.length() == 0) |
| 223 | break; |
| 224 | if (Item[0] == '"') { |
| 225 | Item = Item.substr(pos: 1, n: Item.length() - 2); |
| 226 | // Acceptable format of the string snippet is: |
| 227 | static const std::regex RStr("^(\\d+)(?:,(\\d+))*$" ); |
| 228 | if (std::smatch MatchStr; std::regex_match(s: Item, m&: MatchStr, re: RStr)) { |
| 229 | for (std::size_t SubIdx = 1; SubIdx < MatchStr.size(); ++SubIdx) |
| 230 | if (std::string SubStr = MatchStr[SubIdx].str(); SubStr.length()) |
| 231 | MDsItem.push_back(Elt: ConstantAsMetadata::get( |
| 232 | C: ConstantInt::get(Ty: Int32Ty, V: std::stoi(str: SubStr)))); |
| 233 | } else { |
| 234 | MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item)); |
| 235 | } |
| 236 | } else if (int32_t Num; llvm::to_integer(S: StringRef(Item), Num, Base: 10)) { |
| 237 | MDsItem.push_back( |
| 238 | Elt: ConstantAsMetadata::get(C: ConstantInt::get(Ty: Int32Ty, V: Num))); |
| 239 | } else { |
| 240 | MDsItem.push_back(Elt: MDString::get(Context&: Ctx, Str: Item)); |
| 241 | } |
| 242 | } |
| 243 | if (MDsItem.size() == 0) |
| 244 | return SmallVector<Metadata *>{}; |
| 245 | MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: MDsItem)); |
| 246 | } |
| 247 | return Pos == static_cast<int>(Anno.length()) ? std::move(MDs) |
| 248 | : SmallVector<Metadata *>{}; |
| 249 | } |
| 250 | |
| 251 | static void lowerPtrAnnotation(IntrinsicInst *II) { |
| 252 | LLVMContext &Ctx = II->getContext(); |
| 253 | Type *Int32Ty = Type::getInt32Ty(C&: Ctx); |
| 254 | |
| 255 | // Retrieve an annotation string from arguments. |
| 256 | Value *PtrArg = nullptr; |
| 257 | if (auto *BI = dyn_cast<BitCastInst>(Val: II->getArgOperand(i: 0))) |
| 258 | PtrArg = BI->getOperand(i_nocapture: 0); |
| 259 | else |
| 260 | PtrArg = II->getOperand(i_nocapture: 0); |
| 261 | std::string Anno = |
| 262 | getAnnotation(AnnoVal: II->getArgOperand(i: 1), |
| 263 | OptAnnoVal: 4 < II->arg_size() ? II->getArgOperand(i: 4) : nullptr); |
| 264 | |
| 265 | // Parse the annotation. |
| 266 | SmallVector<Metadata *> MDs = parseAnnotation(I: II, Anno, Ctx, Int32Ty); |
| 267 | |
| 268 | // If the annotation string is not parsed successfully we don't know the |
| 269 | // format used and output it as a general UserSemantic decoration. |
| 270 | // Otherwise MDs is a Metadata tuple (a decoration list) in the format |
| 271 | // expected by `spirv.Decorations`. |
| 272 | if (MDs.size() == 0) { |
| 273 | auto UserSemantic = ConstantAsMetadata::get(C: ConstantInt::get( |
| 274 | Ty: Int32Ty, V: static_cast<uint32_t>(SPIRV::Decoration::UserSemantic))); |
| 275 | MDs.push_back(Elt: MDNode::get(Context&: Ctx, MDs: {UserSemantic, MDString::get(Context&: Ctx, Str: Anno)})); |
| 276 | } |
| 277 | |
| 278 | // Build the internal intrinsic function. |
| 279 | IRBuilder<> IRB(II->getParent()); |
| 280 | IRB.SetInsertPoint(II); |
| 281 | IRB.CreateIntrinsic( |
| 282 | ID: Intrinsic::spv_assign_decoration, Types: {PtrArg->getType()}, |
| 283 | Args: {PtrArg, MetadataAsValue::get(Context&: Ctx, MD: MDNode::get(Context&: Ctx, MDs))}); |
| 284 | II->replaceAllUsesWith(V: II->getOperand(i_nocapture: 0)); |
| 285 | } |
| 286 | |
| 287 | static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) { |
| 288 | // Get a separate function - otherwise, we'd have to rework the CFG of the |
| 289 | // current one. Then simply replace the intrinsic uses with a call to the new |
| 290 | // function. |
| 291 | // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c) |
| 292 | Module *M = FSHIntrinsic->getModule(); |
| 293 | FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType(); |
| 294 | Type *FSHRetTy = FSHFuncTy->getReturnType(); |
| 295 | const std::string FuncName = lowerLLVMIntrinsicName(II: FSHIntrinsic); |
| 296 | Function *FSHFunc = |
| 297 | getOrCreateFunction(M, RetTy: FSHRetTy, ArgTypes: FSHFuncTy->params(), Name: FuncName); |
| 298 | |
| 299 | if (!FSHFunc->empty()) { |
| 300 | FSHIntrinsic->setCalledFunction(FSHFunc); |
| 301 | return; |
| 302 | } |
| 303 | BasicBlock *RotateBB = BasicBlock::Create(Context&: M->getContext(), Name: "rotate" , Parent: FSHFunc); |
| 304 | IRBuilder<> IRB(RotateBB); |
| 305 | Type *Ty = FSHFunc->getReturnType(); |
| 306 | // Build the actual funnel shift rotate logic. |
| 307 | // In the comments, "int" is used interchangeably with "vector of int |
| 308 | // elements". |
| 309 | FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Val: Ty); |
| 310 | Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty; |
| 311 | unsigned BitWidth = IntTy->getIntegerBitWidth(); |
| 312 | ConstantInt *BitWidthConstant = IRB.getInt(AI: {BitWidth, BitWidth}); |
| 313 | Value *BitWidthForInsts = |
| 314 | VectorTy |
| 315 | ? IRB.CreateVectorSplat(NumElts: VectorTy->getNumElements(), V: BitWidthConstant) |
| 316 | : BitWidthConstant; |
| 317 | Value *RotateModVal = |
| 318 | IRB.CreateURem(/*Rotate*/ LHS: FSHFunc->getArg(i: 2), RHS: BitWidthForInsts); |
| 319 | Value *FirstShift = nullptr, *SecShift = nullptr; |
| 320 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
| 321 | // Shift the less significant number right, the "rotate" number of bits |
| 322 | // will be 0-filled on the left as a result of this regular shift. |
| 323 | FirstShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: RotateModVal); |
| 324 | } else { |
| 325 | // Shift the more significant number left, the "rotate" number of bits |
| 326 | // will be 0-filled on the right as a result of this regular shift. |
| 327 | FirstShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: RotateModVal); |
| 328 | } |
| 329 | // We want the "rotate" number of the more significant int's LSBs (MSBs) to |
| 330 | // occupy the leftmost (rightmost) "0 space" left by the previous operation. |
| 331 | // Therefore, subtract the "rotate" number from the integer bitsize... |
| 332 | Value *SubRotateVal = IRB.CreateSub(LHS: BitWidthForInsts, RHS: RotateModVal); |
| 333 | if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) { |
| 334 | // ...and left-shift the more significant int by this number, zero-filling |
| 335 | // the LSBs. |
| 336 | SecShift = IRB.CreateShl(LHS: FSHFunc->getArg(i: 0), RHS: SubRotateVal); |
| 337 | } else { |
| 338 | // ...and right-shift the less significant int by this number, zero-filling |
| 339 | // the MSBs. |
| 340 | SecShift = IRB.CreateLShr(LHS: FSHFunc->getArg(i: 1), RHS: SubRotateVal); |
| 341 | } |
| 342 | // A simple binary addition of the shifted ints yields the final result. |
| 343 | IRB.CreateRet(V: IRB.CreateOr(LHS: FirstShift, RHS: SecShift)); |
| 344 | |
| 345 | FSHIntrinsic->setCalledFunction(FSHFunc); |
| 346 | } |
| 347 | |
| 348 | static void lowerConstrainedFPCmpIntrinsic( |
| 349 | ConstrainedFPCmpIntrinsic *ConstrainedCmpIntrinsic, |
| 350 | SmallVector<Instruction *> &EraseFromParent) { |
| 351 | if (!ConstrainedCmpIntrinsic) |
| 352 | return; |
| 353 | // Extract the floating-point values being compared |
| 354 | Value *LHS = ConstrainedCmpIntrinsic->getArgOperand(i: 0); |
| 355 | Value *RHS = ConstrainedCmpIntrinsic->getArgOperand(i: 1); |
| 356 | FCmpInst::Predicate Pred = ConstrainedCmpIntrinsic->getPredicate(); |
| 357 | IRBuilder<> Builder(ConstrainedCmpIntrinsic); |
| 358 | Value *FCmp = Builder.CreateFCmp(P: Pred, LHS, RHS); |
| 359 | ConstrainedCmpIntrinsic->replaceAllUsesWith(V: FCmp); |
| 360 | EraseFromParent.push_back(Elt: dyn_cast<Instruction>(Val: ConstrainedCmpIntrinsic)); |
| 361 | } |
| 362 | |
| 363 | static void lowerExpectAssume(IntrinsicInst *II) { |
| 364 | // If we cannot use the SPV_KHR_expect_assume extension, then we need to |
| 365 | // ignore the intrinsic and move on. It should be removed later on by LLVM. |
| 366 | // Otherwise we should lower the intrinsic to the corresponding SPIR-V |
| 367 | // instruction. |
| 368 | // For @llvm.assume we have OpAssumeTrueKHR. |
| 369 | // For @llvm.expect we have OpExpectKHR. |
| 370 | // |
| 371 | // We need to lower this into a builtin and then the builtin into a SPIR-V |
| 372 | // instruction. |
| 373 | if (II->getIntrinsicID() == Intrinsic::assume) { |
| 374 | Function *F = Intrinsic::getOrInsertDeclaration( |
| 375 | M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_assume); |
| 376 | II->setCalledFunction(F); |
| 377 | } else if (II->getIntrinsicID() == Intrinsic::expect) { |
| 378 | Function *F = Intrinsic::getOrInsertDeclaration( |
| 379 | M: II->getModule(), id: Intrinsic::SPVIntrinsics::spv_expect, |
| 380 | Tys: {II->getOperand(i_nocapture: 0)->getType()}); |
| 381 | II->setCalledFunction(F); |
| 382 | } else { |
| 383 | llvm_unreachable("Unknown intrinsic" ); |
| 384 | } |
| 385 | } |
| 386 | |
| 387 | static bool toSpvLifetimeIntrinsic(IntrinsicInst *II, Intrinsic::ID NewID) { |
| 388 | auto *LifetimeArg0 = II->getArgOperand(i: 0); |
| 389 | |
| 390 | // If the lifetime argument is a poison value, the intrinsic has no effect. |
| 391 | if (isa<PoisonValue>(Val: LifetimeArg0)) { |
| 392 | II->eraseFromParent(); |
| 393 | return true; |
| 394 | } |
| 395 | |
| 396 | IRBuilder<> Builder(II); |
| 397 | auto *Alloca = cast<AllocaInst>(Val: LifetimeArg0); |
| 398 | std::optional<TypeSize> Size = |
| 399 | Alloca->getAllocationSize(DL: Alloca->getDataLayout()); |
| 400 | Value *SizeVal = Builder.getInt64(C: Size ? *Size : -1); |
| 401 | Builder.CreateIntrinsic(ID: NewID, Types: Alloca->getType(), Args: {SizeVal, LifetimeArg0}); |
| 402 | II->eraseFromParent(); |
| 403 | return true; |
| 404 | } |
| 405 | |
| 406 | static void |
| 407 | lowerConstrainedFmuladd(IntrinsicInst *II, |
| 408 | SmallVector<Instruction *> &EraseFromParent) { |
| 409 | auto *FPI = cast<ConstrainedFPIntrinsic>(Val: II); |
| 410 | Value *A = FPI->getArgOperand(i: 0); |
| 411 | Value *Mul = FPI->getArgOperand(i: 1); |
| 412 | Value *Add = FPI->getArgOperand(i: 2); |
| 413 | IRBuilder<> Builder(II->getParent()); |
| 414 | Builder.SetInsertPoint(II); |
| 415 | std::optional<RoundingMode> Rounding = FPI->getRoundingMode(); |
| 416 | Value *Product = Builder.CreateFMul(L: A, R: Mul, Name: II->getName() + ".mul" ); |
| 417 | Value *Result = Builder.CreateConstrainedFPBinOp( |
| 418 | ID: Intrinsic::experimental_constrained_fadd, L: Product, R: Add, FMFSource: {}, |
| 419 | Name: II->getName() + ".add" , FPMathTag: nullptr, Rounding); |
| 420 | II->replaceAllUsesWith(V: Result); |
| 421 | EraseFromParent.push_back(Elt: II); |
| 422 | } |
| 423 | |
| 424 | // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics |
| 425 | // or calls to proper generated functions. Returns True if F was modified. |
| 426 | bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) { |
| 427 | bool Changed = false; |
| 428 | const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(F: *F); |
| 429 | SmallVector<Instruction *> EraseFromParent; |
| 430 | for (BasicBlock &BB : *F) { |
| 431 | for (Instruction &I : make_early_inc_range(Range&: BB)) { |
| 432 | auto Call = dyn_cast<CallInst>(Val: &I); |
| 433 | if (!Call) |
| 434 | continue; |
| 435 | Function *CF = Call->getCalledFunction(); |
| 436 | if (!CF || !CF->isIntrinsic()) |
| 437 | continue; |
| 438 | auto *II = cast<IntrinsicInst>(Val: Call); |
| 439 | switch (II->getIntrinsicID()) { |
| 440 | case Intrinsic::memset: |
| 441 | case Intrinsic::bswap: |
| 442 | Changed |= lowerIntrinsicToFunction(Intrinsic: II); |
| 443 | break; |
| 444 | case Intrinsic::fshl: |
| 445 | case Intrinsic::fshr: |
| 446 | lowerFunnelShifts(FSHIntrinsic: II); |
| 447 | Changed = true; |
| 448 | break; |
| 449 | case Intrinsic::assume: |
| 450 | case Intrinsic::expect: |
| 451 | if (STI.canUseExtension(E: SPIRV::Extension::SPV_KHR_expect_assume)) |
| 452 | lowerExpectAssume(II); |
| 453 | Changed = true; |
| 454 | break; |
| 455 | case Intrinsic::lifetime_start: |
| 456 | if (!STI.isShader()) { |
| 457 | Changed |= toSpvLifetimeIntrinsic( |
| 458 | II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_start); |
| 459 | } else { |
| 460 | II->eraseFromParent(); |
| 461 | Changed = true; |
| 462 | } |
| 463 | break; |
| 464 | case Intrinsic::lifetime_end: |
| 465 | if (!STI.isShader()) { |
| 466 | Changed |= toSpvLifetimeIntrinsic( |
| 467 | II, NewID: Intrinsic::SPVIntrinsics::spv_lifetime_end); |
| 468 | } else { |
| 469 | II->eraseFromParent(); |
| 470 | Changed = true; |
| 471 | } |
| 472 | break; |
| 473 | case Intrinsic::ptr_annotation: |
| 474 | lowerPtrAnnotation(II); |
| 475 | Changed = true; |
| 476 | break; |
| 477 | case Intrinsic::experimental_constrained_fmuladd: |
| 478 | lowerConstrainedFmuladd(II, EraseFromParent); |
| 479 | Changed = true; |
| 480 | break; |
| 481 | case Intrinsic::experimental_constrained_fcmp: |
| 482 | case Intrinsic::experimental_constrained_fcmps: |
| 483 | lowerConstrainedFPCmpIntrinsic(ConstrainedCmpIntrinsic: dyn_cast<ConstrainedFPCmpIntrinsic>(Val: II), |
| 484 | EraseFromParent); |
| 485 | Changed = true; |
| 486 | break; |
| 487 | default: |
| 488 | if (TM.getTargetTriple().getVendor() == Triple::AMD || |
| 489 | any_of(Range&: SPVAllowUnknownIntrinsics, P: [II](auto &&Prefix) { |
| 490 | if (Prefix.empty()) |
| 491 | return false; |
| 492 | return II->getCalledFunction()->getName().starts_with(Prefix); |
| 493 | })) |
| 494 | Changed |= lowerIntrinsicToFunction(Intrinsic: II); |
| 495 | break; |
| 496 | } |
| 497 | } |
| 498 | } |
| 499 | for (auto *I : EraseFromParent) |
| 500 | I->eraseFromParent(); |
| 501 | return Changed; |
| 502 | } |
| 503 | |
| 504 | static void |
| 505 | addFunctionTypeMutation(NamedMDNode *NMD, |
| 506 | SmallVector<std::pair<int, Type *>> ChangedTys, |
| 507 | StringRef Name) { |
| 508 | |
| 509 | LLVMContext &Ctx = NMD->getParent()->getContext(); |
| 510 | Type *I32Ty = IntegerType::getInt32Ty(C&: Ctx); |
| 511 | |
| 512 | SmallVector<Metadata *> MDArgs; |
| 513 | MDArgs.push_back(Elt: MDString::get(Context&: Ctx, Str: Name)); |
| 514 | transform(Range&: ChangedTys, d_first: std::back_inserter(x&: MDArgs), F: [=, &Ctx](auto &&CTy) { |
| 515 | return MDNode::get( |
| 516 | Context&: Ctx, MDs: {ConstantAsMetadata::get(C: ConstantInt::get(I32Ty, CTy.first, true)), |
| 517 | ValueAsMetadata::get(V: Constant::getNullValue(Ty: CTy.second))}); |
| 518 | }); |
| 519 | NMD->addOperand(M: MDNode::get(Context&: Ctx, MDs: MDArgs)); |
| 520 | } |
| 521 | // Returns F if aggregate argument/return types are not present or cloned F |
| 522 | // function with the types replaced by i32 types. The change in types is |
| 523 | // noted in 'spv.cloned_funcs' metadata for later restoration. |
| 524 | Function * |
| 525 | SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) { |
| 526 | bool IsRetAggr = F->getReturnType()->isAggregateType(); |
| 527 | // Allow intrinsics with aggregate return type to reach GlobalISel |
| 528 | if (F->isIntrinsic() && IsRetAggr) |
| 529 | return F; |
| 530 | |
| 531 | IRBuilder<> B(F->getContext()); |
| 532 | |
| 533 | bool HasAggrArg = llvm::any_of(Range: F->args(), P: [](Argument &Arg) { |
| 534 | return Arg.getType()->isAggregateType(); |
| 535 | }); |
| 536 | bool DoClone = IsRetAggr || HasAggrArg; |
| 537 | if (!DoClone) |
| 538 | return F; |
| 539 | SmallVector<std::pair<int, Type *>, 4> ChangedTypes; |
| 540 | Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType(); |
| 541 | if (IsRetAggr) |
| 542 | ChangedTypes.push_back(Elt: std::pair<int, Type *>(-1, F->getReturnType())); |
| 543 | SmallVector<Type *, 4> ArgTypes; |
| 544 | for (const auto &Arg : F->args()) { |
| 545 | if (Arg.getType()->isAggregateType()) { |
| 546 | ArgTypes.push_back(Elt: B.getInt32Ty()); |
| 547 | ChangedTypes.push_back( |
| 548 | Elt: std::pair<int, Type *>(Arg.getArgNo(), Arg.getType())); |
| 549 | } else |
| 550 | ArgTypes.push_back(Elt: Arg.getType()); |
| 551 | } |
| 552 | FunctionType *NewFTy = |
| 553 | FunctionType::get(Result: RetType, Params: ArgTypes, isVarArg: F->getFunctionType()->isVarArg()); |
| 554 | Function *NewF = |
| 555 | Function::Create(Ty: NewFTy, Linkage: F->getLinkage(), AddrSpace: F->getAddressSpace(), |
| 556 | N: F->getName(), M: F->getParent()); |
| 557 | |
| 558 | ValueToValueMapTy VMap; |
| 559 | auto NewFArgIt = NewF->arg_begin(); |
| 560 | for (auto &Arg : F->args()) { |
| 561 | StringRef ArgName = Arg.getName(); |
| 562 | NewFArgIt->setName(ArgName); |
| 563 | VMap[&Arg] = &(*NewFArgIt++); |
| 564 | } |
| 565 | SmallVector<ReturnInst *, 8> Returns; |
| 566 | |
| 567 | CloneFunctionInto(NewFunc: NewF, OldFunc: F, VMap, Changes: CloneFunctionChangeType::LocalChangesOnly, |
| 568 | Returns); |
| 569 | NewF->takeName(V: F); |
| 570 | |
| 571 | addFunctionTypeMutation( |
| 572 | NMD: NewF->getParent()->getOrInsertNamedMetadata(Name: "spv.cloned_funcs" ), |
| 573 | ChangedTys: std::move(ChangedTypes), Name: NewF->getName()); |
| 574 | |
| 575 | for (auto *U : make_early_inc_range(Range: F->users())) { |
| 576 | if (CallInst *CI; |
| 577 | (CI = dyn_cast<CallInst>(Val: U)) && CI->getCalledFunction() == F) |
| 578 | CI->mutateFunctionType(FTy: NewF->getFunctionType()); |
| 579 | if (auto *C = dyn_cast<Constant>(Val: U)) |
| 580 | C->handleOperandChange(F, NewF); |
| 581 | else |
| 582 | U->replaceUsesOfWith(From: F, To: NewF); |
| 583 | } |
| 584 | |
| 585 | // register the mutation |
| 586 | if (RetType != F->getReturnType()) |
| 587 | TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated( |
| 588 | Val: NewF, Ty: F->getReturnType()); |
| 589 | return NewF; |
| 590 | } |
| 591 | |
| 592 | // Mutates indirect callsites iff if aggregate argument/return types are present |
| 593 | // with the types replaced by i32 types. The change in types is noted in |
| 594 | // 'spv.mutated_callsites' metadata for later restoration. |
| 595 | bool SPIRVPrepareFunctions::removeAggregateTypesFromCalls(Function *F) { |
| 596 | if (F->isDeclaration() || F->isIntrinsic()) |
| 597 | return false; |
| 598 | |
| 599 | SmallVector<std::pair<CallBase *, FunctionType *>> Calls; |
| 600 | for (auto &&I : instructions(F)) { |
| 601 | if (auto *CB = dyn_cast<CallBase>(Val: &I)) { |
| 602 | if (!CB->getCalledOperand() || CB->getCalledFunction()) |
| 603 | continue; |
| 604 | if (CB->getType()->isAggregateType() || |
| 605 | any_of(Range: CB->args(), |
| 606 | P: [](auto &&Arg) { return Arg->getType()->isAggregateType(); })) |
| 607 | Calls.emplace_back(Args&: CB, Args: nullptr); |
| 608 | } |
| 609 | } |
| 610 | |
| 611 | if (Calls.empty()) |
| 612 | return false; |
| 613 | |
| 614 | IRBuilder<> B(F->getContext()); |
| 615 | |
| 616 | for (auto &&[CB, NewFnTy] : Calls) { |
| 617 | SmallVector<std::pair<int, Type *>> ChangedTypes; |
| 618 | SmallVector<Type *> NewArgTypes; |
| 619 | |
| 620 | Type *RetTy = CB->getType(); |
| 621 | if (RetTy->isAggregateType()) { |
| 622 | ChangedTypes.emplace_back(Args: -1, Args&: RetTy); |
| 623 | RetTy = B.getInt32Ty(); |
| 624 | } |
| 625 | |
| 626 | for (auto &&Arg : CB->args()) { |
| 627 | if (Arg->getType()->isAggregateType()) { |
| 628 | NewArgTypes.push_back(Elt: B.getInt32Ty()); |
| 629 | ChangedTypes.emplace_back(Args: Arg.getOperandNo(), Args: Arg->getType()); |
| 630 | } else { |
| 631 | NewArgTypes.push_back(Elt: Arg->getType()); |
| 632 | } |
| 633 | } |
| 634 | NewFnTy = FunctionType::get(Result: RetTy, Params: NewArgTypes, |
| 635 | isVarArg: CB->getFunctionType()->isVarArg()); |
| 636 | |
| 637 | if (!CB->hasName()) |
| 638 | CB->setName("spv.mutated_callsite." + F->getName()); |
| 639 | else |
| 640 | CB->setName("spv.named_mutated_callsite." + F->getName() + "." + |
| 641 | CB->getName()); |
| 642 | |
| 643 | addFunctionTypeMutation( |
| 644 | NMD: F->getParent()->getOrInsertNamedMetadata(Name: "spv.mutated_callsites" ), |
| 645 | ChangedTys: std::move(ChangedTypes), Name: CB->getName()); |
| 646 | } |
| 647 | |
| 648 | for (auto &&[CB, NewFTy] : Calls) { |
| 649 | if (NewFTy->getReturnType() != CB->getType()) |
| 650 | TM.getSubtarget<SPIRVSubtarget>(F: *F).getSPIRVGlobalRegistry()->addMutated( |
| 651 | Val: CB, Ty: CB->getType()); |
| 652 | CB->mutateFunctionType(FTy: NewFTy); |
| 653 | } |
| 654 | |
| 655 | return true; |
| 656 | } |
| 657 | |
| 658 | bool SPIRVPrepareFunctions::runOnModule(Module &M) { |
| 659 | bool Changed = false; |
| 660 | for (Function &F : M) { |
| 661 | Changed |= substituteIntrinsicCalls(F: &F); |
| 662 | Changed |= sortBlocks(F); |
| 663 | Changed |= removeAggregateTypesFromCalls(F: &F); |
| 664 | } |
| 665 | |
| 666 | std::vector<Function *> FuncsWorklist; |
| 667 | for (auto &F : M) |
| 668 | FuncsWorklist.push_back(x: &F); |
| 669 | |
| 670 | for (auto *F : FuncsWorklist) { |
| 671 | Function *NewF = removeAggregateTypesFromSignature(F); |
| 672 | |
| 673 | if (NewF != F) { |
| 674 | F->eraseFromParent(); |
| 675 | Changed = true; |
| 676 | } |
| 677 | } |
| 678 | return Changed; |
| 679 | } |
| 680 | |
| 681 | ModulePass * |
| 682 | llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) { |
| 683 | return new SPIRVPrepareFunctions(TM); |
| 684 | } |
| 685 | |