1 | //===---- EVLIndVarSimplify.cpp - Optimize vectorized loops w/ EVL IV------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass optimizes a vectorized loop with canonical IV to using EVL-based |
10 | // IV if it was tail-folded by predicated EVL. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Transforms/Vectorize/EVLIndVarSimplify.h" |
15 | #include "llvm/ADT/Statistic.h" |
16 | #include "llvm/Analysis/IVDescriptors.h" |
17 | #include "llvm/Analysis/LoopInfo.h" |
18 | #include "llvm/Analysis/LoopPass.h" |
19 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
20 | #include "llvm/Analysis/ScalarEvolution.h" |
21 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
22 | #include "llvm/Analysis/ValueTracking.h" |
23 | #include "llvm/IR/IRBuilder.h" |
24 | #include "llvm/IR/PatternMatch.h" |
25 | #include "llvm/Support/CommandLine.h" |
26 | #include "llvm/Support/Debug.h" |
27 | #include "llvm/Support/MathExtras.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include "llvm/Transforms/Scalar/LoopPassManager.h" |
30 | #include "llvm/Transforms/Utils/Local.h" |
31 | |
32 | #define DEBUG_TYPE "evl-iv-simplify" |
33 | |
34 | using namespace llvm; |
35 | |
36 | STATISTIC(NumEliminatedCanonicalIV, "Number of canonical IVs we eliminated" ); |
37 | |
38 | static cl::opt<bool> EnableEVLIndVarSimplify( |
39 | "enable-evl-indvar-simplify" , |
40 | cl::desc("Enable EVL-based induction variable simplify Pass" ), cl::Hidden, |
41 | cl::init(Val: true)); |
42 | |
43 | namespace { |
44 | struct EVLIndVarSimplifyImpl { |
45 | ScalarEvolution &SE; |
46 | OptimizationRemarkEmitter *ORE = nullptr; |
47 | |
48 | EVLIndVarSimplifyImpl(LoopStandardAnalysisResults &LAR, |
49 | OptimizationRemarkEmitter *ORE) |
50 | : SE(LAR.SE), ORE(ORE) {} |
51 | |
52 | /// Returns true if modify the loop. |
53 | bool run(Loop &L); |
54 | }; |
55 | } // anonymous namespace |
56 | |
57 | /// Returns the constant part of vectorization factor from the induction |
58 | /// variable's step value SCEV expression. |
59 | static uint32_t getVFFromIndVar(const SCEV *Step, const Function &F) { |
60 | if (!Step) |
61 | return 0U; |
62 | |
63 | // Looking for loops with IV step value in the form of `(<constant VF> x |
64 | // vscale)`. |
65 | if (const auto *Mul = dyn_cast<SCEVMulExpr>(Val: Step)) { |
66 | if (Mul->getNumOperands() == 2) { |
67 | const SCEV *LHS = Mul->getOperand(i: 0); |
68 | const SCEV *RHS = Mul->getOperand(i: 1); |
69 | if (const auto *Const = dyn_cast<SCEVConstant>(Val: LHS); |
70 | Const && isa<SCEVVScale>(Val: RHS)) { |
71 | uint64_t V = Const->getAPInt().getLimitedValue(); |
72 | if (llvm::isUInt<32>(x: V)) |
73 | return V; |
74 | } |
75 | } |
76 | } |
77 | |
78 | // If not, see if the vscale_range of the parent function is a fixed value, |
79 | // which makes the step value to be replaced by a constant. |
80 | if (F.hasFnAttribute(Kind: Attribute::VScaleRange)) |
81 | if (const auto *ConstStep = dyn_cast<SCEVConstant>(Val: Step)) { |
82 | APInt V = ConstStep->getAPInt().abs(); |
83 | ConstantRange CR = llvm::getVScaleRange(F: &F, BitWidth: 64); |
84 | if (const APInt *Fixed = CR.getSingleElement()) { |
85 | V = V.zextOrTrunc(width: Fixed->getBitWidth()); |
86 | uint64_t VF = V.udiv(RHS: *Fixed).getLimitedValue(); |
87 | if (VF && llvm::isUInt<32>(x: VF) && |
88 | // Make sure step is divisible by vscale. |
89 | V.urem(RHS: *Fixed).isZero()) |
90 | return VF; |
91 | } |
92 | } |
93 | |
94 | return 0U; |
95 | } |
96 | |
97 | bool EVLIndVarSimplifyImpl::run(Loop &L) { |
98 | if (!EnableEVLIndVarSimplify) |
99 | return false; |
100 | |
101 | if (!getBooleanLoopAttribute(TheLoop: &L, Name: "llvm.loop.isvectorized" )) |
102 | return false; |
103 | const MDOperand *EVLMD = |
104 | findStringMetadataForLoop(TheLoop: &L, Name: "llvm.loop.isvectorized.tailfoldingstyle" ) |
105 | .value_or(u: nullptr); |
106 | if (!EVLMD || !EVLMD->equalsStr(Str: "evl" )) |
107 | return false; |
108 | |
109 | BasicBlock *LatchBlock = L.getLoopLatch(); |
110 | ICmpInst *OrigLatchCmp = L.getLatchCmpInst(); |
111 | if (!LatchBlock || !OrigLatchCmp) |
112 | return false; |
113 | |
114 | InductionDescriptor IVD; |
115 | PHINode *IndVar = L.getInductionVariable(SE); |
116 | if (!IndVar || !L.getInductionDescriptor(SE, IndDesc&: IVD)) { |
117 | const char *Reason = (IndVar ? "induction descriptor is not available" |
118 | : "cannot recognize induction variable" ); |
119 | LLVM_DEBUG(dbgs() << "Cannot retrieve IV from loop " << L.getName() |
120 | << " because" << Reason << "\n" ); |
121 | if (ORE) { |
122 | ORE->emit(RemarkBuilder: [&]() { |
123 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar" , |
124 | L.getStartLoc(), L.getHeader()) |
125 | << "Cannot retrieve IV because " << ore::NV("Reason" , Reason); |
126 | }); |
127 | } |
128 | return false; |
129 | } |
130 | |
131 | BasicBlock *InitBlock, *BackEdgeBlock; |
132 | if (!L.getIncomingAndBackEdge(Incoming&: InitBlock, Backedge&: BackEdgeBlock)) { |
133 | LLVM_DEBUG(dbgs() << "Expect unique incoming and backedge in " |
134 | << L.getName() << "\n" ); |
135 | if (ORE) { |
136 | ORE->emit(RemarkBuilder: [&]() { |
137 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure" , |
138 | L.getStartLoc(), L.getHeader()) |
139 | << "Does not have a unique incoming and backedge" ; |
140 | }); |
141 | } |
142 | return false; |
143 | } |
144 | |
145 | // Retrieve the loop bounds. |
146 | std::optional<Loop::LoopBounds> Bounds = L.getBounds(SE); |
147 | if (!Bounds) { |
148 | LLVM_DEBUG(dbgs() << "Could not obtain the bounds for loop " << L.getName() |
149 | << "\n" ); |
150 | if (ORE) { |
151 | ORE->emit(RemarkBuilder: [&]() { |
152 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedLoopStructure" , |
153 | L.getStartLoc(), L.getHeader()) |
154 | << "Could not obtain the loop bounds" ; |
155 | }); |
156 | } |
157 | return false; |
158 | } |
159 | Value *CanonicalIVInit = &Bounds->getInitialIVValue(); |
160 | Value *CanonicalIVFinal = &Bounds->getFinalIVValue(); |
161 | |
162 | const SCEV *StepV = IVD.getStep(); |
163 | uint32_t VF = getVFFromIndVar(Step: StepV, F: *L.getHeader()->getParent()); |
164 | if (!VF) { |
165 | LLVM_DEBUG(dbgs() << "Could not infer VF from IndVar step '" << *StepV |
166 | << "'\n" ); |
167 | if (ORE) { |
168 | ORE->emit(RemarkBuilder: [&]() { |
169 | return OptimizationRemarkMissed(DEBUG_TYPE, "UnrecognizedIndVar" , |
170 | L.getStartLoc(), L.getHeader()) |
171 | << "Could not infer VF from IndVar step " |
172 | << ore::NV("Step" , StepV); |
173 | }); |
174 | } |
175 | return false; |
176 | } |
177 | LLVM_DEBUG(dbgs() << "Using VF=" << VF << " for loop " << L.getName() |
178 | << "\n" ); |
179 | |
180 | // Try to find the EVL-based induction variable. |
181 | using namespace PatternMatch; |
182 | BasicBlock *BB = IndVar->getParent(); |
183 | |
184 | Value *EVLIndVar = nullptr; |
185 | Value *RemTC = nullptr; |
186 | Value *TC = nullptr; |
187 | auto IntrinsicMatch = m_Intrinsic<Intrinsic::experimental_get_vector_length>( |
188 | Op0: m_Value(V&: RemTC), Op1: m_SpecificInt(V: VF), |
189 | /*Scalable=*/Op2: m_SpecificInt(V: 1)); |
190 | for (PHINode &PN : BB->phis()) { |
191 | if (&PN == IndVar) |
192 | continue; |
193 | |
194 | // Check 1: it has to contain both incoming (init) & backedge blocks |
195 | // from IndVar. |
196 | if (PN.getBasicBlockIndex(BB: InitBlock) < 0 || |
197 | PN.getBasicBlockIndex(BB: BackEdgeBlock) < 0) |
198 | continue; |
199 | // Check 2: EVL index is always increasing, thus its inital value has to be |
200 | // equal to either the initial IV value (when the canonical IV is also |
201 | // increasing) or the last IV value (when canonical IV is decreasing). |
202 | Value *Init = PN.getIncomingValueForBlock(BB: InitBlock); |
203 | using Direction = Loop::LoopBounds::Direction; |
204 | switch (Bounds->getDirection()) { |
205 | case Direction::Increasing: |
206 | if (Init != CanonicalIVInit) |
207 | continue; |
208 | break; |
209 | case Direction::Decreasing: |
210 | if (Init != CanonicalIVFinal) |
211 | continue; |
212 | break; |
213 | case Direction::Unknown: |
214 | // To be more permissive and see if either the initial or final IV value |
215 | // matches PN's init value. |
216 | if (Init != CanonicalIVInit && Init != CanonicalIVFinal) |
217 | continue; |
218 | break; |
219 | } |
220 | Value *RecValue = PN.getIncomingValueForBlock(BB: BackEdgeBlock); |
221 | assert(RecValue && "expect recurrent IndVar value" ); |
222 | |
223 | LLVM_DEBUG(dbgs() << "Found candidate PN of EVL-based IndVar: " << PN |
224 | << "\n" ); |
225 | |
226 | // Check 3: Pattern match to find the EVL-based index and total trip count |
227 | // (TC). |
228 | if (match(V: RecValue, |
229 | P: m_c_Add(L: m_ZExtOrSelf(Op: IntrinsicMatch), R: m_Specific(V: &PN))) && |
230 | match(V: RemTC, P: m_Sub(L: m_Value(V&: TC), R: m_Specific(V: &PN)))) { |
231 | EVLIndVar = RecValue; |
232 | break; |
233 | } |
234 | } |
235 | |
236 | if (!EVLIndVar || !TC) |
237 | return false; |
238 | |
239 | LLVM_DEBUG(dbgs() << "Using " << *EVLIndVar << " for EVL-based IndVar\n" ); |
240 | if (ORE) { |
241 | ORE->emit(RemarkBuilder: [&]() { |
242 | DebugLoc DL; |
243 | BasicBlock *Region = nullptr; |
244 | if (auto *I = dyn_cast<Instruction>(Val: EVLIndVar)) { |
245 | DL = I->getDebugLoc(); |
246 | Region = I->getParent(); |
247 | } else { |
248 | DL = L.getStartLoc(); |
249 | Region = L.getHeader(); |
250 | } |
251 | return OptimizationRemark(DEBUG_TYPE, "UseEVLIndVar" , DL, Region) |
252 | << "Using " << ore::NV("EVLIndVar" , EVLIndVar) |
253 | << " for EVL-based IndVar" ; |
254 | }); |
255 | } |
256 | |
257 | // Create an EVL-based comparison and replace the branch to use it as |
258 | // predicate. |
259 | |
260 | // Loop::getLatchCmpInst check at the beginning of this function has ensured |
261 | // that latch block ends in a conditional branch. |
262 | auto *LatchBranch = cast<BranchInst>(Val: LatchBlock->getTerminator()); |
263 | assert(LatchBranch->isConditional() && |
264 | "expect the loop latch to be ended with a conditional branch" ); |
265 | ICmpInst::Predicate Pred; |
266 | if (LatchBranch->getSuccessor(i: 0) == L.getHeader()) |
267 | Pred = ICmpInst::ICMP_NE; |
268 | else |
269 | Pred = ICmpInst::ICMP_EQ; |
270 | |
271 | IRBuilder<> Builder(OrigLatchCmp); |
272 | auto *NewLatchCmp = Builder.CreateICmp(P: Pred, LHS: EVLIndVar, RHS: TC); |
273 | OrigLatchCmp->replaceAllUsesWith(V: NewLatchCmp); |
274 | |
275 | // llvm::RecursivelyDeleteDeadPHINode only deletes cycles whose values are |
276 | // not used outside the cycles. However, in this case the now-RAUW-ed |
277 | // OrigLatchCmp will be considered a use outside the cycle while in reality |
278 | // it's practically dead. Thus we need to remove it before calling |
279 | // RecursivelyDeleteDeadPHINode. |
280 | (void)RecursivelyDeleteTriviallyDeadInstructions(V: OrigLatchCmp); |
281 | if (llvm::RecursivelyDeleteDeadPHINode(PN: IndVar)) |
282 | LLVM_DEBUG(dbgs() << "Removed original IndVar\n" ); |
283 | |
284 | ++NumEliminatedCanonicalIV; |
285 | |
286 | return true; |
287 | } |
288 | |
289 | PreservedAnalyses EVLIndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &LAM, |
290 | LoopStandardAnalysisResults &AR, |
291 | LPMUpdater &U) { |
292 | Function &F = *L.getHeader()->getParent(); |
293 | auto &FAMProxy = LAM.getResult<FunctionAnalysisManagerLoopProxy>(IR&: L, ExtraArgs&: AR); |
294 | OptimizationRemarkEmitter *ORE = |
295 | FAMProxy.getCachedResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
296 | |
297 | if (EVLIndVarSimplifyImpl(AR, ORE).run(L)) |
298 | return PreservedAnalyses::allInSet<CFGAnalyses>(); |
299 | return PreservedAnalyses::all(); |
300 | } |
301 | |