1#include "llvm/Transforms/Utils/LoopConstrainer.h"
2#include "llvm/Analysis/LoopInfo.h"
3#include "llvm/Analysis/ScalarEvolution.h"
4#include "llvm/Analysis/ScalarEvolutionExpressions.h"
5#include "llvm/IR/Dominators.h"
6#include "llvm/Transforms/Utils/Cloning.h"
7#include "llvm/Transforms/Utils/LoopSimplify.h"
8#include "llvm/Transforms/Utils/LoopUtils.h"
9#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
10
11using namespace llvm;
12
13static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
14
15#define DEBUG_TYPE "loop-constrainer"
16
17/// Given a loop with an deccreasing induction variable, is it possible to
18/// safely calculate the bounds of a new loop using the given Predicate.
19static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
20 const SCEV *Step, ICmpInst::Predicate Pred,
21 unsigned LatchBrExitIdx, Loop *L,
22 ScalarEvolution &SE) {
23 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
24 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
25 return false;
26
27 if (!SE.isAvailableAtLoopEntry(S: BoundSCEV, L))
28 return false;
29
30 assert(SE.isKnownNegative(Step) && "expecting negative step");
31
32 LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
33 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
34 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
35 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
36 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
37 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
38
39 bool IsSigned = ICmpInst::isSigned(predicate: Pred);
40 // The predicate that we need to check that the induction variable lies
41 // within bounds.
42 ICmpInst::Predicate BoundPred =
43 IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
44
45 auto StartLG = SE.applyLoopGuards(Expr: Start, L);
46 auto BoundLG = SE.applyLoopGuards(Expr: BoundSCEV, L);
47
48 if (LatchBrExitIdx == 1)
49 return SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: StartLG, RHS: BoundLG);
50
51 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
52
53 const SCEV *StepPlusOne = SE.getAddExpr(LHS: Step, RHS: SE.getOne(Ty: Step->getType()));
54 unsigned BitWidth = cast<IntegerType>(Val: BoundSCEV->getType())->getBitWidth();
55 APInt Min = IsSigned ? APInt::getSignedMinValue(numBits: BitWidth)
56 : APInt::getMinValue(numBits: BitWidth);
57 const SCEV *Limit = SE.getMinusSCEV(LHS: SE.getConstant(Val: Min), RHS: StepPlusOne);
58
59 const SCEV *MinusOne =
60 SE.getMinusSCEV(LHS: BoundLG, RHS: SE.getOne(Ty: BoundLG->getType()));
61
62 return SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: StartLG, RHS: MinusOne) &&
63 SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: BoundLG, RHS: Limit);
64}
65
66/// Given a loop with an increasing induction variable, is it possible to
67/// safely calculate the bounds of a new loop using the given Predicate.
68static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
69 const SCEV *Step, ICmpInst::Predicate Pred,
70 unsigned LatchBrExitIdx, Loop *L,
71 ScalarEvolution &SE) {
72 if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
73 Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
74 return false;
75
76 if (!SE.isAvailableAtLoopEntry(S: BoundSCEV, L))
77 return false;
78
79 LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
80 LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
81 LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
82 LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
83 LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
84 LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
85
86 bool IsSigned = ICmpInst::isSigned(predicate: Pred);
87 // The predicate that we need to check that the induction variable lies
88 // within bounds.
89 ICmpInst::Predicate BoundPred =
90 IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
91
92 auto StartLG = SE.applyLoopGuards(Expr: Start, L);
93 auto BoundLG = SE.applyLoopGuards(Expr: BoundSCEV, L);
94
95 if (LatchBrExitIdx == 1)
96 return SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: StartLG, RHS: BoundLG);
97
98 assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
99
100 const SCEV *StepMinusOne = SE.getMinusSCEV(LHS: Step, RHS: SE.getOne(Ty: Step->getType()));
101 unsigned BitWidth = cast<IntegerType>(Val: BoundSCEV->getType())->getBitWidth();
102 APInt Max = IsSigned ? APInt::getSignedMaxValue(numBits: BitWidth)
103 : APInt::getMaxValue(numBits: BitWidth);
104 const SCEV *Limit = SE.getMinusSCEV(LHS: SE.getConstant(Val: Max), RHS: StepMinusOne);
105
106 return (SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: StartLG,
107 RHS: SE.getAddExpr(LHS: BoundLG, RHS: Step)) &&
108 SE.isLoopEntryGuardedByCond(L, Pred: BoundPred, LHS: BoundLG, RHS: Limit));
109}
110
111/// Returns estimate for max latch taken count of the loop of the narrowest
112/// available type. If the latch block has such estimate, it is returned.
113/// Otherwise, we use max exit count of whole loop (that is potentially of wider
114/// type than latch check itself), which is still better than no estimate.
115static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
116 const Loop &L) {
117 const SCEV *FromBlock =
118 SE.getExitCount(L: &L, ExitingBlock: L.getLoopLatch(), Kind: ScalarEvolution::SymbolicMaximum);
119 if (isa<SCEVCouldNotCompute>(Val: FromBlock))
120 return SE.getSymbolicMaxBackedgeTakenCount(L: &L);
121 return FromBlock;
122}
123
124std::optional<LoopStructure>
125LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
126 bool AllowUnsignedLatchCond,
127 const char *&FailureReason) {
128 if (!L.isLoopSimplifyForm()) {
129 FailureReason = "loop not in LoopSimplify form";
130 return std::nullopt;
131 }
132
133 BasicBlock *Latch = L.getLoopLatch();
134 assert(Latch && "Simplified loops only have one latch!");
135
136 if (Latch->getTerminator()->getMetadata(Kind: ClonedLoopTag)) {
137 FailureReason = "loop has already been cloned";
138 return std::nullopt;
139 }
140
141 if (!L.isLoopExiting(BB: Latch)) {
142 FailureReason = "no loop latch";
143 return std::nullopt;
144 }
145
146 BasicBlock *Header = L.getHeader();
147 BasicBlock *Preheader = L.getLoopPreheader();
148 if (!Preheader) {
149 FailureReason = "no preheader";
150 return std::nullopt;
151 }
152
153 BranchInst *LatchBr = dyn_cast<BranchInst>(Val: Latch->getTerminator());
154 if (!LatchBr || LatchBr->isUnconditional()) {
155 FailureReason = "latch terminator not conditional branch";
156 return std::nullopt;
157 }
158
159 unsigned LatchBrExitIdx = LatchBr->getSuccessor(i: 0) == Header ? 1 : 0;
160
161 ICmpInst *ICI = dyn_cast<ICmpInst>(Val: LatchBr->getCondition());
162 if (!ICI || !isa<IntegerType>(Val: ICI->getOperand(i_nocapture: 0)->getType())) {
163 FailureReason = "latch terminator branch not conditional on integral icmp";
164 return std::nullopt;
165 }
166
167 const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
168 if (isa<SCEVCouldNotCompute>(Val: MaxBETakenCount)) {
169 FailureReason = "could not compute latch count";
170 return std::nullopt;
171 }
172 assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
173 ScalarEvolution::LoopInvariant &&
174 "loop variant exit count doesn't make sense!");
175
176 ICmpInst::Predicate Pred = ICI->getPredicate();
177 Value *LeftValue = ICI->getOperand(i_nocapture: 0);
178 const SCEV *LeftSCEV = SE.getSCEV(V: LeftValue);
179 IntegerType *IndVarTy = cast<IntegerType>(Val: LeftValue->getType());
180
181 Value *RightValue = ICI->getOperand(i_nocapture: 1);
182 const SCEV *RightSCEV = SE.getSCEV(V: RightValue);
183
184 // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
185 if (!isa<SCEVAddRecExpr>(Val: LeftSCEV)) {
186 if (isa<SCEVAddRecExpr>(Val: RightSCEV)) {
187 std::swap(a&: LeftSCEV, b&: RightSCEV);
188 std::swap(a&: LeftValue, b&: RightValue);
189 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
190 } else {
191 FailureReason = "no add recurrences in the icmp";
192 return std::nullopt;
193 }
194 }
195
196 auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
197 if (AR->getNoWrapFlags(Mask: SCEV::FlagNSW))
198 return true;
199
200 IntegerType *Ty = cast<IntegerType>(Val: AR->getType());
201 IntegerType *WideTy =
202 IntegerType::get(C&: Ty->getContext(), NumBits: Ty->getBitWidth() * 2);
203
204 const SCEVAddRecExpr *ExtendAfterOp =
205 dyn_cast<SCEVAddRecExpr>(Val: SE.getSignExtendExpr(Op: AR, Ty: WideTy));
206 if (ExtendAfterOp) {
207 const SCEV *ExtendedStart = SE.getSignExtendExpr(Op: AR->getStart(), Ty: WideTy);
208 const SCEV *ExtendedStep =
209 SE.getSignExtendExpr(Op: AR->getStepRecurrence(SE), Ty: WideTy);
210
211 bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
212 ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
213
214 if (NoSignedWrap)
215 return true;
216 }
217
218 // We may have proved this when computing the sign extension above.
219 return AR->getNoWrapFlags(Mask: SCEV::FlagNSW) != SCEV::FlagAnyWrap;
220 };
221
222 // `ICI` is interpreted as taking the backedge if the *next* value of the
223 // induction variable satisfies some constraint.
224
225 const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(Val: LeftSCEV);
226 if (IndVarBase->getLoop() != &L) {
227 FailureReason = "LHS in cmp is not an AddRec for this loop";
228 return std::nullopt;
229 }
230 if (!IndVarBase->isAffine()) {
231 FailureReason = "LHS in icmp not induction variable";
232 return std::nullopt;
233 }
234 const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
235 if (!isa<SCEVConstant>(Val: StepRec)) {
236 FailureReason = "LHS in icmp not induction variable";
237 return std::nullopt;
238 }
239 ConstantInt *StepCI = cast<SCEVConstant>(Val: StepRec)->getValue();
240
241 if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
242 FailureReason = "LHS in icmp needs nsw for equality predicates";
243 return std::nullopt;
244 }
245
246 assert(!StepCI->isZero() && "Zero step?");
247 bool IsIncreasing = !StepCI->isNegative();
248 bool IsSignedPredicate;
249 const SCEV *StartNext = IndVarBase->getStart();
250 const SCEV *Addend = SE.getNegativeSCEV(V: IndVarBase->getStepRecurrence(SE));
251 const SCEV *IndVarStart = SE.getAddExpr(LHS: StartNext, RHS: Addend);
252 const SCEV *Step = SE.getSCEV(V: StepCI);
253
254 const SCEV *FixedRightSCEV = nullptr;
255
256 // If RightValue resides within loop (but still being loop invariant),
257 // regenerate it as preheader.
258 if (auto *I = dyn_cast<Instruction>(Val: RightValue))
259 if (L.contains(BB: I->getParent()))
260 FixedRightSCEV = RightSCEV;
261
262 if (IsIncreasing) {
263 bool DecreasedRightValueByOne = false;
264 if (StepCI->isOne()) {
265 // Try to turn eq/ne predicates to those we can work with.
266 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
267 // while (++i != len) { while (++i < len) {
268 // ... ---> ...
269 // } }
270 // If both parts are known non-negative, it is profitable to use
271 // unsigned comparison in increasing loop. This allows us to make the
272 // comparison check against "RightSCEV + 1" more optimistic.
273 if (isKnownNonNegativeInLoop(S: IndVarStart, L: &L, SE) &&
274 isKnownNonNegativeInLoop(S: RightSCEV, L: &L, SE))
275 Pred = ICmpInst::ICMP_ULT;
276 else
277 Pred = ICmpInst::ICMP_SLT;
278 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
279 // while (true) { while (true) {
280 // if (++i == len) ---> if (++i > len - 1)
281 // break; break;
282 // ... ...
283 // } }
284 if (IndVarBase->getNoWrapFlags(Mask: SCEV::FlagNUW) &&
285 cannotBeMinInLoop(S: RightSCEV, L: &L, SE, /*Signed*/ false)) {
286 Pred = ICmpInst::ICMP_UGT;
287 RightSCEV =
288 SE.getMinusSCEV(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
289 DecreasedRightValueByOne = true;
290 } else if (cannotBeMinInLoop(S: RightSCEV, L: &L, SE, /*Signed*/ true)) {
291 Pred = ICmpInst::ICMP_SGT;
292 RightSCEV =
293 SE.getMinusSCEV(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
294 DecreasedRightValueByOne = true;
295 }
296 }
297 }
298
299 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
300 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
301 bool FoundExpectedPred =
302 (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
303
304 if (!FoundExpectedPred) {
305 FailureReason = "expected icmp slt semantically, found something else";
306 return std::nullopt;
307 }
308
309 IsSignedPredicate = ICmpInst::isSigned(predicate: Pred);
310 if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
311 FailureReason = "unsigned latch conditions are explicitly prohibited";
312 return std::nullopt;
313 }
314
315 if (!isSafeIncreasingBound(Start: IndVarStart, BoundSCEV: RightSCEV, Step, Pred,
316 LatchBrExitIdx, L: &L, SE)) {
317 FailureReason = "Unsafe loop bounds";
318 return std::nullopt;
319 }
320 if (LatchBrExitIdx == 0) {
321 // We need to increase the right value unless we have already decreased
322 // it virtually when we replaced EQ with SGT.
323 if (!DecreasedRightValueByOne)
324 FixedRightSCEV =
325 SE.getAddExpr(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
326 } else {
327 assert(!DecreasedRightValueByOne &&
328 "Right value can be decreased only for LatchBrExitIdx == 0!");
329 }
330 } else {
331 bool IncreasedRightValueByOne = false;
332 if (StepCI->isMinusOne()) {
333 // Try to turn eq/ne predicates to those we can work with.
334 if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
335 // while (--i != len) { while (--i > len) {
336 // ... ---> ...
337 // } }
338 // We intentionally don't turn the predicate into UGT even if we know
339 // that both operands are non-negative, because it will only pessimize
340 // our check against "RightSCEV - 1".
341 Pred = ICmpInst::ICMP_SGT;
342 else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
343 // while (true) { while (true) {
344 // if (--i == len) ---> if (--i < len + 1)
345 // break; break;
346 // ... ...
347 // } }
348 if (IndVarBase->getNoWrapFlags(Mask: SCEV::FlagNUW) &&
349 cannotBeMaxInLoop(S: RightSCEV, L: &L, SE, /* Signed */ false)) {
350 Pred = ICmpInst::ICMP_ULT;
351 RightSCEV = SE.getAddExpr(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
352 IncreasedRightValueByOne = true;
353 } else if (cannotBeMaxInLoop(S: RightSCEV, L: &L, SE, /* Signed */ true)) {
354 Pred = ICmpInst::ICMP_SLT;
355 RightSCEV = SE.getAddExpr(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
356 IncreasedRightValueByOne = true;
357 }
358 }
359 }
360
361 bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
362 bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
363
364 bool FoundExpectedPred =
365 (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
366
367 if (!FoundExpectedPred) {
368 FailureReason = "expected icmp sgt semantically, found something else";
369 return std::nullopt;
370 }
371
372 IsSignedPredicate =
373 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
374
375 if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
376 FailureReason = "unsigned latch conditions are explicitly prohibited";
377 return std::nullopt;
378 }
379
380 if (!isSafeDecreasingBound(Start: IndVarStart, BoundSCEV: RightSCEV, Step, Pred,
381 LatchBrExitIdx, L: &L, SE)) {
382 FailureReason = "Unsafe bounds";
383 return std::nullopt;
384 }
385
386 if (LatchBrExitIdx == 0) {
387 // We need to decrease the right value unless we have already increased
388 // it virtually when we replaced EQ with SLT.
389 if (!IncreasedRightValueByOne)
390 FixedRightSCEV =
391 SE.getMinusSCEV(LHS: RightSCEV, RHS: SE.getOne(Ty: RightSCEV->getType()));
392 } else {
393 assert(!IncreasedRightValueByOne &&
394 "Right value can be increased only for LatchBrExitIdx == 0!");
395 }
396 }
397 BasicBlock *LatchExit = LatchBr->getSuccessor(i: LatchBrExitIdx);
398
399 assert(!L.contains(LatchExit) && "expected an exit block!");
400 const DataLayout &DL = Preheader->getDataLayout();
401 SCEVExpander Expander(SE, DL, "loop-constrainer");
402 Instruction *Ins = Preheader->getTerminator();
403
404 if (FixedRightSCEV)
405 RightValue =
406 Expander.expandCodeFor(SH: FixedRightSCEV, Ty: FixedRightSCEV->getType(), I: Ins);
407
408 Value *IndVarStartV = Expander.expandCodeFor(SH: IndVarStart, Ty: IndVarTy, I: Ins);
409 IndVarStartV->setName("indvar.start");
410
411 LoopStructure Result;
412
413 Result.Tag = "main";
414 Result.Header = Header;
415 Result.Latch = Latch;
416 Result.LatchBr = LatchBr;
417 Result.LatchExit = LatchExit;
418 Result.LatchBrExitIdx = LatchBrExitIdx;
419 Result.IndVarStart = IndVarStartV;
420 Result.IndVarStep = StepCI;
421 Result.IndVarBase = LeftValue;
422 Result.IndVarIncreasing = IsIncreasing;
423 Result.LoopExitAt = RightValue;
424 Result.IsSignedPredicate = IsSignedPredicate;
425 Result.ExitCountTy = cast<IntegerType>(Val: MaxBETakenCount->getType());
426
427 FailureReason = nullptr;
428
429 return Result;
430}
431
432// Add metadata to the loop L to disable loop optimizations. Callers need to
433// confirm that optimizing loop L is not beneficial.
434static void DisableAllLoopOptsOnLoop(Loop &L) {
435 // We do not care about any existing loopID related metadata for L, since we
436 // are setting all loop metadata to false.
437 LLVMContext &Context = L.getHeader()->getContext();
438 // Reserve first location for self reference to the LoopID metadata node.
439 MDNode *Dummy = MDNode::get(Context, MDs: {});
440 MDNode *DisableUnroll = MDNode::get(
441 Context, MDs: {MDString::get(Context, Str: "llvm.loop.unroll.disable")});
442 Metadata *FalseVal =
443 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt1Ty(C&: Context), V: 0));
444 MDNode *DisableVectorize = MDNode::get(
445 Context,
446 MDs: {MDString::get(Context, Str: "llvm.loop.vectorize.enable"), FalseVal});
447 MDNode *DisableLICMVersioning = MDNode::get(
448 Context, MDs: {MDString::get(Context, Str: "llvm.loop.licm_versioning.disable")});
449 MDNode *DisableDistribution = MDNode::get(
450 Context,
451 MDs: {MDString::get(Context, Str: "llvm.loop.distribute.enable"), FalseVal});
452 MDNode *NewLoopID =
453 MDNode::get(Context, MDs: {Dummy, DisableUnroll, DisableVectorize,
454 DisableLICMVersioning, DisableDistribution});
455 // Set operand 0 to refer to the loop id itself.
456 NewLoopID->replaceOperandWith(I: 0, New: NewLoopID);
457 L.setLoopID(NewLoopID);
458}
459
460LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
461 function_ref<void(Loop *, bool)> LPMAddNewLoop,
462 const LoopStructure &LS, ScalarEvolution &SE,
463 DominatorTree &DT, Type *T, SubRanges SR)
464 : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
465 DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
466 MainLoopStructure(LS), SR(SR) {}
467
468void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
469 const char *Tag) const {
470 for (BasicBlock *BB : OriginalLoop.getBlocks()) {
471 BasicBlock *Clone = CloneBasicBlock(BB, VMap&: Result.Map, NameSuffix: Twine(".") + Tag, F: &F);
472 Result.Blocks.push_back(x: Clone);
473 Result.Map[BB] = Clone;
474 }
475
476 auto GetClonedValue = [&Result](Value *V) {
477 assert(V && "null values not in domain!");
478 auto It = Result.Map.find(Val: V);
479 if (It == Result.Map.end())
480 return V;
481 return static_cast<Value *>(It->second);
482 };
483
484 auto *ClonedLatch =
485 cast<BasicBlock>(Val: GetClonedValue(OriginalLoop.getLoopLatch()));
486 ClonedLatch->getTerminator()->setMetadata(Kind: ClonedLoopTag,
487 Node: MDNode::get(Context&: Ctx, MDs: {}));
488
489 Result.Structure = MainLoopStructure.map(Map: GetClonedValue);
490 Result.Structure.Tag = Tag;
491
492 for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
493 BasicBlock *ClonedBB = Result.Blocks[i];
494 BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
495
496 assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
497
498 for (Instruction &I : *ClonedBB)
499 RemapInstruction(I: &I, VM&: Result.Map,
500 Flags: RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
501
502 // Exit blocks will now have one more predecessor and their PHI nodes need
503 // to be edited to reflect that. No phi nodes need to be introduced because
504 // the loop is in LCSSA.
505
506 for (auto *SBB : successors(BB: OriginalBB)) {
507 if (OriginalLoop.contains(BB: SBB))
508 continue; // not an exit block
509
510 for (PHINode &PN : SBB->phis()) {
511 Value *OldIncoming = PN.getIncomingValueForBlock(BB: OriginalBB);
512 PN.addIncoming(V: GetClonedValue(OldIncoming), BB: ClonedBB);
513 SE.forgetValue(V: &PN);
514 }
515 }
516 }
517}
518
519LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
520 const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
521 BasicBlock *ContinuationBlock) const {
522 // We start with a loop with a single latch:
523 //
524 // +--------------------+
525 // | |
526 // | preheader |
527 // | |
528 // +--------+-----------+
529 // | ----------------\
530 // | / |
531 // +--------v----v------+ |
532 // | | |
533 // | header | |
534 // | | |
535 // +--------------------+ |
536 // |
537 // ..... |
538 // |
539 // +--------------------+ |
540 // | | |
541 // | latch >----------/
542 // | |
543 // +-------v------------+
544 // |
545 // |
546 // | +--------------------+
547 // | | |
548 // +---> original exit |
549 // | |
550 // +--------------------+
551 //
552 // We change the control flow to look like
553 //
554 //
555 // +--------------------+
556 // | |
557 // | preheader >-------------------------+
558 // | | |
559 // +--------v-----------+ |
560 // | /-------------+ |
561 // | / | |
562 // +--------v--v--------+ | |
563 // | | | |
564 // | header | | +--------+ |
565 // | | | | | |
566 // +--------------------+ | | +-----v-----v-----------+
567 // | | | |
568 // | | | .pseudo.exit |
569 // | | | |
570 // | | +-----------v-----------+
571 // | | |
572 // ..... | | |
573 // | | +--------v-------------+
574 // +--------------------+ | | | |
575 // | | | | | ContinuationBlock |
576 // | latch >------+ | | |
577 // | | | +----------------------+
578 // +---------v----------+ |
579 // | |
580 // | |
581 // | +---------------^-----+
582 // | | |
583 // +-----> .exit.selector |
584 // | |
585 // +----------v----------+
586 // |
587 // +--------------------+ |
588 // | | |
589 // | original exit <----+
590 // | |
591 // +--------------------+
592
593 RewrittenRangeInfo RRI;
594
595 BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
596 RRI.ExitSelector = BasicBlock::Create(Context&: Ctx, Name: Twine(LS.Tag) + ".exit.selector",
597 Parent: &F, InsertBefore: BBInsertLocation);
598 RRI.PseudoExit = BasicBlock::Create(Context&: Ctx, Name: Twine(LS.Tag) + ".pseudo.exit", Parent: &F,
599 InsertBefore: BBInsertLocation);
600
601 BranchInst *PreheaderJump = cast<BranchInst>(Val: Preheader->getTerminator());
602 bool Increasing = LS.IndVarIncreasing;
603 bool IsSignedPredicate = LS.IsSignedPredicate;
604
605 IRBuilder<> B(PreheaderJump);
606 auto NoopOrExt = [&](Value *V) {
607 if (V->getType() == RangeTy)
608 return V;
609 return IsSignedPredicate ? B.CreateSExt(V, DestTy: RangeTy, Name: "wide." + V->getName())
610 : B.CreateZExt(V, DestTy: RangeTy, Name: "wide." + V->getName());
611 };
612
613 // EnterLoopCond - is it okay to start executing this `LS'?
614 Value *EnterLoopCond = nullptr;
615 auto Pred =
616 Increasing
617 ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
618 : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
619 Value *IndVarStart = NoopOrExt(LS.IndVarStart);
620 EnterLoopCond = B.CreateICmp(P: Pred, LHS: IndVarStart, RHS: ExitSubloopAt);
621
622 B.CreateCondBr(Cond: EnterLoopCond, True: LS.Header, False: RRI.PseudoExit);
623 PreheaderJump->eraseFromParent();
624
625 LS.LatchBr->setSuccessor(idx: LS.LatchBrExitIdx, NewSucc: RRI.ExitSelector);
626 B.SetInsertPoint(LS.LatchBr);
627 Value *IndVarBase = NoopOrExt(LS.IndVarBase);
628 Value *TakeBackedgeLoopCond = B.CreateICmp(P: Pred, LHS: IndVarBase, RHS: ExitSubloopAt);
629
630 Value *CondForBranch = LS.LatchBrExitIdx == 1
631 ? TakeBackedgeLoopCond
632 : B.CreateNot(V: TakeBackedgeLoopCond);
633
634 LS.LatchBr->setCondition(CondForBranch);
635
636 B.SetInsertPoint(RRI.ExitSelector);
637
638 // IterationsLeft - are there any more iterations left, given the original
639 // upper bound on the induction variable? If not, we branch to the "real"
640 // exit.
641 Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
642 Value *IterationsLeft = B.CreateICmp(P: Pred, LHS: IndVarBase, RHS: LoopExitAt);
643 B.CreateCondBr(Cond: IterationsLeft, True: RRI.PseudoExit, False: LS.LatchExit);
644
645 BranchInst *BranchToContinuation =
646 BranchInst::Create(IfTrue: ContinuationBlock, InsertBefore: RRI.PseudoExit);
647
648 // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
649 // each of the PHI nodes in the loop header. This feeds into the initial
650 // value of the same PHI nodes if/when we continue execution.
651 for (PHINode &PN : LS.Header->phis()) {
652 PHINode *NewPHI = PHINode::Create(Ty: PN.getType(), NumReservedValues: 2, NameStr: PN.getName() + ".copy",
653 InsertBefore: BranchToContinuation->getIterator());
654
655 NewPHI->addIncoming(V: PN.getIncomingValueForBlock(BB: Preheader), BB: Preheader);
656 NewPHI->addIncoming(V: PN.getIncomingValueForBlock(BB: LS.Latch),
657 BB: RRI.ExitSelector);
658 RRI.PHIValuesAtPseudoExit.push_back(x: NewPHI);
659 }
660
661 RRI.IndVarEnd = PHINode::Create(Ty: IndVarBase->getType(), NumReservedValues: 2, NameStr: "indvar.end",
662 InsertBefore: BranchToContinuation->getIterator());
663 RRI.IndVarEnd->addIncoming(V: IndVarStart, BB: Preheader);
664 RRI.IndVarEnd->addIncoming(V: IndVarBase, BB: RRI.ExitSelector);
665
666 // The latch exit now has a branch from `RRI.ExitSelector' instead of
667 // `LS.Latch'. The PHI nodes need to be updated to reflect that.
668 LS.LatchExit->replacePhiUsesWith(Old: LS.Latch, New: RRI.ExitSelector);
669
670 return RRI;
671}
672
673void LoopConstrainer::rewriteIncomingValuesForPHIs(
674 LoopStructure &LS, BasicBlock *ContinuationBlock,
675 const LoopConstrainer::RewrittenRangeInfo &RRI) const {
676 unsigned PHIIndex = 0;
677 for (PHINode &PN : LS.Header->phis())
678 PN.setIncomingValueForBlock(BB: ContinuationBlock,
679 V: RRI.PHIValuesAtPseudoExit[PHIIndex++]);
680
681 LS.IndVarStart = RRI.IndVarEnd;
682}
683
684BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
685 BasicBlock *OldPreheader,
686 const char *Tag) const {
687 BasicBlock *Preheader = BasicBlock::Create(Context&: Ctx, Name: Tag, Parent: &F, InsertBefore: LS.Header);
688 BranchInst::Create(IfTrue: LS.Header, InsertBefore: Preheader);
689
690 LS.Header->replacePhiUsesWith(Old: OldPreheader, New: Preheader);
691
692 return Preheader;
693}
694
695void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
696 Loop *ParentLoop = OriginalLoop.getParentLoop();
697 if (!ParentLoop)
698 return;
699
700 for (BasicBlock *BB : BBs)
701 ParentLoop->addBasicBlockToLoop(NewBB: BB, LI);
702}
703
704Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
705 ValueToValueMapTy &VM,
706 bool IsSubloop) {
707 Loop &New = *LI.AllocateLoop();
708 if (Parent)
709 Parent->addChildLoop(NewChild: &New);
710 else
711 LI.addTopLevelLoop(New: &New);
712 LPMAddNewLoop(&New, IsSubloop);
713
714 // Add all of the blocks in Original to the new loop.
715 for (auto *BB : Original->blocks())
716 if (LI.getLoopFor(BB) == Original)
717 New.addBasicBlockToLoop(NewBB: cast<BasicBlock>(Val&: VM[BB]), LI);
718
719 // Add all of the subloops to the new loop.
720 for (Loop *SubLoop : *Original)
721 createClonedLoopStructure(Original: SubLoop, Parent: &New, VM, /* IsSubloop */ true);
722
723 return &New;
724}
725
726bool LoopConstrainer::run() {
727 BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
728 assert(Preheader != nullptr && "precondition!");
729
730 OriginalPreheader = Preheader;
731 MainLoopPreheader = Preheader;
732 bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
733 bool Increasing = MainLoopStructure.IndVarIncreasing;
734 IntegerType *IVTy = cast<IntegerType>(Val: RangeTy);
735
736 SCEVExpander Expander(SE, F.getDataLayout(), "loop-constrainer");
737 Instruction *InsertPt = OriginalPreheader->getTerminator();
738
739 // It would have been better to make `PreLoop' and `PostLoop'
740 // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
741 // constructor.
742 ClonedLoop PreLoop, PostLoop;
743 bool NeedsPreLoop =
744 Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
745 bool NeedsPostLoop =
746 Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
747
748 Value *ExitPreLoopAt = nullptr;
749 Value *ExitMainLoopAt = nullptr;
750 const SCEVConstant *MinusOneS =
751 cast<SCEVConstant>(Val: SE.getConstant(Ty: IVTy, V: -1, isSigned: true /* isSigned */));
752
753 if (NeedsPreLoop) {
754 const SCEV *ExitPreLoopAtSCEV = nullptr;
755
756 if (Increasing)
757 ExitPreLoopAtSCEV = *SR.LowLimit;
758 else if (cannotBeMinInLoop(S: *SR.HighLimit, L: &OriginalLoop, SE,
759 Signed: IsSignedPredicate))
760 ExitPreLoopAtSCEV = SE.getAddExpr(LHS: *SR.HighLimit, RHS: MinusOneS);
761 else {
762 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
763 << "preloop exit limit. HighLimit = "
764 << *(*SR.HighLimit) << "\n");
765 return false;
766 }
767
768 if (!Expander.isSafeToExpandAt(S: ExitPreLoopAtSCEV, InsertionPoint: InsertPt)) {
769 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
770 << " preloop exit limit " << *ExitPreLoopAtSCEV
771 << " at block " << InsertPt->getParent()->getName()
772 << "\n");
773 return false;
774 }
775
776 ExitPreLoopAt = Expander.expandCodeFor(SH: ExitPreLoopAtSCEV, Ty: IVTy, I: InsertPt);
777 ExitPreLoopAt->setName("exit.preloop.at");
778 }
779
780 if (NeedsPostLoop) {
781 const SCEV *ExitMainLoopAtSCEV = nullptr;
782
783 if (Increasing)
784 ExitMainLoopAtSCEV = *SR.HighLimit;
785 else if (cannotBeMinInLoop(S: *SR.LowLimit, L: &OriginalLoop, SE,
786 Signed: IsSignedPredicate))
787 ExitMainLoopAtSCEV = SE.getAddExpr(LHS: *SR.LowLimit, RHS: MinusOneS);
788 else {
789 LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
790 << "mainloop exit limit. LowLimit = "
791 << *(*SR.LowLimit) << "\n");
792 return false;
793 }
794
795 if (!Expander.isSafeToExpandAt(S: ExitMainLoopAtSCEV, InsertionPoint: InsertPt)) {
796 LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
797 << " main loop exit limit " << *ExitMainLoopAtSCEV
798 << " at block " << InsertPt->getParent()->getName()
799 << "\n");
800 return false;
801 }
802
803 ExitMainLoopAt = Expander.expandCodeFor(SH: ExitMainLoopAtSCEV, Ty: IVTy, I: InsertPt);
804 ExitMainLoopAt->setName("exit.mainloop.at");
805 }
806
807 // We clone these ahead of time so that we don't have to deal with changing
808 // and temporarily invalid IR as we transform the loops.
809 if (NeedsPreLoop)
810 cloneLoop(Result&: PreLoop, Tag: "preloop");
811 if (NeedsPostLoop)
812 cloneLoop(Result&: PostLoop, Tag: "postloop");
813
814 RewrittenRangeInfo PreLoopRRI;
815
816 if (NeedsPreLoop) {
817 Preheader->getTerminator()->replaceUsesOfWith(From: MainLoopStructure.Header,
818 To: PreLoop.Structure.Header);
819
820 MainLoopPreheader =
821 createPreheader(LS: MainLoopStructure, OldPreheader: Preheader, Tag: "mainloop");
822 PreLoopRRI = changeIterationSpaceEnd(LS: PreLoop.Structure, Preheader,
823 ExitSubloopAt: ExitPreLoopAt, ContinuationBlock: MainLoopPreheader);
824 rewriteIncomingValuesForPHIs(LS&: MainLoopStructure, ContinuationBlock: MainLoopPreheader,
825 RRI: PreLoopRRI);
826 }
827
828 BasicBlock *PostLoopPreheader = nullptr;
829 RewrittenRangeInfo PostLoopRRI;
830
831 if (NeedsPostLoop) {
832 PostLoopPreheader =
833 createPreheader(LS: PostLoop.Structure, OldPreheader: Preheader, Tag: "postloop");
834 PostLoopRRI = changeIterationSpaceEnd(LS: MainLoopStructure, Preheader: MainLoopPreheader,
835 ExitSubloopAt: ExitMainLoopAt, ContinuationBlock: PostLoopPreheader);
836 rewriteIncomingValuesForPHIs(LS&: PostLoop.Structure, ContinuationBlock: PostLoopPreheader,
837 RRI: PostLoopRRI);
838 }
839
840 BasicBlock *NewMainLoopPreheader =
841 MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
842 BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
843 PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
844 PostLoopRRI.ExitSelector, NewMainLoopPreheader};
845
846 // Some of the above may be nullptr, filter them out before passing to
847 // addToParentLoopIfNeeded.
848 auto NewBlocksEnd =
849 std::remove(first: std::begin(arr&: NewBlocks), last: std::end(arr&: NewBlocks), value: nullptr);
850
851 addToParentLoopIfNeeded(BBs: ArrayRef(std::begin(arr&: NewBlocks), NewBlocksEnd));
852
853 DT.recalculate(Func&: F);
854
855 // We need to first add all the pre and post loop blocks into the loop
856 // structures (as part of createClonedLoopStructure), and then update the
857 // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
858 // LI when LoopSimplifyForm is generated.
859 Loop *PreL = nullptr, *PostL = nullptr;
860 if (!PreLoop.Blocks.empty()) {
861 PreL = createClonedLoopStructure(Original: &OriginalLoop,
862 Parent: OriginalLoop.getParentLoop(), VM&: PreLoop.Map,
863 /* IsSubLoop */ IsSubloop: false);
864 }
865
866 if (!PostLoop.Blocks.empty()) {
867 PostL =
868 createClonedLoopStructure(Original: &OriginalLoop, Parent: OriginalLoop.getParentLoop(),
869 VM&: PostLoop.Map, /* IsSubLoop */ IsSubloop: false);
870 }
871
872 // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
873 auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
874 formLCSSARecursively(L&: *L, DT, LI: &LI, SE: &SE);
875 simplifyLoop(L, DT: &DT, LI: &LI, SE: &SE, AC: nullptr, MSSAU: nullptr, PreserveLCSSA: true);
876 // Pre/post loops are slow paths, we do not need to perform any loop
877 // optimizations on them.
878 if (!IsOriginalLoop)
879 DisableAllLoopOptsOnLoop(L&: *L);
880 };
881 if (PreL)
882 CanonicalizeLoop(PreL, false);
883 if (PostL)
884 CanonicalizeLoop(PostL, false);
885 CanonicalizeLoop(&OriginalLoop, true);
886
887 /// At this point:
888 /// - We've broken a "main loop" out of the loop in a way that the "main loop"
889 /// runs with the induction variable in a subset of [Begin, End).
890 /// - There is no overflow when computing "main loop" exit limit.
891 /// - Max latch taken count of the loop is limited.
892 /// It guarantees that induction variable will not overflow iterating in the
893 /// "main loop".
894 if (isa<OverflowingBinaryOperator>(Val: MainLoopStructure.IndVarBase))
895 if (IsSignedPredicate)
896 cast<BinaryOperator>(Val: MainLoopStructure.IndVarBase)
897 ->setHasNoSignedWrap(true);
898 /// TODO: support unsigned predicate.
899 /// To add NUW flag we need to prove that both operands of BO are
900 /// non-negative. E.g:
901 /// ...
902 /// %iv.next = add nsw i32 %iv, -1
903 /// %cmp = icmp ult i32 %iv.next, %n
904 /// br i1 %cmp, label %loopexit, label %loop
905 ///
906 /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
907 /// overflow, therefore NUW flag is not legal here.
908
909 return true;
910}
911