1//=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Replaces calls to LLVM Intrinsics with matching calls to functions from a
10// vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/ReplaceWithVeclib.h"
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/Statistic.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/DemandedBits.h"
19#include "llvm/Analysis/GlobalsModRef.h"
20#include "llvm/Analysis/OptimizationRemarkEmitter.h"
21#include "llvm/Analysis/TargetLibraryInfo.h"
22#include "llvm/Analysis/VectorUtils.h"
23#include "llvm/CodeGen/Passes.h"
24#include "llvm/IR/DerivedTypes.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstIterator.h"
27#include "llvm/IR/IntrinsicInst.h"
28#include "llvm/IR/VFABIDemangler.h"
29#include "llvm/InitializePasses.h"
30#include "llvm/Support/TypeSize.h"
31#include "llvm/Transforms/Utils/ModuleUtils.h"
32
33using namespace llvm;
34
35#define DEBUG_TYPE "replace-with-veclib"
36
37STATISTIC(NumCallsReplaced,
38 "Number of calls to intrinsics that have been replaced.");
39
40STATISTIC(NumTLIFuncDeclAdded,
41 "Number of vector library function declarations added.");
42
43STATISTIC(NumFuncUsedAdded,
44 "Number of functions added to `llvm.compiler.used`");
45
46/// Returns a vector Function that it adds to the Module \p M. When an \p
47/// ScalarFunc is not null, it copies its attributes to the newly created
48/// Function.
49Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
50 const StringRef TLIName,
51 std::optional<CallingConv::ID> CC,
52 Function *ScalarFunc = nullptr) {
53 Function *TLIFunc = M->getFunction(Name: TLIName);
54 if (!TLIFunc) {
55 TLIFunc =
56 Function::Create(Ty: VectorFTy, Linkage: Function::ExternalLinkage, N: TLIName, M&: *M);
57 if (ScalarFunc)
58 TLIFunc->copyAttributesFrom(Src: ScalarFunc);
59 if (CC)
60 TLIFunc->setCallingConv(*CC);
61
62 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
63 << TLIName << "` of type `" << *(TLIFunc->getType())
64 << "` to module.\n");
65
66 ++NumTLIFuncDeclAdded;
67 // Add the freshly created function to llvm.compiler.used, similar to as it
68 // is done in InjectTLIMappings.
69 appendToCompilerUsed(M&: *M, Values: {TLIFunc});
70 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
71 << "` to `@llvm.compiler.used`.\n");
72 ++NumFuncUsedAdded;
73 }
74 return TLIFunc;
75}
76
77/// Replace the intrinsic call \p II to \p TLIVecFunc, which is the
78/// corresponding function from the vector library.
79static void replaceWithTLIFunction(IntrinsicInst *II, VFInfo &Info,
80 Function *TLIVecFunc) {
81 IRBuilder<> IRBuilder(II);
82 SmallVector<Value *> Args(II->args());
83 if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
84 auto *MaskTy =
85 VectorType::get(ElementType: Type::getInt1Ty(C&: II->getContext()), EC: Info.Shape.VF);
86 Args.insert(I: Args.begin() + OptMaskpos.value(),
87 Elt: Constant::getAllOnesValue(Ty: MaskTy));
88 }
89
90 // Preserve the operand bundles.
91 SmallVector<OperandBundleDef, 1> OpBundles;
92 II->getOperandBundlesAsDefs(Defs&: OpBundles);
93
94 auto *Replacement = IRBuilder.CreateCall(Callee: TLIVecFunc, Args, OpBundles);
95 II->replaceAllUsesWith(V: Replacement);
96 // Preserve fast math flags for FP math.
97 if (isa<FPMathOperator>(Val: Replacement))
98 Replacement->copyFastMathFlags(I: II);
99 Replacement->setCallingConv(TLIVecFunc->getCallingConv());
100}
101
102/// Returns true when successfully replaced \p II, which is a call to a
103/// vectorized intrinsic, with a suitable function taking vector arguments,
104/// based on available mappings in the \p TLI.
105static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
106 IntrinsicInst *II) {
107 assert(II != nullptr && "Intrinsic cannot be null");
108 Intrinsic::ID IID = II->getIntrinsicID();
109 Type *RetTy = II->getType();
110 Type *ScalarRetTy = RetTy->getScalarType();
111 // At the moment VFABI assumes the return type is always widened unless it is
112 // a void type.
113 auto *VTy = dyn_cast<VectorType>(Val: RetTy);
114 ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(MinVal: 0));
115
116 // OloadTys collects types used in scalar intrinsic overload name.
117 SmallVector<Type *, 3> OloadTys;
118 if (!RetTy->isVoidTy() &&
119 isVectorIntrinsicWithOverloadTypeAtArg(ID: IID, OpdIdx: -1, /*TTI=*/nullptr))
120 OloadTys.push_back(Elt: ScalarRetTy);
121
122 // Compute the argument types of the corresponding scalar call and check that
123 // all vector operands match the previously found EC.
124 SmallVector<Type *, 8> ScalarArgTypes;
125 for (auto Arg : enumerate(First: II->args())) {
126 auto *ArgTy = Arg.value()->getType();
127 bool IsOloadTy = isVectorIntrinsicWithOverloadTypeAtArg(ID: IID, OpdIdx: Arg.index(),
128 /*TTI=*/nullptr);
129 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: Arg.index(), /*TTI=*/nullptr)) {
130 ScalarArgTypes.push_back(Elt: ArgTy);
131 if (IsOloadTy)
132 OloadTys.push_back(Elt: ArgTy);
133 } else if (auto *VectorArgTy = dyn_cast<VectorType>(Val: ArgTy)) {
134 auto *ScalarArgTy = VectorArgTy->getElementType();
135 ScalarArgTypes.push_back(Elt: ScalarArgTy);
136 if (IsOloadTy)
137 OloadTys.push_back(Elt: ScalarArgTy);
138 // When return type is void, set EC to the first vector argument, and
139 // disallow vector arguments with different ECs.
140 if (EC.isZero())
141 EC = VectorArgTy->getElementCount();
142 else if (EC != VectorArgTy->getElementCount())
143 return false;
144 } else
145 // Exit when it is supposed to be a vector argument but it isn't.
146 return false;
147 }
148
149 // Try to reconstruct the name for the scalar version of the instruction,
150 // using scalar argument types.
151 std::string ScalarName =
152 Intrinsic::isOverloaded(id: IID)
153 ? Intrinsic::getName(Id: IID, Tys: OloadTys, M: II->getModule())
154 : Intrinsic::getName(id: IID).str();
155
156 // Try to find the mapping for the scalar version of this intrinsic and the
157 // exact vector width of the call operands in the TargetLibraryInfo. First,
158 // check with a non-masked variant, and if that fails try with a masked one.
159 const VecDesc *VD =
160 TLI.getVectorMappingInfo(F: ScalarName, VF: EC, /*Masked*/ false);
161 if (!VD && !(VD = TLI.getVectorMappingInfo(F: ScalarName, VF: EC, /*Masked*/ true)))
162 return false;
163
164 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
165 << "` and vector width " << EC << " to: `"
166 << VD->getVectorFnName() << "`.\n");
167
168 // Replace the call to the intrinsic with a call to the vector library
169 // function.
170 FunctionType *ScalarFTy =
171 FunctionType::get(Result: ScalarRetTy, Params: ScalarArgTypes, /*isVarArg*/ false);
172 const std::string MangledName = VD->getVectorFunctionABIVariantString();
173 auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, FTy: ScalarFTy);
174 if (!OptInfo)
175 return false;
176
177 // There is no guarantee that the vectorized instructions followed the VFABI
178 // specification when being created, this is why we need to add extra check to
179 // make sure that the operands of the vector function obtained via VFABI match
180 // the operands of the original vector instruction.
181 for (auto &VFParam : OptInfo->Shape.Parameters) {
182 if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
183 continue;
184
185 // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
186 // a bug in the VFABI parser.
187 assert(VFParam.ParamPos < II->arg_size() && "ParamPos has invalid range");
188 Type *OrigTy = II->getArgOperand(i: VFParam.ParamPos)->getType();
189 if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
190 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
191 << ". Wrong type at index " << VFParam.ParamPos << ": "
192 << *OrigTy << "\n");
193 return false;
194 }
195 }
196
197 FunctionType *VectorFTy = VFABI::createFunctionType(Info: *OptInfo, ScalarFTy);
198 if (!VectorFTy)
199 return false;
200
201 Function *TLIFunc =
202 getTLIFunction(M: II->getModule(), VectorFTy, TLIName: VD->getVectorFnName(),
203 CC: VD->getCallingConv(), ScalarFunc: II->getCalledFunction());
204 replaceWithTLIFunction(II, Info&: *OptInfo, TLIVecFunc: TLIFunc);
205 LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
206 << "` with call to `" << TLIFunc->getName() << "`.\n");
207 ++NumCallsReplaced;
208 return true;
209}
210
211static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
212 SmallVector<Instruction *> ReplacedCalls;
213 for (auto &I : instructions(F)) {
214 // Process only intrinsic calls that return void or a vector.
215 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) {
216 if (II->getIntrinsicID() == Intrinsic::not_intrinsic)
217 continue;
218 if (!II->getType()->isVectorTy() && !II->getType()->isVoidTy())
219 continue;
220
221 if (replaceWithCallToVeclib(TLI, II))
222 ReplacedCalls.push_back(Elt: &I);
223 }
224 }
225 // Erase any intrinsic calls that were replaced with vector library calls.
226 for (auto *I : ReplacedCalls)
227 I->eraseFromParent();
228 return !ReplacedCalls.empty();
229}
230
231////////////////////////////////////////////////////////////////////////////////
232// New pass manager implementation.
233////////////////////////////////////////////////////////////////////////////////
234PreservedAnalyses ReplaceWithVeclib::run(Function &F,
235 FunctionAnalysisManager &AM) {
236 const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(IR&: F);
237 auto Changed = runImpl(TLI, F);
238 if (Changed) {
239 LLVM_DEBUG(dbgs() << "Intrinsic calls replaced with vector libraries: "
240 << NumCallsReplaced << "\n");
241
242 PreservedAnalyses PA;
243 PA.preserveSet<CFGAnalyses>();
244 PA.preserve<TargetLibraryAnalysis>();
245 PA.preserve<ScalarEvolutionAnalysis>();
246 PA.preserve<LoopAccessAnalysis>();
247 PA.preserve<DemandedBitsAnalysis>();
248 PA.preserve<OptimizationRemarkEmitterAnalysis>();
249 return PA;
250 }
251
252 // The pass did not replace any calls, hence it preserves all analyses.
253 return PreservedAnalyses::all();
254}
255
256////////////////////////////////////////////////////////////////////////////////
257// Legacy PM Implementation.
258////////////////////////////////////////////////////////////////////////////////
259bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
260 const TargetLibraryInfo &TLI =
261 getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
262 return runImpl(TLI, F);
263}
264
265void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
266 AU.setPreservesCFG();
267 AU.addRequired<TargetLibraryInfoWrapperPass>();
268 AU.addPreserved<TargetLibraryInfoWrapperPass>();
269 AU.addPreserved<ScalarEvolutionWrapperPass>();
270 AU.addPreserved<AAResultsWrapperPass>();
271 AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
272 AU.addPreserved<GlobalsAAWrapperPass>();
273}
274
275////////////////////////////////////////////////////////////////////////////////
276// Legacy Pass manager initialization
277////////////////////////////////////////////////////////////////////////////////
278char ReplaceWithVeclibLegacy::ID = 0;
279
280INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
281 "Replace intrinsics with calls to vector library", false,
282 false)
283INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
284INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
285 "Replace intrinsics with calls to vector library", false,
286 false)
287
288FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
289 return new ReplaceWithVeclibLegacy();
290}
291