1//===- LoopTermFold.cpp - Eliminate last use of IV in exit branch----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//===----------------------------------------------------------------------===//
9
10#include "llvm/Transforms/Scalar/LoopTermFold.h"
11#include "llvm/ADT/Statistic.h"
12#include "llvm/Analysis/LoopAnalysisManager.h"
13#include "llvm/Analysis/LoopInfo.h"
14#include "llvm/Analysis/LoopPass.h"
15#include "llvm/Analysis/MemorySSA.h"
16#include "llvm/Analysis/MemorySSAUpdater.h"
17#include "llvm/Analysis/ScalarEvolution.h"
18#include "llvm/Analysis/ScalarEvolutionExpressions.h"
19#include "llvm/Analysis/TargetLibraryInfo.h"
20#include "llvm/Analysis/TargetTransformInfo.h"
21#include "llvm/Analysis/ValueTracking.h"
22#include "llvm/Config/llvm-config.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Dominators.h"
25#include "llvm/IR/IRBuilder.h"
26#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
29#include "llvm/IR/Type.h"
30#include "llvm/IR/Value.h"
31#include "llvm/InitializePasses.h"
32#include "llvm/Pass.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/raw_ostream.h"
35#include "llvm/Transforms/Scalar.h"
36#include "llvm/Transforms/Utils.h"
37#include "llvm/Transforms/Utils/BasicBlockUtils.h"
38#include "llvm/Transforms/Utils/Local.h"
39#include "llvm/Transforms/Utils/LoopUtils.h"
40#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
41#include <cassert>
42#include <optional>
43
44using namespace llvm;
45
46#define DEBUG_TYPE "loop-term-fold"
47
48STATISTIC(NumTermFold,
49 "Number of terminating condition fold recognized and performed");
50
51static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
52canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
53 const LoopInfo &LI, const TargetTransformInfo &TTI) {
54 if (!L->isInnermost()) {
55 LLVM_DEBUG(dbgs() << "Cannot fold on non-innermost loop\n");
56 return std::nullopt;
57 }
58 // Only inspect on simple loop structure
59 if (!L->isLoopSimplifyForm()) {
60 LLVM_DEBUG(dbgs() << "Cannot fold on non-simple loop\n");
61 return std::nullopt;
62 }
63
64 if (!SE.hasLoopInvariantBackedgeTakenCount(L)) {
65 LLVM_DEBUG(dbgs() << "Cannot fold on backedge that is loop variant\n");
66 return std::nullopt;
67 }
68
69 BasicBlock *LoopLatch = L->getLoopLatch();
70 BranchInst *BI = dyn_cast<BranchInst>(Val: LoopLatch->getTerminator());
71 if (!BI || BI->isUnconditional())
72 return std::nullopt;
73 auto *TermCond = dyn_cast<ICmpInst>(Val: BI->getCondition());
74 if (!TermCond) {
75 LLVM_DEBUG(
76 dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
77 return std::nullopt;
78 }
79 if (!TermCond->hasOneUse()) {
80 LLVM_DEBUG(
81 dbgs()
82 << "Cannot replace terminating condition with more than one use\n");
83 return std::nullopt;
84 }
85
86 BinaryOperator *LHS = dyn_cast<BinaryOperator>(Val: TermCond->getOperand(i_nocapture: 0));
87 Value *RHS = TermCond->getOperand(i_nocapture: 1);
88 if (!LHS || !L->isLoopInvariant(V: RHS))
89 // We could pattern match the inverse form of the icmp, but that is
90 // non-canonical, and this pass is running *very* late in the pipeline.
91 return std::nullopt;
92
93 // Find the IV used by the current exit condition.
94 PHINode *ToFold;
95 Value *ToFoldStart, *ToFoldStep;
96 if (!matchSimpleRecurrence(I: LHS, P&: ToFold, Start&: ToFoldStart, Step&: ToFoldStep))
97 return std::nullopt;
98
99 // Ensure the simple recurrence is a part of the current loop.
100 if (ToFold->getParent() != L->getHeader())
101 return std::nullopt;
102
103 // If that IV isn't dead after we rewrite the exit condition in terms of
104 // another IV, there's no point in doing the transform.
105 if (!isAlmostDeadIV(IV: ToFold, LatchBlock: LoopLatch, Cond: TermCond))
106 return std::nullopt;
107
108 // Inserting instructions in the preheader has a runtime cost, scale
109 // the allowed cost with the loops trip count as best we can.
110 const unsigned ExpansionBudget = [&]() {
111 unsigned Budget = 2 * SCEVCheapExpansionBudget;
112 if (unsigned SmallTC = SE.getSmallConstantMaxTripCount(L))
113 return std::min(a: Budget, b: SmallTC);
114 if (std::optional<unsigned> SmallTC = getLoopEstimatedTripCount(L))
115 return std::min(a: Budget, b: *SmallTC);
116 // Unknown trip count, assume long running by default.
117 return Budget;
118 }();
119
120 const SCEV *BECount = SE.getBackedgeTakenCount(L);
121 SCEVExpander Expander(SE, "lsr_fold_term_cond");
122
123 PHINode *ToHelpFold = nullptr;
124 const SCEV *TermValueS = nullptr;
125 bool MustDropPoison = false;
126 auto InsertPt = L->getLoopPreheader()->getTerminator();
127 for (PHINode &PN : L->getHeader()->phis()) {
128 if (ToFold == &PN)
129 continue;
130
131 if (!SE.isSCEVable(Ty: PN.getType())) {
132 LLVM_DEBUG(dbgs() << "IV of phi '" << PN
133 << "' is not SCEV-able, not qualified for the "
134 "terminating condition folding.\n");
135 continue;
136 }
137 const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: &PN));
138 // Only speculate on affine AddRec
139 if (!AddRec || !AddRec->isAffine()) {
140 LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
141 << "' is not an affine add recursion, not qualified "
142 "for the terminating condition folding.\n");
143 continue;
144 }
145
146 // Check that we can compute the value of AddRec on the exiting iteration
147 // without soundness problems. evaluateAtIteration internally needs
148 // to multiply the stride of the iteration number - which may wrap around.
149 // The issue here is subtle because computing the result accounting for
150 // wrap is insufficient. In order to use the result in an exit test, we
151 // must also know that AddRec doesn't take the same value on any previous
152 // iteration. The simplest case to consider is a candidate IV which is
153 // narrower than the trip count (and thus original IV), but this can
154 // also happen due to non-unit strides on the candidate IVs.
155 if (!AddRec->hasNoSelfWrap() ||
156 !SE.isKnownNonZero(S: AddRec->getStepRecurrence(SE)))
157 continue;
158
159 const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
160 const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(It: BECount, SE);
161 if (!Expander.isSafeToExpand(S: TermValueSLocal)) {
162 LLVM_DEBUG(
163 dbgs() << "Is not safe to expand terminating value for phi node" << PN
164 << "\n");
165 continue;
166 }
167
168 if (Expander.isHighCostExpansion(Exprs: TermValueSLocal, L, Budget: ExpansionBudget, TTI: &TTI,
169 At: InsertPt)) {
170 LLVM_DEBUG(
171 dbgs() << "Is too expensive to expand terminating value for phi node"
172 << PN << "\n");
173 continue;
174 }
175
176 // The candidate IV may have been otherwise dead and poison from the
177 // very first iteration. If we can't disprove that, we can't use the IV.
178 if (!mustExecuteUBIfPoisonOnPathTo(Root: &PN, OnPathTo: LoopLatch->getTerminator(), DT: &DT)) {
179 LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " << PN << "\n");
180 continue;
181 }
182
183 // The candidate IV may become poison on the last iteration. If this
184 // value is not branched on, this is a well defined program. We're
185 // about to add a new use to this IV, and we have to ensure we don't
186 // insert UB which didn't previously exist.
187 bool MustDropPoisonLocal = false;
188 Instruction *PostIncV =
189 cast<Instruction>(Val: PN.getIncomingValueForBlock(BB: LoopLatch));
190 if (!mustExecuteUBIfPoisonOnPathTo(Root: PostIncV, OnPathTo: LoopLatch->getTerminator(),
191 DT: &DT)) {
192 LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" << PN
193 << "\n");
194
195 // If this is a complex recurrance with multiple instructions computing
196 // the backedge value, we might need to strip poison flags from all of
197 // them.
198 if (PostIncV->getOperand(i: 0) != &PN)
199 continue;
200
201 // In order to perform the transform, we need to drop the poison
202 // generating flags on this instruction (if any).
203 MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
204 }
205
206 // We pick the last legal alternate IV. We could expore choosing an optimal
207 // alternate IV if we had a decent heuristic to do so.
208 ToHelpFold = &PN;
209 TermValueS = TermValueSLocal;
210 MustDropPoison = MustDropPoisonLocal;
211 }
212
213 LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
214 << "Cannot find other AddRec IV to help folding\n";);
215
216 LLVM_DEBUG(if (ToFold && ToHelpFold) dbgs()
217 << "\nFound loop that can fold terminating condition\n"
218 << " BECount (SCEV): " << *SE.getBackedgeTakenCount(L) << "\n"
219 << " TermCond: " << *TermCond << "\n"
220 << " BrandInst: " << *BI << "\n"
221 << " ToFold: " << *ToFold << "\n"
222 << " ToHelpFold: " << *ToHelpFold << "\n");
223
224 if (!ToFold || !ToHelpFold)
225 return std::nullopt;
226 return std::make_tuple(args&: ToFold, args&: ToHelpFold, args&: TermValueS, args&: MustDropPoison);
227}
228
229static bool RunTermFold(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
230 LoopInfo &LI, const TargetTransformInfo &TTI,
231 TargetLibraryInfo &TLI, MemorySSA *MSSA) {
232 std::unique_ptr<MemorySSAUpdater> MSSAU;
233 if (MSSA)
234 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
235
236 auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI, TTI);
237 if (!Opt)
238 return false;
239
240 auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
241
242 NumTermFold++;
243
244 BasicBlock *LoopPreheader = L->getLoopPreheader();
245 BasicBlock *LoopLatch = L->getLoopLatch();
246
247 (void)ToFold;
248 LLVM_DEBUG(dbgs() << "To fold phi-node:\n"
249 << *ToFold << "\n"
250 << "New term-cond phi-node:\n"
251 << *ToHelpFold << "\n");
252
253 Value *StartValue = ToHelpFold->getIncomingValueForBlock(BB: LoopPreheader);
254 (void)StartValue;
255 Value *LoopValue = ToHelpFold->getIncomingValueForBlock(BB: LoopLatch);
256
257 // See comment in canFoldTermCondOfLoop on why this is sufficient.
258 if (MustDrop)
259 cast<Instruction>(Val: LoopValue)->dropPoisonGeneratingFlags();
260
261 // SCEVExpander for both use in preheader and latch
262 SCEVExpander Expander(SE, "lsr_fold_term_cond");
263
264 assert(Expander.isSafeToExpand(TermValueS) &&
265 "Terminating value was checked safe in canFoldTerminatingCondition");
266
267 // Create new terminating value at loop preheader
268 Value *TermValue = Expander.expandCodeFor(SH: TermValueS, Ty: ToHelpFold->getType(),
269 I: LoopPreheader->getTerminator());
270
271 LLVM_DEBUG(dbgs() << "Start value of new term-cond phi-node:\n"
272 << *StartValue << "\n"
273 << "Terminating value of new term-cond phi-node:\n"
274 << *TermValue << "\n");
275
276 // Create new terminating condition at loop latch
277 BranchInst *BI = cast<BranchInst>(Val: LoopLatch->getTerminator());
278 ICmpInst *OldTermCond = cast<ICmpInst>(Val: BI->getCondition());
279 IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
280 Value *NewTermCond =
281 LatchBuilder.CreateICmp(P: CmpInst::ICMP_EQ, LHS: LoopValue, RHS: TermValue,
282 Name: "lsr_fold_term_cond.replaced_term_cond");
283 // Swap successors to exit loop body if IV equals to new TermValue
284 if (BI->getSuccessor(i: 0) == L->getHeader())
285 BI->swapSuccessors();
286
287 LLVM_DEBUG(dbgs() << "Old term-cond:\n"
288 << *OldTermCond << "\n"
289 << "New term-cond:\n"
290 << *NewTermCond << "\n");
291
292 BI->setCondition(NewTermCond);
293
294 Expander.clear();
295 OldTermCond->eraseFromParent();
296 DeleteDeadPHIs(BB: L->getHeader(), TLI: &TLI, MSSAU: MSSAU.get());
297 return true;
298}
299
300namespace {
301
302class LoopTermFold : public LoopPass {
303public:
304 static char ID; // Pass ID, replacement for typeid
305
306 LoopTermFold();
307
308private:
309 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
310 void getAnalysisUsage(AnalysisUsage &AU) const override;
311};
312
313} // end anonymous namespace
314
315LoopTermFold::LoopTermFold() : LoopPass(ID) {
316 initializeLoopTermFoldPass(*PassRegistry::getPassRegistry());
317}
318
319void LoopTermFold::getAnalysisUsage(AnalysisUsage &AU) const {
320 AU.addRequired<LoopInfoWrapperPass>();
321 AU.addPreserved<LoopInfoWrapperPass>();
322 AU.addPreservedID(ID&: LoopSimplifyID);
323 AU.addRequiredID(ID&: LoopSimplifyID);
324 AU.addRequired<DominatorTreeWrapperPass>();
325 AU.addPreserved<DominatorTreeWrapperPass>();
326 AU.addRequired<ScalarEvolutionWrapperPass>();
327 AU.addPreserved<ScalarEvolutionWrapperPass>();
328 AU.addRequired<TargetLibraryInfoWrapperPass>();
329 AU.addRequired<TargetTransformInfoWrapperPass>();
330 AU.addPreserved<MemorySSAWrapperPass>();
331}
332
333bool LoopTermFold::runOnLoop(Loop *L, LPPassManager & /*LPM*/) {
334 if (skipLoop(L))
335 return false;
336
337 auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
338 auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
339 auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
340 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(
341 F: *L->getHeader()->getParent());
342 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
343 F: *L->getHeader()->getParent());
344 auto *MSSAAnalysis = getAnalysisIfAvailable<MemorySSAWrapperPass>();
345 MemorySSA *MSSA = nullptr;
346 if (MSSAAnalysis)
347 MSSA = &MSSAAnalysis->getMSSA();
348 return RunTermFold(L, SE, DT, LI, TTI, TLI, MSSA);
349}
350
351PreservedAnalyses LoopTermFoldPass::run(Loop &L, LoopAnalysisManager &AM,
352 LoopStandardAnalysisResults &AR,
353 LPMUpdater &) {
354 if (!RunTermFold(L: &L, SE&: AR.SE, DT&: AR.DT, LI&: AR.LI, TTI: AR.TTI, TLI&: AR.TLI, MSSA: AR.MSSA))
355 return PreservedAnalyses::all();
356
357 auto PA = getLoopPassPreservedAnalyses();
358 if (AR.MSSA)
359 PA.preserve<MemorySSAAnalysis>();
360 return PA;
361}
362
363char LoopTermFold::ID = 0;
364
365INITIALIZE_PASS_BEGIN(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
366 false, false)
367INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
368INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
369INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
370INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
371INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
372INITIALIZE_PASS_END(LoopTermFold, "loop-term-fold", "Loop Terminator Folding",
373 false, false)
374
375Pass *llvm::createLoopTermFoldPass() { return new LoopTermFold(); }
376