1//=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Replaces calls to LLVM Intrinsics with matching calls to functions from a
10// vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/ReplaceWithVeclib.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/Statistic.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/DemandedBits.h"
19#include "llvm/Analysis/GlobalsModRef.h"
20#include "llvm/Analysis/OptimizationRemarkEmitter.h"
21#include "llvm/Analysis/TargetLibraryInfo.h"
22#include "llvm/Analysis/VectorUtils.h"
23#include "llvm/CodeGen/Passes.h"
24#include "llvm/IR/DerivedTypes.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstIterator.h"
27#include "llvm/IR/IntrinsicInst.h"
28#include "llvm/IR/VFABIDemangler.h"
29#include "llvm/Support/TypeSize.h"
30#include "llvm/Transforms/Utils/ModuleUtils.h"
31
32using namespace llvm;
33
34#define DEBUG_TYPE "replace-with-veclib"
35
36STATISTIC(NumCallsReplaced,
37 "Number of calls to intrinsics that have been replaced.");
38
39STATISTIC(NumTLIFuncDeclAdded,
40 "Number of vector library function declarations added.");
41
42STATISTIC(NumFuncUsedAdded,
43 "Number of functions added to `llvm.compiler.used`");
44
45/// Returns a vector Function that it adds to the Module \p M. When an \p
46/// ScalarFunc is not null, it copies its attributes to the newly created
47/// Function.
48Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49 const StringRef TLIName,
50 Function *ScalarFunc = nullptr) {
51 Function *TLIFunc = M->getFunction(Name: TLIName);
52 if (!TLIFunc) {
53 TLIFunc =
54 Function::Create(Ty: VectorFTy, Linkage: Function::ExternalLinkage, N: TLIName, M&: *M);
55 if (ScalarFunc)
56 TLIFunc->copyAttributesFrom(Src: ScalarFunc);
57
58 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59 << TLIName << "` of type `" << *(TLIFunc->getType())
60 << "` to module.\n");
61
62 ++NumTLIFuncDeclAdded;
63 // Add the freshly created function to llvm.compiler.used, similar to as it
64 // is done in InjectTLIMappings.
65 appendToCompilerUsed(M&: *M, Values: {TLIFunc});
66 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67 << "` to `@llvm.compiler.used`.\n");
68 ++NumFuncUsedAdded;
69 }
70 return TLIFunc;
71}
72
73/// Replace the intrinsic call \p II to \p TLIVecFunc, which is the
74/// corresponding function from the vector library.
75static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info,
76 Function *TLIVecFunc) {
77 IRBuilder<> IRBuilder(II);
78 SmallVector<Value *> Args(II->args());
79 if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
80 auto *MaskTy =
81 VectorType::get(ElementType: Type::getInt1Ty(C&: II->getContext()), EC: Info.Shape.VF);
82 Args.insert(I: Args.begin() + OptMaskpos.value(),
83 Elt: Constant::getAllOnesValue(Ty: MaskTy));
84 }
85
86 // Preserve the operand bundles.
87 SmallVector<OperandBundleDef, 1> OpBundles;
88 II->getOperandBundlesAsDefs(Defs&: OpBundles);
89
90 auto *Replacement = IRBuilder.CreateCall(Callee: TLIVecFunc, Args, OpBundles);
91 II->replaceAllUsesWith(V: Replacement);
92 // Preserve fast math flags for FP math.
93 if (isa<FPMathOperator>(Val: Replacement))
94 Replacement->copyFastMathFlags(I: II);
95}
96
97/// Returns true when successfully replaced \p II, which is a call to a
98/// vectorized intrinsic, with a suitable function taking vector arguments,
99/// based on available mappings in the \p TLI.
100static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
101 IntrinsicInst *II) {
102 assert(II != nullptr && "Intrinsic cannot be null");
103 // At the moment VFABI assumes the return type is always widened unless it is
104 // a void type.
105 auto *VTy = dyn_cast<VectorType>(Val: II->getType());
106 ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(MinVal: 0));
107 // Compute the argument types of the corresponding scalar call and check that
108 // all vector operands match the previously found EC.
109 SmallVector<Type *, 8> ScalarArgTypes;
110 Intrinsic::ID IID = II->getIntrinsicID();
111 for (auto Arg : enumerate(First: II->args())) {
112 auto *ArgTy = Arg.value()->getType();
113 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: Arg.index())) {
114 ScalarArgTypes.push_back(Elt: ArgTy);
115 } else if (auto *VectorArgTy = dyn_cast<VectorType>(Val: ArgTy)) {
116 ScalarArgTypes.push_back(Elt: VectorArgTy->getElementType());
117 // When return type is void, set EC to the first vector argument, and
118 // disallow vector arguments with different ECs.
119 if (EC.isZero())
120 EC = VectorArgTy->getElementCount();
121 else if (EC != VectorArgTy->getElementCount())
122 return false;
123 } else
124 // Exit when it is supposed to be a vector argument but it isn't.
125 return false;
126 }
127
128 // Try to reconstruct the name for the scalar version of the instruction,
129 // using scalar argument types.
130 std::string ScalarName =
131 Intrinsic::isOverloaded(id: IID)
132 ? Intrinsic::getName(Id: IID, Tys: ScalarArgTypes, M: II->getModule())
133 : Intrinsic::getName(id: IID).str();
134
135 // Try to find the mapping for the scalar version of this intrinsic and the
136 // exact vector width of the call operands in the TargetLibraryInfo. First,
137 // check with a non-masked variant, and if that fails try with a masked one.
138 const VecDesc *VD =
139 TLI.getVectorMappingInfo(F: ScalarName, VF: EC, /*Masked*/ false);
140 if (!VD && !(VD = TLI.getVectorMappingInfo(F: ScalarName, VF: EC, /*Masked*/ true)))
141 return false;
142
143 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
144 << "` and vector width " << EC << " to: `"
145 << VD->getVectorFnName() << "`.\n");
146
147 // Replace the call to the intrinsic with a call to the vector library
148 // function.
149 Type *ScalarRetTy = II->getType()->getScalarType();
150 FunctionType *ScalarFTy =
151 FunctionType::get(Result: ScalarRetTy, Params: ScalarArgTypes, /*isVarArg*/ false);
152 const std::string MangledName = VD->getVectorFunctionABIVariantString();
153 auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, FTy: ScalarFTy);
154 if (!OptInfo)
155 return false;
156
157 // There is no guarantee that the vectorized instructions followed the VFABI
158 // specification when being created, this is why we need to add extra check to
159 // make sure that the operands of the vector function obtained via VFABI match
160 // the operands of the original vector instruction.
161 for (auto &VFParam : OptInfo->Shape.Parameters) {
162 if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
163 continue;
164
165 // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
166 // a bug in the VFABI parser.
167 assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range");
168 Type *OrigTy = II->getArgOperand(i: VFParam.ParamPos)->getType();
169 if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
170 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
171 << ". Wrong type at index " << VFParam.ParamPos << ": "
172 << *OrigTy << "\n");
173 return false;
174 }
175 }
176
177 FunctionType *VectorFTy = VFABI::createFunctionType(Info: *OptInfo, ScalarFTy);
178 if (!VectorFTy)
179 return false;
180
181 Function *TLIFunc =
182 getTLIFunction(M: II->getModule(), VectorFTy, TLIName: VD->getVectorFnName(),
183 ScalarFunc: II->getCalledFunction());
184 replaceWithTLIFunction(II, Info&: *OptInfo, TLIVecFunc: TLIFunc);
185 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
186 << "` with call to `" << TLIFunc->getName() << "`.\n");
187 ++NumCallsReplaced;
188 return true;
189}
190
191static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
192 SmallVector<Instruction *> ReplacedCalls;
193 for (auto &I : instructions(F)) {
194 // Process only intrinsic calls that return void or a vector.
195 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
196 if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy())
197 continue;
198
199 if (replaceWithCallToVeclib(TLI, II))
200 ReplacedCalls.push_back(Elt: &I);
201 }
202 }
203 // Erase any intrinsic calls that were replaced with vector library calls.
204 for (auto *I : ReplacedCalls)
205 I->eraseFromParent();
206 return !ReplacedCalls.empty();
207}
208
209////////////////////////////////////////////////////////////////////////////////
210// New pass manager implementation.
211////////////////////////////////////////////////////////////////////////////////
212PreservedAnalyses ReplaceWithVeclib::run(Function &F,
213 FunctionAnalysisManager &AM) {
214 const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F);
215 auto Changed = runImpl(TLI, F);
216 if (Changed) {
217 LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: "
218 << NumCallsReplaced << "\n");
219
220 PreservedAnalyses PA;
221 PA.preserveSet<CFGAnalyses>();
222 PA.preserve<TargetLibraryAnalysis>();
223 PA.preserve<ScalarEvolutionAnalysis>();
224 PA.preserve<LoopAccessAnalysis>();
225 PA.preserve<DemandedBitsAnalysis>();
226 PA.preserve<OptimizationRemarkEmitterAnalysis>();
227 return PA;
228 }
229
230 // The pass did not replace any calls, hence it preserves all analyses.
231 return PreservedAnalyses::all();
232}
233
234////////////////////////////////////////////////////////////////////////////////
235// Legacy PM Implementation.
236////////////////////////////////////////////////////////////////////////////////
237bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
238 const TargetLibraryInfo &TLI =
239 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
240 return runImpl(TLI, F);
241}
242
243void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
244 AU.setPreservesCFG();
245 AU.addRequired<TargetLibraryInfoWrapperPass>();
246 AU.addPreserved<TargetLibraryInfoWrapperPass>();
247 AU.addPreserved<ScalarEvolutionWrapperPass>();
248 AU.addPreserved<AAResultsWrapperPass>();
249 AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
250 AU.addPreserved<GlobalsAAWrapperPass>();
251}
252
253////////////////////////////////////////////////////////////////////////////////
254// Legacy Pass manager initialization
255////////////////////////////////////////////////////////////////////////////////
256char ReplaceWithVeclibLegacy::ID = 0;
257
258INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
259 "Replace intrinsics with calls to vector library", false,
260 false)
261INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
262INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
263 "Replace intrinsics with calls to vector library", false,
264 false)
265
266FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
267 return new ReplaceWithVeclibLegacy();
268}
269