1//===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10#include "llvm/ADT/Sequence.h"
11#include "llvm/Analysis/LoopAnalysisManager.h"
12#include "llvm/Analysis/LoopInfo.h"
13#include "llvm/Analysis/ScalarEvolution.h"
14#include "llvm/Analysis/ScalarEvolutionExpressions.h"
15#include "llvm/IR/PatternMatch.h"
16#include "llvm/Transforms/Scalar/LoopPassManager.h"
17#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18#include "llvm/Transforms/Utils/Cloning.h"
19#include "llvm/Transforms/Utils/LoopSimplify.h"
20#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
21
22#define DEBUG_TYPE "loop-bound-split"
23
24using namespace llvm;
25using namespace PatternMatch;
26
27namespace {
28struct ConditionInfo {
29 /// Branch instruction with this condition
30 BranchInst *BI = nullptr;
31 /// ICmp instruction with this condition
32 ICmpInst *ICmp = nullptr;
33 /// Preciate info
34 CmpPredicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
35 /// AddRec llvm value
36 Value *AddRecValue = nullptr;
37 /// Non PHI AddRec llvm value
38 Value *NonPHIAddRecValue;
39 /// Bound llvm value
40 Value *BoundValue = nullptr;
41 /// AddRec SCEV
42 const SCEVAddRecExpr *AddRecSCEV = nullptr;
43 /// Bound SCEV
44 const SCEV *BoundSCEV = nullptr;
45
46 ConditionInfo() = default;
47};
48} // namespace
49
50static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
51 ConditionInfo &Cond, const Loop &L) {
52 Cond.ICmp = ICmp;
53 if (match(V: ICmp, P: m_ICmp(Pred&: Cond.Pred, L: m_Value(V&: Cond.AddRecValue),
54 R: m_Value(V&: Cond.BoundValue)))) {
55 const SCEV *AddRecSCEV = SE.getSCEV(V: Cond.AddRecValue);
56 const SCEV *BoundSCEV = SE.getSCEV(V: Cond.BoundValue);
57 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: AddRecSCEV);
58 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: BoundSCEV);
59 // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
60 if (!LHSAddRecSCEV && RHSAddRecSCEV) {
61 std::swap(a&: Cond.AddRecValue, b&: Cond.BoundValue);
62 std::swap(a&: AddRecSCEV, b&: BoundSCEV);
63 Cond.Pred = ICmpInst::getSwappedPredicate(pred: Cond.Pred);
64 }
65
66 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: AddRecSCEV);
67 Cond.BoundSCEV = BoundSCEV;
68 Cond.NonPHIAddRecValue = Cond.AddRecValue;
69
70 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
71 // value from backedge.
72 if (Cond.AddRecSCEV && isa<PHINode>(Val: Cond.AddRecValue)) {
73 PHINode *PN = cast<PHINode>(Val: Cond.AddRecValue);
74 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(BB: L.getLoopLatch());
75 }
76 }
77}
78
79static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
80 ConditionInfo &Cond, bool IsExitCond) {
81 if (IsExitCond) {
82 const SCEV *ExitCount = SE.getExitCount(L: &L, ExitingBlock: Cond.ICmp->getParent());
83 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
84 return false;
85
86 Cond.BoundSCEV = ExitCount;
87 return true;
88 }
89
90 // For non-exit condtion, if pred is LT, keep existing bound.
91 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
92 return true;
93
94 // For non-exit condition, if pre is LE, try to convert it to LT.
95 // Range Range
96 // AddRec <= Bound --> AddRec < Bound + 1
97 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
98 return false;
99
100 if (IntegerType *BoundSCEVIntType =
101 dyn_cast<IntegerType>(Val: Cond.BoundSCEV->getType())) {
102 unsigned BitWidth = BoundSCEVIntType->getBitWidth();
103 APInt Max = ICmpInst::isSigned(predicate: Cond.Pred)
104 ? APInt::getSignedMaxValue(numBits: BitWidth)
105 : APInt::getMaxValue(numBits: BitWidth);
106 const SCEV *MaxSCEV = SE.getConstant(Val: Max);
107 // Check Bound < INT_MAX
108 ICmpInst::Predicate Pred =
109 ICmpInst::isSigned(predicate: Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
110 if (SE.isKnownPredicate(Pred, LHS: Cond.BoundSCEV, RHS: MaxSCEV)) {
111 const SCEV *BoundPlusOneSCEV =
112 SE.getAddExpr(LHS: Cond.BoundSCEV, RHS: SE.getOne(Ty: BoundSCEVIntType));
113 Cond.BoundSCEV = BoundPlusOneSCEV;
114 Cond.Pred = Pred;
115 return true;
116 }
117 }
118
119 // ToDo: Support ICMP_NE/EQ.
120
121 return false;
122}
123
124static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
125 ICmpInst *ICmp, ConditionInfo &Cond,
126 bool IsExitCond) {
127 analyzeICmp(SE, ICmp, Cond, L);
128
129 // The BoundSCEV should be evaluated at loop entry.
130 if (!SE.isAvailableAtLoopEntry(S: Cond.BoundSCEV, L: &L))
131 return false;
132
133 // Allowed AddRec as induction variable.
134 if (!Cond.AddRecSCEV)
135 return false;
136
137 if (!Cond.AddRecSCEV->isAffine())
138 return false;
139
140 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
141 // Allowed constant step.
142 if (!isa<SCEVConstant>(Val: StepRecSCEV))
143 return false;
144
145 ConstantInt *StepCI = cast<SCEVConstant>(Val: StepRecSCEV)->getValue();
146 // Allowed positive step for now.
147 // TODO: Support negative step.
148 if (StepCI->isNegative() || StepCI->isZero())
149 return false;
150
151 // Calculate upper bound.
152 if (!calculateUpperBound(L, SE, Cond, IsExitCond))
153 return false;
154
155 return true;
156}
157
158static bool isProcessableCondBI(const ScalarEvolution &SE,
159 const BranchInst *BI) {
160 BasicBlock *TrueSucc = nullptr;
161 BasicBlock *FalseSucc = nullptr;
162 Value *LHS, *RHS;
163 if (!match(V: BI, P: m_Br(C: m_ICmp(L: m_Value(V&: LHS), R: m_Value(V&: RHS)),
164 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc))))
165 return false;
166
167 if (!SE.isSCEVable(Ty: LHS->getType()))
168 return false;
169 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
170
171 if (TrueSucc == FalseSucc)
172 return false;
173
174 return true;
175}
176
177static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
178 ScalarEvolution &SE, ConditionInfo &Cond) {
179 // Skip function with optsize.
180 if (L.getHeader()->getParent()->hasOptSize())
181 return false;
182
183 // Split only innermost loop.
184 if (!L.isInnermost())
185 return false;
186
187 // Check loop is in simplified form.
188 if (!L.isLoopSimplifyForm())
189 return false;
190
191 // Check loop is in LCSSA form.
192 if (!L.isLCSSAForm(DT))
193 return false;
194
195 // Skip loop that cannot be cloned.
196 if (!L.isSafeToClone())
197 return false;
198
199 BasicBlock *ExitingBB = L.getExitingBlock();
200 // Assumed only one exiting block.
201 if (!ExitingBB)
202 return false;
203
204 BranchInst *ExitingBI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator());
205 if (!ExitingBI)
206 return false;
207
208 // Allowed only conditional branch with ICmp.
209 if (!isProcessableCondBI(SE, BI: ExitingBI))
210 return false;
211
212 // Check the condition is processable.
213 ICmpInst *ICmp = cast<ICmpInst>(Val: ExitingBI->getCondition());
214 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
215 return false;
216
217 Cond.BI = ExitingBI;
218 return true;
219}
220
221static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
222 // If the conditional branch splits a loop into two halves, we could
223 // generally say it is profitable.
224 //
225 // ToDo: Add more profitable cases here.
226
227 // Check this branch causes diamond CFG.
228 BasicBlock *Succ0 = BI->getSuccessor(i: 0);
229 BasicBlock *Succ1 = BI->getSuccessor(i: 1);
230
231 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
232 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
233 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
234 return false;
235
236 // ToDo: Calculate each successor's instruction cost.
237
238 return true;
239}
240
241static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
242 ConditionInfo &ExitingCond,
243 ConditionInfo &SplitCandidateCond) {
244 for (auto *BB : L.blocks()) {
245 // Skip condition of backedge.
246 if (L.getLoopLatch() == BB)
247 continue;
248
249 auto *BI = dyn_cast<BranchInst>(Val: BB->getTerminator());
250 if (!BI)
251 continue;
252
253 // Check conditional branch with ICmp.
254 if (!isProcessableCondBI(SE, BI))
255 continue;
256
257 // Skip loop invariant condition.
258 if (L.isLoopInvariant(V: BI->getCondition()))
259 continue;
260
261 // Check the condition is processable.
262 ICmpInst *ICmp = cast<ICmpInst>(Val: BI->getCondition());
263 if (!hasProcessableCondition(L, SE, ICmp, Cond&: SplitCandidateCond,
264 /*IsExitCond*/ false))
265 continue;
266
267 if (ExitingCond.BoundSCEV->getType() !=
268 SplitCandidateCond.BoundSCEV->getType())
269 continue;
270
271 // After transformation, we assume the split condition of the pre-loop is
272 // always true. In order to guarantee it, we need to check the start value
273 // of the split cond AddRec satisfies the split condition.
274 if (!SE.isLoopEntryGuardedByCond(L: &L, Pred: SplitCandidateCond.Pred,
275 LHS: SplitCandidateCond.AddRecSCEV->getStart(),
276 RHS: SplitCandidateCond.BoundSCEV))
277 continue;
278
279 SplitCandidateCond.BI = BI;
280 return BI;
281 }
282
283 return nullptr;
284}
285
286static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
287 ScalarEvolution &SE, LPMUpdater &U) {
288 ConditionInfo SplitCandidateCond;
289 ConditionInfo ExitingCond;
290
291 // Check we can split this loop's bound.
292 if (!canSplitLoopBound(L, DT, SE, Cond&: ExitingCond))
293 return false;
294
295 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
296 return false;
297
298 if (!isProfitableToTransform(L, BI: SplitCandidateCond.BI))
299 return false;
300
301 // Now, we have a split candidate. Let's build a form as below.
302 // +--------------------+
303 // | preheader |
304 // | set up newbound |
305 // +--------------------+
306 // | /----------------\
307 // +--------v----v------+ |
308 // | header |---\ |
309 // | with true condition| | |
310 // +--------------------+ | |
311 // | | |
312 // +--------v-----------+ | |
313 // | if.then.BB | | |
314 // +--------------------+ | |
315 // | | |
316 // +--------v-----------<---/ |
317 // | latch >----------/
318 // | with newbound |
319 // +--------------------+
320 // |
321 // +--------v-----------+
322 // | preheader2 |--------------\
323 // | if (AddRec i != | |
324 // | org bound) | |
325 // +--------------------+ |
326 // | /----------------\ |
327 // +--------v----v------+ | |
328 // | header2 |---\ | |
329 // | conditional branch | | | |
330 // |with false condition| | | |
331 // +--------------------+ | | |
332 // | | | |
333 // +--------v-----------+ | | |
334 // | if.then.BB2 | | | |
335 // +--------------------+ | | |
336 // | | | |
337 // +--------v-----------<---/ | |
338 // | latch2 >----------/ |
339 // | with org bound | |
340 // +--------v-----------+ |
341 // | |
342 // | +---------------+ |
343 // +--> exit <-------/
344 // +---------------+
345
346 // Let's create post loop.
347 SmallVector<BasicBlock *, 8> PostLoopBlocks;
348 Loop *PostLoop;
349 ValueToValueMapTy VMap;
350 BasicBlock *PreHeader = L.getLoopPreheader();
351 BasicBlock *SplitLoopPH = SplitEdge(From: PreHeader, To: L.getHeader(), DT: &DT, LI: &LI);
352 PostLoop = cloneLoopWithPreheader(Before: L.getExitBlock(), LoopDomBB: SplitLoopPH, OrigLoop: &L, VMap,
353 NameSuffix: ".split", LI: &LI, DT: &DT, Blocks&: PostLoopBlocks);
354 remapInstructionsInBlocks(Blocks: PostLoopBlocks, VMap);
355
356 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
357 IRBuilder<> Builder(&PostLoopPreHeader->front());
358
359 // Update phi nodes in header of post-loop.
360 bool isExitingLatch = L.getExitingBlock() == L.getLoopLatch();
361 Value *ExitingCondLCSSAPhi = nullptr;
362 for (PHINode &PN : L.getHeader()->phis()) {
363 // Create LCSSA phi node in preheader of post-loop.
364 PHINode *LCSSAPhi =
365 Builder.CreatePHI(Ty: PN.getType(), NumReservedValues: 1, Name: PN.getName() + ".lcssa");
366 LCSSAPhi->setDebugLoc(PN.getDebugLoc());
367 // If the exiting block is loop latch, the phi does not have the update at
368 // last iteration. In this case, update lcssa phi with value from backedge.
369 LCSSAPhi->addIncoming(
370 V: isExitingLatch ? PN.getIncomingValueForBlock(BB: L.getLoopLatch()) : &PN,
371 BB: L.getExitingBlock());
372
373 // Update the start value of phi node in post-loop with the LCSSA phi node.
374 PHINode *PostLoopPN = cast<PHINode>(Val&: VMap[&PN]);
375 PostLoopPN->setIncomingValueForBlock(BB: PostLoopPreHeader, V: LCSSAPhi);
376
377 // Find PHI with exiting condition from pre-loop. The PHI should be
378 // SCEVAddRecExpr and have same incoming value from backedge with
379 // ExitingCond.
380 if (!SE.isSCEVable(Ty: PN.getType()))
381 continue;
382
383 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: &PN));
384 if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
385 PN.getIncomingValueForBlock(BB: L.getLoopLatch()))
386 ExitingCondLCSSAPhi = LCSSAPhi;
387 }
388
389 // Add conditional branch to check we can skip post-loop in its preheader.
390 Instruction *OrigBI = PostLoopPreHeader->getTerminator();
391 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
392 Value *Cond =
393 Builder.CreateICmp(P: Pred, LHS: ExitingCondLCSSAPhi, RHS: ExitingCond.BoundValue);
394 Builder.CreateCondBr(Cond, True: PostLoop->getHeader(), False: PostLoop->getExitBlock());
395 OrigBI->eraseFromParent();
396
397 // Create new loop bound and add it into preheader of pre-loop.
398 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
399 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
400 NewBoundSCEV = ICmpInst::isSigned(predicate: ExitingCond.Pred)
401 ? SE.getSMinExpr(LHS: NewBoundSCEV, RHS: SplitBoundSCEV)
402 : SE.getUMinExpr(LHS: NewBoundSCEV, RHS: SplitBoundSCEV);
403
404 SCEVExpander Expander(SE, "split");
405 Instruction *InsertPt = SplitLoopPH->getTerminator();
406 Value *NewBoundValue =
407 Expander.expandCodeFor(SH: NewBoundSCEV, Ty: NewBoundSCEV->getType(), I: InsertPt);
408 NewBoundValue->setName("new.bound");
409
410 // Replace exiting bound value of pre-loop NewBound.
411 ExitingCond.ICmp->setOperand(i_nocapture: 1, Val_nocapture: NewBoundValue);
412
413 // Replace SplitCandidateCond.BI's condition of pre-loop by True.
414 LLVMContext &Context = PreHeader->getContext();
415 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
416
417 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
418 BranchInst *ClonedSplitCandidateBI =
419 cast<BranchInst>(Val&: VMap[SplitCandidateCond.BI]);
420 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
421
422 // Replace exit branch target of pre-loop by post-loop's preheader.
423 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(i: 0))
424 ExitingCond.BI->setSuccessor(idx: 0, NewSucc: PostLoopPreHeader);
425 else
426 ExitingCond.BI->setSuccessor(idx: 1, NewSucc: PostLoopPreHeader);
427
428 // Update phi node in exit block of post-loop.
429 Builder.SetInsertPoint(TheBB: PostLoopPreHeader, IP: PostLoopPreHeader->begin());
430 for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
431 for (auto i : seq<int>(Begin: 0, End: PN.getNumOperands())) {
432 // Check incoming block is pre-loop's exiting block.
433 if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
434 Value *IncomingValue = PN.getIncomingValue(i);
435
436 // Create LCSSA phi node for incoming value.
437 PHINode *LCSSAPhi =
438 Builder.CreatePHI(Ty: PN.getType(), NumReservedValues: 1, Name: PN.getName() + ".lcssa");
439 LCSSAPhi->setDebugLoc(PN.getDebugLoc());
440 LCSSAPhi->addIncoming(V: IncomingValue, BB: PN.getIncomingBlock(i));
441
442 // Replace pre-loop's exiting block by post-loop's preheader.
443 PN.setIncomingBlock(i, BB: PostLoopPreHeader);
444 // Replace incoming value by LCSSAPhi.
445 PN.setIncomingValue(i, V: LCSSAPhi);
446 // Add a new incoming value with post-loop's exiting block.
447 PN.addIncoming(V: VMap[IncomingValue], BB: PostLoop->getExitingBlock());
448 }
449 }
450 }
451
452 // Update dominator tree.
453 DT.changeImmediateDominator(BB: PostLoopPreHeader, NewBB: L.getExitingBlock());
454 DT.changeImmediateDominator(BB: PostLoop->getExitBlock(), NewBB: PostLoopPreHeader);
455
456 // Invalidate cached SE information.
457 SE.forgetLoop(L: &L);
458
459 // Canonicalize loops.
460 simplifyLoop(L: &L, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, PreserveLCSSA: true);
461 simplifyLoop(L: PostLoop, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, PreserveLCSSA: true);
462
463 // Add new post-loop to loop pass manager.
464 U.addSiblingLoops(NewSibLoops: PostLoop);
465
466 return true;
467}
468
469PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
470 LoopStandardAnalysisResults &AR,
471 LPMUpdater &U) {
472 [[maybe_unused]] Function &F = *L.getHeader()->getParent();
473
474 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
475 << "\n");
476
477 if (!splitLoopBound(L, DT&: AR.DT, LI&: AR.LI, SE&: AR.SE, U))
478 return PreservedAnalyses::all();
479
480 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
481 AR.LI.verify(DomTree: AR.DT);
482
483 return getLoopPassPreservedAnalyses();
484}
485