1//===------- LoopBoundSplit.cpp - Split Loop Bound --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Scalar/LoopBoundSplit.h"
10#include "llvm/ADT/Sequence.h"
11#include "llvm/Analysis/LoopAnalysisManager.h"
12#include "llvm/Analysis/LoopInfo.h"
13#include "llvm/Analysis/ScalarEvolution.h"
14#include "llvm/Analysis/ScalarEvolutionExpressions.h"
15#include "llvm/IR/PatternMatch.h"
16#include "llvm/Transforms/Scalar/LoopPassManager.h"
17#include "llvm/Transforms/Utils/BasicBlockUtils.h"
18#include "llvm/Transforms/Utils/Cloning.h"
19#include "llvm/Transforms/Utils/LoopSimplify.h"
20#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
21
22#define DEBUG_TYPE "loop-bound-split"
23
24namespace llvm {
25
26using namespace PatternMatch;
27
28namespace {
29struct ConditionInfo {
30 /// Branch instruction with this condition
31 BranchInst *BI = nullptr;
32 /// ICmp instruction with this condition
33 ICmpInst *ICmp = nullptr;
34 /// Preciate info
35 CmpPredicate Pred = ICmpInst::BAD_ICMP_PREDICATE;
36 /// AddRec llvm value
37 Value *AddRecValue = nullptr;
38 /// Non PHI AddRec llvm value
39 Value *NonPHIAddRecValue;
40 /// Bound llvm value
41 Value *BoundValue = nullptr;
42 /// AddRec SCEV
43 const SCEVAddRecExpr *AddRecSCEV = nullptr;
44 /// Bound SCEV
45 const SCEV *BoundSCEV = nullptr;
46
47 ConditionInfo() = default;
48};
49} // namespace
50
51static void analyzeICmp(ScalarEvolution &SE, ICmpInst *ICmp,
52 ConditionInfo &Cond, const Loop &L) {
53 Cond.ICmp = ICmp;
54 if (match(V: ICmp, P: m_ICmp(Pred&: Cond.Pred, L: m_Value(V&: Cond.AddRecValue),
55 R: m_Value(V&: Cond.BoundValue)))) {
56 const SCEV *AddRecSCEV = SE.getSCEV(V: Cond.AddRecValue);
57 const SCEV *BoundSCEV = SE.getSCEV(V: Cond.BoundValue);
58 const SCEVAddRecExpr *LHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: AddRecSCEV);
59 const SCEVAddRecExpr *RHSAddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: BoundSCEV);
60 // Locate AddRec in LHSSCEV and Bound in RHSSCEV.
61 if (!LHSAddRecSCEV && RHSAddRecSCEV) {
62 std::swap(a&: Cond.AddRecValue, b&: Cond.BoundValue);
63 std::swap(a&: AddRecSCEV, b&: BoundSCEV);
64 Cond.Pred = ICmpInst::getSwappedPredicate(pred: Cond.Pred);
65 }
66
67 Cond.AddRecSCEV = dyn_cast<SCEVAddRecExpr>(Val: AddRecSCEV);
68 Cond.BoundSCEV = BoundSCEV;
69 Cond.NonPHIAddRecValue = Cond.AddRecValue;
70
71 // If the Cond.AddRecValue is PHI node, update Cond.NonPHIAddRecValue with
72 // value from backedge.
73 if (Cond.AddRecSCEV && isa<PHINode>(Val: Cond.AddRecValue)) {
74 PHINode *PN = cast<PHINode>(Val: Cond.AddRecValue);
75 Cond.NonPHIAddRecValue = PN->getIncomingValueForBlock(BB: L.getLoopLatch());
76 }
77 }
78}
79
80static bool calculateUpperBound(const Loop &L, ScalarEvolution &SE,
81 ConditionInfo &Cond, bool IsExitCond) {
82 if (IsExitCond) {
83 const SCEV *ExitCount = SE.getExitCount(L: &L, ExitingBlock: Cond.ICmp->getParent());
84 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
85 return false;
86
87 Cond.BoundSCEV = ExitCount;
88 return true;
89 }
90
91 // For non-exit condtion, if pred is LT, keep existing bound.
92 if (Cond.Pred == ICmpInst::ICMP_SLT || Cond.Pred == ICmpInst::ICMP_ULT)
93 return true;
94
95 // For non-exit condition, if pre is LE, try to convert it to LT.
96 // Range Range
97 // AddRec <= Bound --> AddRec < Bound + 1
98 if (Cond.Pred != ICmpInst::ICMP_ULE && Cond.Pred != ICmpInst::ICMP_SLE)
99 return false;
100
101 if (IntegerType *BoundSCEVIntType =
102 dyn_cast<IntegerType>(Val: Cond.BoundSCEV->getType())) {
103 unsigned BitWidth = BoundSCEVIntType->getBitWidth();
104 APInt Max = ICmpInst::isSigned(predicate: Cond.Pred)
105 ? APInt::getSignedMaxValue(numBits: BitWidth)
106 : APInt::getMaxValue(numBits: BitWidth);
107 const SCEV *MaxSCEV = SE.getConstant(Val: Max);
108 // Check Bound < INT_MAX
109 ICmpInst::Predicate Pred =
110 ICmpInst::isSigned(predicate: Cond.Pred) ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
111 if (SE.isKnownPredicate(Pred, LHS: Cond.BoundSCEV, RHS: MaxSCEV)) {
112 const SCEV *BoundPlusOneSCEV =
113 SE.getAddExpr(LHS: Cond.BoundSCEV, RHS: SE.getOne(Ty: BoundSCEVIntType));
114 Cond.BoundSCEV = BoundPlusOneSCEV;
115 Cond.Pred = Pred;
116 return true;
117 }
118 }
119
120 // ToDo: Support ICMP_NE/EQ.
121
122 return false;
123}
124
125static bool hasProcessableCondition(const Loop &L, ScalarEvolution &SE,
126 ICmpInst *ICmp, ConditionInfo &Cond,
127 bool IsExitCond) {
128 analyzeICmp(SE, ICmp, Cond, L);
129
130 // The BoundSCEV should be evaluated at loop entry.
131 if (!SE.isAvailableAtLoopEntry(S: Cond.BoundSCEV, L: &L))
132 return false;
133
134 // Allowed AddRec as induction variable.
135 if (!Cond.AddRecSCEV)
136 return false;
137
138 if (!Cond.AddRecSCEV->isAffine())
139 return false;
140
141 const SCEV *StepRecSCEV = Cond.AddRecSCEV->getStepRecurrence(SE);
142 // Allowed constant step.
143 if (!isa<SCEVConstant>(Val: StepRecSCEV))
144 return false;
145
146 ConstantInt *StepCI = cast<SCEVConstant>(Val: StepRecSCEV)->getValue();
147 // Allowed positive step for now.
148 // TODO: Support negative step.
149 if (StepCI->isNegative() || StepCI->isZero())
150 return false;
151
152 // Calculate upper bound.
153 if (!calculateUpperBound(L, SE, Cond, IsExitCond))
154 return false;
155
156 return true;
157}
158
159static bool isProcessableCondBI(const ScalarEvolution &SE,
160 const BranchInst *BI) {
161 BasicBlock *TrueSucc = nullptr;
162 BasicBlock *FalseSucc = nullptr;
163 Value *LHS, *RHS;
164 if (!match(V: BI, P: m_Br(C: m_ICmp(L: m_Value(V&: LHS), R: m_Value(V&: RHS)),
165 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc))))
166 return false;
167
168 if (!SE.isSCEVable(Ty: LHS->getType()))
169 return false;
170 assert(SE.isSCEVable(RHS->getType()) && "Expected RHS's type is SCEVable");
171
172 if (TrueSucc == FalseSucc)
173 return false;
174
175 return true;
176}
177
178static bool canSplitLoopBound(const Loop &L, const DominatorTree &DT,
179 ScalarEvolution &SE, ConditionInfo &Cond) {
180 // Skip function with optsize.
181 if (L.getHeader()->getParent()->hasOptSize())
182 return false;
183
184 // Split only innermost loop.
185 if (!L.isInnermost())
186 return false;
187
188 // Check loop is in simplified form.
189 if (!L.isLoopSimplifyForm())
190 return false;
191
192 // Check loop is in LCSSA form.
193 if (!L.isLCSSAForm(DT))
194 return false;
195
196 // Skip loop that cannot be cloned.
197 if (!L.isSafeToClone())
198 return false;
199
200 BasicBlock *ExitingBB = L.getExitingBlock();
201 // Assumed only one exiting block.
202 if (!ExitingBB)
203 return false;
204
205 BranchInst *ExitingBI = dyn_cast<BranchInst>(Val: ExitingBB->getTerminator());
206 if (!ExitingBI)
207 return false;
208
209 // Allowed only conditional branch with ICmp.
210 if (!isProcessableCondBI(SE, BI: ExitingBI))
211 return false;
212
213 // Check the condition is processable.
214 ICmpInst *ICmp = cast<ICmpInst>(Val: ExitingBI->getCondition());
215 if (!hasProcessableCondition(L, SE, ICmp, Cond, /*IsExitCond*/ true))
216 return false;
217
218 Cond.BI = ExitingBI;
219 return true;
220}
221
222static bool isProfitableToTransform(const Loop &L, const BranchInst *BI) {
223 // If the conditional branch splits a loop into two halves, we could
224 // generally say it is profitable.
225 //
226 // ToDo: Add more profitable cases here.
227
228 // Check this branch causes diamond CFG.
229 BasicBlock *Succ0 = BI->getSuccessor(i: 0);
230 BasicBlock *Succ1 = BI->getSuccessor(i: 1);
231
232 BasicBlock *Succ0Succ = Succ0->getSingleSuccessor();
233 BasicBlock *Succ1Succ = Succ1->getSingleSuccessor();
234 if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ)
235 return false;
236
237 // ToDo: Calculate each successor's instruction cost.
238
239 return true;
240}
241
242static BranchInst *findSplitCandidate(const Loop &L, ScalarEvolution &SE,
243 ConditionInfo &ExitingCond,
244 ConditionInfo &SplitCandidateCond) {
245 for (auto *BB : L.blocks()) {
246 // Skip condition of backedge.
247 if (L.getLoopLatch() == BB)
248 continue;
249
250 auto *BI = dyn_cast<BranchInst>(Val: BB->getTerminator());
251 if (!BI)
252 continue;
253
254 // Check conditional branch with ICmp.
255 if (!isProcessableCondBI(SE, BI))
256 continue;
257
258 // Skip loop invariant condition.
259 if (L.isLoopInvariant(V: BI->getCondition()))
260 continue;
261
262 // Check the condition is processable.
263 ICmpInst *ICmp = cast<ICmpInst>(Val: BI->getCondition());
264 if (!hasProcessableCondition(L, SE, ICmp, Cond&: SplitCandidateCond,
265 /*IsExitCond*/ false))
266 continue;
267
268 if (ExitingCond.BoundSCEV->getType() !=
269 SplitCandidateCond.BoundSCEV->getType())
270 continue;
271
272 // After transformation, we assume the split condition of the pre-loop is
273 // always true. In order to guarantee it, we need to check the start value
274 // of the split cond AddRec satisfies the split condition.
275 if (!SE.isLoopEntryGuardedByCond(L: &L, Pred: SplitCandidateCond.Pred,
276 LHS: SplitCandidateCond.AddRecSCEV->getStart(),
277 RHS: SplitCandidateCond.BoundSCEV))
278 continue;
279
280 SplitCandidateCond.BI = BI;
281 return BI;
282 }
283
284 return nullptr;
285}
286
287static bool splitLoopBound(Loop &L, DominatorTree &DT, LoopInfo &LI,
288 ScalarEvolution &SE, LPMUpdater &U) {
289 ConditionInfo SplitCandidateCond;
290 ConditionInfo ExitingCond;
291
292 // Check we can split this loop's bound.
293 if (!canSplitLoopBound(L, DT, SE, Cond&: ExitingCond))
294 return false;
295
296 if (!findSplitCandidate(L, SE, ExitingCond, SplitCandidateCond))
297 return false;
298
299 if (!isProfitableToTransform(L, BI: SplitCandidateCond.BI))
300 return false;
301
302 // Now, we have a split candidate. Let's build a form as below.
303 // +--------------------+
304 // | preheader |
305 // | set up newbound |
306 // +--------------------+
307 // | /----------------\
308 // +--------v----v------+ |
309 // | header |---\ |
310 // | with true condition| | |
311 // +--------------------+ | |
312 // | | |
313 // +--------v-----------+ | |
314 // | if.then.BB | | |
315 // +--------------------+ | |
316 // | | |
317 // +--------v-----------<---/ |
318 // | latch >----------/
319 // | with newbound |
320 // +--------------------+
321 // |
322 // +--------v-----------+
323 // | preheader2 |--------------\
324 // | if (AddRec i != | |
325 // | org bound) | |
326 // +--------------------+ |
327 // | /----------------\ |
328 // +--------v----v------+ | |
329 // | header2 |---\ | |
330 // | conditional branch | | | |
331 // |with false condition| | | |
332 // +--------------------+ | | |
333 // | | | |
334 // +--------v-----------+ | | |
335 // | if.then.BB2 | | | |
336 // +--------------------+ | | |
337 // | | | |
338 // +--------v-----------<---/ | |
339 // | latch2 >----------/ |
340 // | with org bound | |
341 // +--------v-----------+ |
342 // | |
343 // | +---------------+ |
344 // +--> exit <-------/
345 // +---------------+
346
347 // Let's create post loop.
348 SmallVector<BasicBlock *, 8> PostLoopBlocks;
349 Loop *PostLoop;
350 ValueToValueMapTy VMap;
351 BasicBlock *PreHeader = L.getLoopPreheader();
352 BasicBlock *SplitLoopPH = SplitEdge(From: PreHeader, To: L.getHeader(), DT: &DT, LI: &LI);
353 PostLoop = cloneLoopWithPreheader(Before: L.getExitBlock(), LoopDomBB: SplitLoopPH, OrigLoop: &L, VMap,
354 NameSuffix: ".split", LI: &LI, DT: &DT, Blocks&: PostLoopBlocks);
355 remapInstructionsInBlocks(Blocks: PostLoopBlocks, VMap);
356
357 BasicBlock *PostLoopPreHeader = PostLoop->getLoopPreheader();
358 IRBuilder<> Builder(&PostLoopPreHeader->front());
359
360 // Update phi nodes in header of post-loop.
361 bool isExitingLatch =
362 (L.getExitingBlock() == L.getLoopLatch()) ? true : false;
363 Value *ExitingCondLCSSAPhi = nullptr;
364 for (PHINode &PN : L.getHeader()->phis()) {
365 // Create LCSSA phi node in preheader of post-loop.
366 PHINode *LCSSAPhi =
367 Builder.CreatePHI(Ty: PN.getType(), NumReservedValues: 1, Name: PN.getName() + ".lcssa");
368 LCSSAPhi->setDebugLoc(PN.getDebugLoc());
369 // If the exiting block is loop latch, the phi does not have the update at
370 // last iteration. In this case, update lcssa phi with value from backedge.
371 LCSSAPhi->addIncoming(
372 V: isExitingLatch ? PN.getIncomingValueForBlock(BB: L.getLoopLatch()) : &PN,
373 BB: L.getExitingBlock());
374
375 // Update the start value of phi node in post-loop with the LCSSA phi node.
376 PHINode *PostLoopPN = cast<PHINode>(Val&: VMap[&PN]);
377 PostLoopPN->setIncomingValueForBlock(BB: PostLoopPreHeader, V: LCSSAPhi);
378
379 // Find PHI with exiting condition from pre-loop. The PHI should be
380 // SCEVAddRecExpr and have same incoming value from backedge with
381 // ExitingCond.
382 if (!SE.isSCEVable(Ty: PN.getType()))
383 continue;
384
385 const SCEVAddRecExpr *PhiSCEV = dyn_cast<SCEVAddRecExpr>(Val: SE.getSCEV(V: &PN));
386 if (PhiSCEV && ExitingCond.NonPHIAddRecValue ==
387 PN.getIncomingValueForBlock(BB: L.getLoopLatch()))
388 ExitingCondLCSSAPhi = LCSSAPhi;
389 }
390
391 // Add conditional branch to check we can skip post-loop in its preheader.
392 Instruction *OrigBI = PostLoopPreHeader->getTerminator();
393 ICmpInst::Predicate Pred = ICmpInst::ICMP_NE;
394 Value *Cond =
395 Builder.CreateICmp(P: Pred, LHS: ExitingCondLCSSAPhi, RHS: ExitingCond.BoundValue);
396 Builder.CreateCondBr(Cond, True: PostLoop->getHeader(), False: PostLoop->getExitBlock());
397 OrigBI->eraseFromParent();
398
399 // Create new loop bound and add it into preheader of pre-loop.
400 const SCEV *NewBoundSCEV = ExitingCond.BoundSCEV;
401 const SCEV *SplitBoundSCEV = SplitCandidateCond.BoundSCEV;
402 NewBoundSCEV = ICmpInst::isSigned(predicate: ExitingCond.Pred)
403 ? SE.getSMinExpr(LHS: NewBoundSCEV, RHS: SplitBoundSCEV)
404 : SE.getUMinExpr(LHS: NewBoundSCEV, RHS: SplitBoundSCEV);
405
406 SCEVExpander Expander(
407 SE, L.getHeader()->getDataLayout(), "split");
408 Instruction *InsertPt = SplitLoopPH->getTerminator();
409 Value *NewBoundValue =
410 Expander.expandCodeFor(SH: NewBoundSCEV, Ty: NewBoundSCEV->getType(), I: InsertPt);
411 NewBoundValue->setName("new.bound");
412
413 // Replace exiting bound value of pre-loop NewBound.
414 ExitingCond.ICmp->setOperand(i_nocapture: 1, Val_nocapture: NewBoundValue);
415
416 // Replace SplitCandidateCond.BI's condition of pre-loop by True.
417 LLVMContext &Context = PreHeader->getContext();
418 SplitCandidateCond.BI->setCondition(ConstantInt::getTrue(Context));
419
420 // Replace cloned SplitCandidateCond.BI's condition in post-loop by False.
421 BranchInst *ClonedSplitCandidateBI =
422 cast<BranchInst>(Val&: VMap[SplitCandidateCond.BI]);
423 ClonedSplitCandidateBI->setCondition(ConstantInt::getFalse(Context));
424
425 // Replace exit branch target of pre-loop by post-loop's preheader.
426 if (L.getExitBlock() == ExitingCond.BI->getSuccessor(i: 0))
427 ExitingCond.BI->setSuccessor(idx: 0, NewSucc: PostLoopPreHeader);
428 else
429 ExitingCond.BI->setSuccessor(idx: 1, NewSucc: PostLoopPreHeader);
430
431 // Update phi node in exit block of post-loop.
432 Builder.SetInsertPoint(TheBB: PostLoopPreHeader, IP: PostLoopPreHeader->begin());
433 for (PHINode &PN : PostLoop->getExitBlock()->phis()) {
434 for (auto i : seq<int>(Begin: 0, End: PN.getNumOperands())) {
435 // Check incoming block is pre-loop's exiting block.
436 if (PN.getIncomingBlock(i) == L.getExitingBlock()) {
437 Value *IncomingValue = PN.getIncomingValue(i);
438
439 // Create LCSSA phi node for incoming value.
440 PHINode *LCSSAPhi =
441 Builder.CreatePHI(Ty: PN.getType(), NumReservedValues: 1, Name: PN.getName() + ".lcssa");
442 LCSSAPhi->setDebugLoc(PN.getDebugLoc());
443 LCSSAPhi->addIncoming(V: IncomingValue, BB: PN.getIncomingBlock(i));
444
445 // Replace pre-loop's exiting block by post-loop's preheader.
446 PN.setIncomingBlock(i, BB: PostLoopPreHeader);
447 // Replace incoming value by LCSSAPhi.
448 PN.setIncomingValue(i, V: LCSSAPhi);
449 // Add a new incoming value with post-loop's exiting block.
450 PN.addIncoming(V: VMap[IncomingValue], BB: PostLoop->getExitingBlock());
451 }
452 }
453 }
454
455 // Update dominator tree.
456 DT.changeImmediateDominator(BB: PostLoopPreHeader, NewBB: L.getExitingBlock());
457 DT.changeImmediateDominator(BB: PostLoop->getExitBlock(), NewBB: PostLoopPreHeader);
458
459 // Invalidate cached SE information.
460 SE.forgetLoop(L: &L);
461
462 // Canonicalize loops.
463 simplifyLoop(L: &L, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, PreserveLCSSA: true);
464 simplifyLoop(L: PostLoop, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, PreserveLCSSA: true);
465
466 // Add new post-loop to loop pass manager.
467 U.addSiblingLoops(NewSibLoops: PostLoop);
468
469 return true;
470}
471
472PreservedAnalyses LoopBoundSplitPass::run(Loop &L, LoopAnalysisManager &AM,
473 LoopStandardAnalysisResults &AR,
474 LPMUpdater &U) {
475 Function &F = *L.getHeader()->getParent();
476 (void)F;
477
478 LLVM_DEBUG(dbgs() << "Spliting bound of loop in " << F.getName() << ": " << L
479 << "\n");
480
481 if (!splitLoopBound(L, DT&: AR.DT, LI&: AR.LI, SE&: AR.SE, U))
482 return PreservedAnalyses::all();
483
484 assert(AR.DT.verify(DominatorTree::VerificationLevel::Fast));
485 AR.LI.verify(DomTree: AR.DT);
486
487 return getLoopPassPreservedAnalyses();
488}
489
490} // end namespace llvm
491