1//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//===----------------------------------------------------------------------===//
9
10#include "llvm/Transforms/Scalar/LoopTermFold.h"
11#include "llvm/ADT/Statistic.h"
12#include "llvm/Analysis/LoopAnalysisManager.h"
13#include "llvm/Analysis/LoopInfo.h"
14#include "llvm/Analysis/LoopPass.h"
15#include "llvm/Analysis/MemorySSA.h"
16#include "llvm/Analysis/MemorySSAUpdater.h"
17#include "llvm/Analysis/ScalarEvolution.h"
18#include "llvm/Analysis/ScalarEvolutionExpressions.h"
19#include "llvm/Analysis/TargetLibraryInfo.h"
20#include "llvm/Analysis/TargetTransformInfo.h"
21#include "llvm/Analysis/ValueTracking.h"
22#include "llvm/Config/llvm-config.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/Type.h"
30#include "llvm/IR/Value.h"
31#include "llvm/InitializePasses.h"
32#include "llvm/Pass.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/raw_ostream.h"
35#include "llvm/Transforms/Scalar.h"
36#include "llvm/Transforms/Utils.h"
37#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38#include "llvm/Transforms/Utils/Local.h"
39#include "llvm/Transforms/Utils/LoopUtils.h"
40#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41#include <cassert>
42#include <optional>
43
44using namespace llvm;
45
46#define DEBUG_TYPE "loop-term-fold"
47
48STATISTIC(NumTermFold,
49 "Number of terminating condition fold recognized and performed");
50
51static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
52canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53 const LoopInfo &LI, const TargetTransformInfo &TTI) {
54 if (!L->isInnermost()) {
55 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56 return std::nullopt;
57 }
58 // Only inspect on simple loop structure
59 if (!L->isLoopSimplifyForm()) {
60 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61 return std::nullopt;
62 }
63
64 if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66 return std::nullopt;
67 }
68
69 BasicBlock *LoopLatch = L->getLoopLatch();
70 BranchInst *BI = dyn_cast<BranchInst>(Val: LoopLatch->getTerminator());
71 if (!BI || BI->isUnconditional())
72 return std::nullopt;
73 auto *TermCond = dyn_cast<ICmpInst>(Val: BI->getCondition());
74 if (!TermCond) {
75 LLVM_DEBUG(
76 dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77 return std::nullopt;
78 }
79 if (!TermCond->hasOneUse()) {
80 LLVM_DEBUG(
81 dbgs()
82 << "Cannot replace terminating condition with more than one use\n");
83 return std::nullopt;
84 }
85
86 BinaryOperator *LHS = dyn_cast<BinaryOperator>(Val: TermCond->getOperand(i_nocapture: 0));
87 Value *RHS = TermCond->getOperand(i_nocapture: 1);
88 if (!LHS || !L->isLoopInvariant(V: RHS))
89 // We could pattern match the inverse form of the icmp, but that is
90 // non-canonical, and this pass is running *very* late in the pipeline.
91 return std::nullopt;
92
93 // Find the IV used by the current exit condition.
94 PHINode *ToFold;
95 Value *ToFoldStart, *ToFoldStep;
96 if (!matchSimpleRecurrence(I: LHS, P&: ToFold, Start&: ToFoldStart, Step&: ToFoldStep))
97 return std::nullopt;
98
99 // Ensure the simple recurrence is a part of the current loop.
100 if (ToFold->getParent() != L->getHeader())
101 return std::nullopt;
102
103 // If that IV isn't dead after we rewrite the exit condition in terms of
104 // another IV, there's no point in doing the transform.
105 if (!isAlmostDeadIV(IV: ToFold, LatchBlock: LoopLatch, Cond: TermCond))
106 return std::nullopt;
107
108 // Inserting instructions in the preheader has a runtime cost, scale
109 // the allowed cost with the loops trip count as best we can.
110 const unsigned ExpansionBudget = [&]() {
111 unsigned Budget = 2 * SCEVCheapExpansionBudget;
112 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113 return std::min(a: Budget, b: SmallTC);
114 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115 return std::min(a: Budget, b: *SmallTC);
116 // Unknown trip count, assume long running by default.
117 return Budget;
118 }();
119
120 const SCEV *BECount = SE.getBackedgeTakenCount(L);
121 const DataLayout &DL = L->getHeader()->getDataLayout();
122 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
123
124 PHINode *ToHelpFold = nullptr;
125 const SCEV *TermValueS = nullptr;
126 bool MustDropPoison = false;
127 auto InsertPt = L->getLoopPreheader()->getTerminator();
128 for (PHINode &PN : L->getHeader()->phis()) {
129 if (ToFold == &PN)
130 continue;
131
132 if (!SE.isSCEVable(Ty: PN.getType())) {
133 LLVM_DEBUG(dbgs() << "IV of phi '" << PN
134 << "' is not SCEV-able, not qualified for the "
135 "terminating condition folding.\n");
136 continue;
137 }
138 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: &PN));
139 // Only speculate on affine AddRec
140 if (!AddRec || !AddRec->isAffine()) {
141 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
142 << "' is not an affine add recursion, not qualified "
143 "for the terminating condition folding.\n");
144 continue;
145 }
146
147 // Check that we can compute the value of AddRec on the exiting iteration
148 // without soundness problems. evaluateAtIteration internally needs
149 // to multiply the stride of the iteration number - which may wrap around.
150 // The issue here is subtle because computing the result accounting for
151 // wrap is insufficient. In order to use the result in an exit test, we
152 // must also know that AddRec doesn't take the same value on any previous
153 // iteration. The simplest case to consider is a candidate IV which is
154 // narrower than the trip count (and thus original IV), but this can
155 // also happen due to non-unit strides on the candidate IVs.
156 if (!AddRec->hasNoSelfWrap() ||
157 !SE.isKnownNonZero(S: AddRec->getStepRecurrence(SE)))
158 continue;
159
160 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
161 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(It: BECount, SE);
162 if (!Expander.isSafeToExpand(S: TermValueSLocal)) {
163 LLVM_DEBUG(
164 dbgs() << "Is not safe to expand terminating value for phi node" << PN
165 << "\n");
166 continue;
167 }
168
169 if (Expander.isHighCostExpansion(Exprs: TermValueSLocal, L, Budget: ExpansionBudget, TTI: &TTI,
170 At: InsertPt)) {
171 LLVM_DEBUG(
172 dbgs() << "Is too expensive to expand terminating value for phi node"
173 << PN << "\n");
174 continue;
175 }
176
177 // The candidate IV may have been otherwise dead and poison from the
178 // very first iteration. If we can't disprove that, we can't use the IV.
179 if (!mustExecuteUBIfPoisonOnPathTo(Root: &PN, OnPathTo: LoopLatch->getTerminator(), DT: &DT)) {
180 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
181 continue;
182 }
183
184 // The candidate IV may become poison on the last iteration. If this
185 // value is not branched on, this is a well defined program. We're
186 // about to add a new use to this IV, and we have to ensure we don't
187 // insert UB which didn't previously exist.
188 bool MustDropPoisonLocal = false;
189 Instruction *PostIncV =
190 cast<Instruction>(Val: PN.getIncomingValueForBlock(BB: LoopLatch));
191 if (!mustExecuteUBIfPoisonOnPathTo(Root: PostIncV, OnPathTo: LoopLatch->getTerminator(),
192 DT: &DT)) {
193 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
194 << "\n");
195
196 // If this is a complex recurrance with multiple instructions computing
197 // the backedge value, we might need to strip poison flags from all of
198 // them.
199 if (PostIncV->getOperand(i: 0) != &PN)
200 continue;
201
202 // In order to perform the transform, we need to drop the poison
203 // generating flags on this instruction (if any).
204 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
205 }
206
207 // We pick the last legal alternate IV. We could expore choosing an optimal
208 // alternate IV if we had a decent heuristic to do so.
209 ToHelpFold = &PN;
210 TermValueS = TermValueSLocal;
211 MustDropPoison = MustDropPoisonLocal;
212 }
213
214 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
215 << "Cannot find other AddRec IV to help folding\n";);
216
217 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
218 << "\nFound loop that can fold terminating condition\n"
219 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
220 << " TermCond: " << *TermCond << "\n"
221 << " BrandInst: " << *BI << "\n"
222 << " ToFold: " << *ToFold << "\n"
223 << " ToHelpFold: " << *ToHelpFold << "\n");
224
225 if (!ToFold || !ToHelpFold)
226 return std::nullopt;
227 return std::make_tuple(args&: ToFold, args&: ToHelpFold, args&: TermValueS, args&: MustDropPoison);
228}
229
230static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
231 LoopInfo &LI, const TargetTransformInfo &TTI,
232 TargetLibraryInfo &TLI, MemorySSA *MSSA) {
233 std::unique_ptr<MemorySSAUpdater> MSSAU;
234 if (MSSA)
235 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
236
237 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
238 if (!Opt)
239 return false;
240
241 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
242
243 NumTermFold++;
244
245 BasicBlock *LoopPreheader = L->getLoopPreheader();
246 BasicBlock *LoopLatch = L->getLoopLatch();
247
248 (void)ToFold;
249 LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
250 << *ToFold << "\n"
251 << "New term-cond phi-node:\n"
252 << *ToHelpFold << "\n");
253
254 Value *StartValue = ToHelpFold->getIncomingValueForBlock(BB: LoopPreheader);
255 (void)StartValue;
256 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(BB: LoopLatch);
257
258 // See comment in canFoldTermCondOfLoop on why this is sufficient.
259 if (MustDrop)
260 cast<Instruction>(Val: LoopValue)->dropPoisonGeneratingFlags();
261
262 // SCEVExpander for both use in preheader and latch
263 const DataLayout &DL = L->getHeader()->getDataLayout();
264 SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
265
266 assert(Expander.isSafeToExpand(TermValueS) &&
267 "Terminating value was checked safe in canFoldTerminatingCondition");
268
269 // Create new terminating value at loop preheader
270 Value *TermValue = Expander.expandCodeFor(SH: TermValueS, Ty: ToHelpFold->getType(),
271 I: LoopPreheader->getTerminator());
272
273 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
274 << *StartValue << "\n"
275 << "Terminating value of new term-cond phi-node:\n"
276 << *TermValue << "\n");
277
278 // Create new terminating condition at loop latch
279 BranchInst *BI = cast<BranchInst>(Val: LoopLatch->getTerminator());
280 ICmpInst *OldTermCond = cast<ICmpInst>(Val: BI->getCondition());
281 IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
282 Value *NewTermCond =
283 LatchBuilder.CreateICmp(P: CmpInst::ICMP_EQ, LHS: LoopValue, RHS: TermValue,
284 Name: "lsr_fold_term_cond.replaced_term_cond");
285 // Swap successors to exit loop body if IV equals to new TermValue
286 if (BI->getSuccessor(i: 0) == L->getHeader())
287 BI->swapSuccessors();
288
289 LLVM_DEBUG(dbgs() << "Old term-cond:\n"
290 << *OldTermCond << "\n"
291 << "New term-cond:\n"
292 << *NewTermCond << "\n");
293
294 BI->setCondition(NewTermCond);
295
296 Expander.clear();
297 OldTermCond->eraseFromParent();
298 DeleteDeadPHIs(BB: L->getHeader(), TLI: &TLI, MSSAU: MSSAU.get());
299 return true;
300}
301
302namespace {
303
304class LoopTermFold : public LoopPass {
305public:
306 static char ID; // Pass ID, replacement for typeid
307
308 LoopTermFold();
309
310private:
311 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
312 void getAnalysisUsage(AnalysisUsage &AU) const override;
313};
314
315} // end anonymous namespace
316
317LoopTermFold::LoopTermFold() : LoopPass(ID) {
318 initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
319}
320
321void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
322 AU.addRequired<LoopInfoWrapperPass>();
323 AU.addPreserved<LoopInfoWrapperPass>();
324 AU.addPreservedID(ID&: LoopSimplifyID);
325 AU.addRequiredID(ID&: LoopSimplifyID);
326 AU.addRequired<DominatorTreeWrapperPass>();
327 AU.addPreserved<DominatorTreeWrapperPass>();
328 AU.addRequired<ScalarEvolutionWrapperPass>();
329 AU.addPreserved<ScalarEvolutionWrapperPass>();
330 AU.addRequired<TargetLibraryInfoWrapperPass>();
331 AU.addRequired<TargetTransformInfoWrapperPass>();
332 AU.addPreserved<MemorySSAWrapperPass>();
333}
334
335bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
336 if (skipLoop(L))
337 return false;
338
339 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
340 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
341 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
342 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
343 F: *L->getHeader()->getParent());
344 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
345 F: *L->getHeader()->getParent());
346 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
347 MemorySSA *MSSA = nullptr;
348 if (MSSAAnalysis)
349 MSSA = &MSSAAnalysis->getMSSA();
350 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
351}
352
353PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
354 LoopStandardAnalysisResults &AR,
355 LPMUpdater &) {
356 if (!RunTermFold(L: &L, SE&: AR.SE, DT&: AR.DT, LI&: AR.LI, TTI: AR.TTI, TLI&: AR.TLI, MSSA: AR.MSSA))
357 return PreservedAnalyses::all();
358
359 auto PA = getLoopPassPreservedAnalyses();
360 if (AR.MSSA)
361 PA.preserve<MemorySSAAnalysis>();
362 return PA;
363}
364
365char LoopTermFold::ID = 0;
366
367INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
368 false, false)
369INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
370INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
371INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
372INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
373INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
374INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
375 false, false)
376
377Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
378