1 | //===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | //===----------------------------------------------------------------------===// |
9 | |
10 | #include "llvm/Transforms/Scalar/LoopTermFold.h" |
11 | #include "llvm/ADT/Statistic.h" |
12 | #include "llvm/Analysis/LoopAnalysisManager.h" |
13 | #include "llvm/Analysis/LoopInfo.h" |
14 | #include "llvm/Analysis/LoopPass.h" |
15 | #include "llvm/Analysis/MemorySSA.h" |
16 | #include "llvm/Analysis/MemorySSAUpdater.h" |
17 | #include "llvm/Analysis/ScalarEvolution.h" |
18 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
19 | #include "llvm/Analysis/TargetLibraryInfo.h" |
20 | #include "llvm/Analysis/TargetTransformInfo.h" |
21 | #include "llvm/Analysis/ValueTracking.h" |
22 | #include "llvm/Config/llvm-config.h" |
23 | #include "llvm/IR/BasicBlock.h" |
24 | #include "llvm/IR/Dominators.h" |
25 | #include "llvm/IR/IRBuilder.h" |
26 | #include "llvm/IR/InstrTypes.h" |
27 | #include "llvm/IR/Instruction.h" |
28 | #include "llvm/IR/Instructions.h" |
29 | #include "llvm/IR/Type.h" |
30 | #include "llvm/IR/Value.h" |
31 | #include "llvm/InitializePasses.h" |
32 | #include "llvm/Pass.h" |
33 | #include "llvm/Support/Debug.h" |
34 | #include "llvm/Support/raw_ostream.h" |
35 | #include "llvm/Transforms/Scalar.h" |
36 | #include "llvm/Transforms/Utils.h" |
37 | #include "llvm/Transforms/Utils/BasicBlockUtils.h" |
38 | #include "llvm/Transforms/Utils/Local.h" |
39 | #include "llvm/Transforms/Utils/LoopUtils.h" |
40 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
41 | #include <cassert> |
42 | #include <optional> |
43 | |
44 | using namespace llvm; |
45 | |
46 | #define DEBUG_TYPE "loop-term-fold" |
47 | |
48 | STATISTIC(NumTermFold, |
49 | "Number of terminating condition fold recognized and performed" ); |
50 | |
51 | static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> |
52 | canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, |
53 | const LoopInfo &LI, const TargetTransformInfo &TTI) { |
54 | if (!L->isInnermost()) { |
55 | LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n" ); |
56 | return std::nullopt; |
57 | } |
58 | // Only inspect on simple loop structure |
59 | if (!L->isLoopSimplifyForm()) { |
60 | LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n" ); |
61 | return std::nullopt; |
62 | } |
63 | |
64 | if (!SE.hasLoopInvariantBackedgeTakenCount(L)) { |
65 | LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n" ); |
66 | return std::nullopt; |
67 | } |
68 | |
69 | BasicBlock *LoopLatch = L->getLoopLatch(); |
70 | BranchInst *BI = dyn_cast<BranchInst>(Val: LoopLatch->getTerminator()); |
71 | if (!BI || BI->isUnconditional()) |
72 | return std::nullopt; |
73 | auto *TermCond = dyn_cast<ICmpInst>(Val: BI->getCondition()); |
74 | if (!TermCond) { |
75 | LLVM_DEBUG( |
76 | dbgs() << "Cannot fold on branching condition that is not an ICmpInst" ); |
77 | return std::nullopt; |
78 | } |
79 | if (!TermCond->hasOneUse()) { |
80 | LLVM_DEBUG( |
81 | dbgs() |
82 | << "Cannot replace terminating condition with more than one use\n" ); |
83 | return std::nullopt; |
84 | } |
85 | |
86 | BinaryOperator *LHS = dyn_cast<BinaryOperator>(Val: TermCond->getOperand(i_nocapture: 0)); |
87 | Value *RHS = TermCond->getOperand(i_nocapture: 1); |
88 | if (!LHS || !L->isLoopInvariant(V: RHS)) |
89 | // We could pattern match the inverse form of the icmp, but that is |
90 | // non-canonical, and this pass is running *very* late in the pipeline. |
91 | return std::nullopt; |
92 | |
93 | // Find the IV used by the current exit condition. |
94 | PHINode *ToFold; |
95 | Value *ToFoldStart, *ToFoldStep; |
96 | if (!matchSimpleRecurrence(I: LHS, P&: ToFold, Start&: ToFoldStart, Step&: ToFoldStep)) |
97 | return std::nullopt; |
98 | |
99 | // Ensure the simple recurrence is a part of the current loop. |
100 | if (ToFold->getParent() != L->getHeader()) |
101 | return std::nullopt; |
102 | |
103 | // If that IV isn't dead after we rewrite the exit condition in terms of |
104 | // another IV, there's no point in doing the transform. |
105 | if (!isAlmostDeadIV(IV: ToFold, LatchBlock: LoopLatch, Cond: TermCond)) |
106 | return std::nullopt; |
107 | |
108 | // Inserting instructions in the preheader has a runtime cost, scale |
109 | // the allowed cost with the loops trip count as best we can. |
110 | const unsigned ExpansionBudget = [&]() { |
111 | unsigned Budget = 2 * SCEVCheapExpansionBudget; |
112 | if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L)) |
113 | return std::min(a: Budget, b: SmallTC); |
114 | if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L)) |
115 | return std::min(a: Budget, b: *SmallTC); |
116 | // Unknown trip count, assume long running by default. |
117 | return Budget; |
118 | }(); |
119 | |
120 | const SCEV *BECount = SE.getBackedgeTakenCount(L); |
121 | const DataLayout &DL = L->getHeader()->getDataLayout(); |
122 | SCEVExpander Expander(SE, DL, "lsr_fold_term_cond" ); |
123 | |
124 | PHINode *ToHelpFold = nullptr; |
125 | const SCEV *TermValueS = nullptr; |
126 | bool MustDropPoison = false; |
127 | auto InsertPt = L->getLoopPreheader()->getTerminator(); |
128 | for (PHINode &PN : L->getHeader()->phis()) { |
129 | if (ToFold == &PN) |
130 | continue; |
131 | |
132 | if (!SE.isSCEVable(Ty: PN.getType())) { |
133 | LLVM_DEBUG(dbgs() << "IV of phi '" << PN |
134 | << "' is not SCEV-able, not qualified for the " |
135 | "terminating condition folding.\n" ); |
136 | continue; |
137 | } |
138 | const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: &PN)); |
139 | // Only speculate on affine AddRec |
140 | if (!AddRec || !AddRec->isAffine()) { |
141 | LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN |
142 | << "' is not an affine add recursion, not qualified " |
143 | "for the terminating condition folding.\n" ); |
144 | continue; |
145 | } |
146 | |
147 | // Check that we can compute the value of AddRec on the exiting iteration |
148 | // without soundness problems. evaluateAtIteration internally needs |
149 | // to multiply the stride of the iteration number - which may wrap around. |
150 | // The issue here is subtle because computing the result accounting for |
151 | // wrap is insufficient. In order to use the result in an exit test, we |
152 | // must also know that AddRec doesn't take the same value on any previous |
153 | // iteration. The simplest case to consider is a candidate IV which is |
154 | // narrower than the trip count (and thus original IV), but this can |
155 | // also happen due to non-unit strides on the candidate IVs. |
156 | if (!AddRec->hasNoSelfWrap() || |
157 | !SE.isKnownNonZero(S: AddRec->getStepRecurrence(SE))) |
158 | continue; |
159 | |
160 | const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); |
161 | const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(It: BECount, SE); |
162 | if (!Expander.isSafeToExpand(S: TermValueSLocal)) { |
163 | LLVM_DEBUG( |
164 | dbgs() << "Is not safe to expand terminating value for phi node" << PN |
165 | << "\n" ); |
166 | continue; |
167 | } |
168 | |
169 | if (Expander.isHighCostExpansion(Exprs: TermValueSLocal, L, Budget: ExpansionBudget, TTI: &TTI, |
170 | At: InsertPt)) { |
171 | LLVM_DEBUG( |
172 | dbgs() << "Is too expensive to expand terminating value for phi node" |
173 | << PN << "\n" ); |
174 | continue; |
175 | } |
176 | |
177 | // The candidate IV may have been otherwise dead and poison from the |
178 | // very first iteration. If we can't disprove that, we can't use the IV. |
179 | if (!mustExecuteUBIfPoisonOnPathTo(Root: &PN, OnPathTo: LoopLatch->getTerminator(), DT: &DT)) { |
180 | LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n" ); |
181 | continue; |
182 | } |
183 | |
184 | // The candidate IV may become poison on the last iteration. If this |
185 | // value is not branched on, this is a well defined program. We're |
186 | // about to add a new use to this IV, and we have to ensure we don't |
187 | // insert UB which didn't previously exist. |
188 | bool MustDropPoisonLocal = false; |
189 | Instruction *PostIncV = |
190 | cast<Instruction>(Val: PN.getIncomingValueForBlock(BB: LoopLatch)); |
191 | if (!mustExecuteUBIfPoisonOnPathTo(Root: PostIncV, OnPathTo: LoopLatch->getTerminator(), |
192 | DT: &DT)) { |
193 | LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN |
194 | << "\n" ); |
195 | |
196 | // If this is a complex recurrance with multiple instructions computing |
197 | // the backedge value, we might need to strip poison flags from all of |
198 | // them. |
199 | if (PostIncV->getOperand(i: 0) != &PN) |
200 | continue; |
201 | |
202 | // In order to perform the transform, we need to drop the poison |
203 | // generating flags on this instruction (if any). |
204 | MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); |
205 | } |
206 | |
207 | // We pick the last legal alternate IV. We could expore choosing an optimal |
208 | // alternate IV if we had a decent heuristic to do so. |
209 | ToHelpFold = &PN; |
210 | TermValueS = TermValueSLocal; |
211 | MustDropPoison = MustDropPoisonLocal; |
212 | } |
213 | |
214 | LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() |
215 | << "Cannot find other AddRec IV to help folding\n" ;); |
216 | |
217 | LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs() |
218 | << "\nFound loop that can fold terminating condition\n" |
219 | << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n" |
220 | << " TermCond: " << *TermCond << "\n" |
221 | << " BrandInst: " << *BI << "\n" |
222 | << " ToFold: " << *ToFold << "\n" |
223 | << " ToHelpFold: " << *ToHelpFold << "\n" ); |
224 | |
225 | if (!ToFold || !ToHelpFold) |
226 | return std::nullopt; |
227 | return std::make_tuple(args&: ToFold, args&: ToHelpFold, args&: TermValueS, args&: MustDropPoison); |
228 | } |
229 | |
230 | static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT, |
231 | LoopInfo &LI, const TargetTransformInfo &TTI, |
232 | TargetLibraryInfo &TLI, MemorySSA *MSSA) { |
233 | std::unique_ptr<MemorySSAUpdater> MSSAU; |
234 | if (MSSA) |
235 | MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA); |
236 | |
237 | auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI); |
238 | if (!Opt) |
239 | return false; |
240 | |
241 | auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; |
242 | |
243 | NumTermFold++; |
244 | |
245 | BasicBlock * = L->getLoopPreheader(); |
246 | BasicBlock *LoopLatch = L->getLoopLatch(); |
247 | |
248 | (void)ToFold; |
249 | LLVM_DEBUG(dbgs() << "To fold phi-node:\n" |
250 | << *ToFold << "\n" |
251 | << "New term-cond phi-node:\n" |
252 | << *ToHelpFold << "\n" ); |
253 | |
254 | Value *StartValue = ToHelpFold->getIncomingValueForBlock(BB: LoopPreheader); |
255 | (void)StartValue; |
256 | Value *LoopValue = ToHelpFold->getIncomingValueForBlock(BB: LoopLatch); |
257 | |
258 | // See comment in canFoldTermCondOfLoop on why this is sufficient. |
259 | if (MustDrop) |
260 | cast<Instruction>(Val: LoopValue)->dropPoisonGeneratingFlags(); |
261 | |
262 | // SCEVExpander for both use in preheader and latch |
263 | const DataLayout &DL = L->getHeader()->getDataLayout(); |
264 | SCEVExpander Expander(SE, DL, "lsr_fold_term_cond" ); |
265 | |
266 | assert(Expander.isSafeToExpand(TermValueS) && |
267 | "Terminating value was checked safe in canFoldTerminatingCondition" ); |
268 | |
269 | // Create new terminating value at loop preheader |
270 | Value *TermValue = Expander.expandCodeFor(SH: TermValueS, Ty: ToHelpFold->getType(), |
271 | I: LoopPreheader->getTerminator()); |
272 | |
273 | LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n" |
274 | << *StartValue << "\n" |
275 | << "Terminating value of new term-cond phi-node:\n" |
276 | << *TermValue << "\n" ); |
277 | |
278 | // Create new terminating condition at loop latch |
279 | BranchInst *BI = cast<BranchInst>(Val: LoopLatch->getTerminator()); |
280 | ICmpInst *OldTermCond = cast<ICmpInst>(Val: BI->getCondition()); |
281 | IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); |
282 | Value *NewTermCond = |
283 | LatchBuilder.CreateICmp(P: CmpInst::ICMP_EQ, LHS: LoopValue, RHS: TermValue, |
284 | Name: "lsr_fold_term_cond.replaced_term_cond" ); |
285 | // Swap successors to exit loop body if IV equals to new TermValue |
286 | if (BI->getSuccessor(i: 0) == L->getHeader()) |
287 | BI->swapSuccessors(); |
288 | |
289 | LLVM_DEBUG(dbgs() << "Old term-cond:\n" |
290 | << *OldTermCond << "\n" |
291 | << "New term-cond:\n" |
292 | << *NewTermCond << "\n" ); |
293 | |
294 | BI->setCondition(NewTermCond); |
295 | |
296 | Expander.clear(); |
297 | OldTermCond->eraseFromParent(); |
298 | DeleteDeadPHIs(BB: L->getHeader(), TLI: &TLI, MSSAU: MSSAU.get()); |
299 | return true; |
300 | } |
301 | |
302 | namespace { |
303 | |
304 | class LoopTermFold : public LoopPass { |
305 | public: |
306 | static char ID; // Pass ID, replacement for typeid |
307 | |
308 | LoopTermFold(); |
309 | |
310 | private: |
311 | bool runOnLoop(Loop *L, LPPassManager &LPM) override; |
312 | void getAnalysisUsage(AnalysisUsage &AU) const override; |
313 | }; |
314 | |
315 | } // end anonymous namespace |
316 | |
317 | LoopTermFold::LoopTermFold() : LoopPass(ID) { |
318 | initializeLoopTermFoldPass(*PassRegistry::getPassRegistry()); |
319 | } |
320 | |
321 | void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const { |
322 | AU.addRequired<LoopInfoWrapperPass>(); |
323 | AU.addPreserved<LoopInfoWrapperPass>(); |
324 | AU.addPreservedID(ID&: LoopSimplifyID); |
325 | AU.addRequiredID(ID&: LoopSimplifyID); |
326 | AU.addRequired<DominatorTreeWrapperPass>(); |
327 | AU.addPreserved<DominatorTreeWrapperPass>(); |
328 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
329 | AU.addPreserved<ScalarEvolutionWrapperPass>(); |
330 | AU.addRequired<TargetLibraryInfoWrapperPass>(); |
331 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
332 | AU.addPreserved<MemorySSAWrapperPass>(); |
333 | } |
334 | |
335 | bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { |
336 | if (skipLoop(L)) |
337 | return false; |
338 | |
339 | auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
340 | auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
341 | auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
342 | const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( |
343 | F: *L->getHeader()->getParent()); |
344 | auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI( |
345 | F: *L->getHeader()->getParent()); |
346 | auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>(); |
347 | MemorySSA *MSSA = nullptr; |
348 | if (MSSAAnalysis) |
349 | MSSA = &MSSAAnalysis->getMSSA(); |
350 | return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA); |
351 | } |
352 | |
353 | PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM, |
354 | LoopStandardAnalysisResults &AR, |
355 | LPMUpdater &) { |
356 | if (!RunTermFold(L: &L, SE&: AR.SE, DT&: AR.DT, LI&: AR.LI, TTI: AR.TTI, TLI&: AR.TLI, MSSA: AR.MSSA)) |
357 | return PreservedAnalyses::all(); |
358 | |
359 | auto PA = getLoopPassPreservedAnalyses(); |
360 | if (AR.MSSA) |
361 | PA.preserve<MemorySSAAnalysis>(); |
362 | return PA; |
363 | } |
364 | |
365 | char LoopTermFold::ID = 0; |
366 | |
367 | INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold" , "Loop Terminator Folding" , |
368 | false, false) |
369 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
370 | INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) |
371 | INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) |
372 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
373 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
374 | INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold" , "Loop Terminator Folding" , |
375 | false, false) |
376 | |
377 | Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); } |
378 | |