1//===-- UnrollLoop.cpp - Loop unrolling utilities -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements some loop unrolling utilities. It does not define any
10// actual pass or policy, but provides a single function to perform loop
11// unrolling.
12//
13// The process of unrolling can produce extraneous basic blocks linked with
14// unconditional branches. This will be corrected in the future.
15//
16//===----------------------------------------------------------------------===//
17
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/ADT/DenseMap.h"
20#include "llvm/ADT/MapVector.h"
21#include "llvm/ADT/STLExtras.h"
22#include "llvm/ADT/ScopedHashTable.h"
23#include "llvm/ADT/SetVector.h"
24#include "llvm/ADT/SmallVector.h"
25#include "llvm/ADT/Statistic.h"
26#include "llvm/ADT/StringRef.h"
27#include "llvm/ADT/Twine.h"
28#include "llvm/Analysis/AliasAnalysis.h"
29#include "llvm/Analysis/AssumptionCache.h"
30#include "llvm/Analysis/DomTreeUpdater.h"
31#include "llvm/Analysis/InstructionSimplify.h"
32#include "llvm/Analysis/LoopInfo.h"
33#include "llvm/Analysis/LoopIterator.h"
34#include "llvm/Analysis/MemorySSA.h"
35#include "llvm/Analysis/OptimizationRemarkEmitter.h"
36#include "llvm/Analysis/ScalarEvolution.h"
37#include "llvm/IR/BasicBlock.h"
38#include "llvm/IR/CFG.h"
39#include "llvm/IR/Constants.h"
40#include "llvm/IR/DebugInfoMetadata.h"
41#include "llvm/IR/DebugLoc.h"
42#include "llvm/IR/DiagnosticInfo.h"
43#include "llvm/IR/Dominators.h"
44#include "llvm/IR/Function.h"
45#include "llvm/IR/IRBuilder.h"
46#include "llvm/IR/Instruction.h"
47#include "llvm/IR/Instructions.h"
48#include "llvm/IR/IntrinsicInst.h"
49#include "llvm/IR/Metadata.h"
50#include "llvm/IR/PatternMatch.h"
51#include "llvm/IR/Use.h"
52#include "llvm/IR/User.h"
53#include "llvm/IR/ValueHandle.h"
54#include "llvm/IR/ValueMap.h"
55#include "llvm/Support/Casting.h"
56#include "llvm/Support/CommandLine.h"
57#include "llvm/Support/Debug.h"
58#include "llvm/Support/GenericDomTree.h"
59#include "llvm/Support/raw_ostream.h"
60#include "llvm/Transforms/Utils/BasicBlockUtils.h"
61#include "llvm/Transforms/Utils/Cloning.h"
62#include "llvm/Transforms/Utils/Local.h"
63#include "llvm/Transforms/Utils/LoopSimplify.h"
64#include "llvm/Transforms/Utils/LoopUtils.h"
65#include "llvm/Transforms/Utils/SimplifyIndVar.h"
66#include "llvm/Transforms/Utils/UnrollLoop.h"
67#include "llvm/Transforms/Utils/ValueMapper.h"
68#include <assert.h>
69#include <cmath>
70#include <numeric>
71#include <vector>
72
73namespace llvm {
74class DataLayout;
75class Value;
76} // namespace llvm
77
78using namespace llvm;
79
80#define DEBUG_TYPE "loop-unroll"
81
82// TODO: Should these be here or in LoopUnroll?
83STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled");
84STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)");
85STATISTIC(NumUnrolledNotLatch, "Number of loops unrolled without a conditional "
86 "latch (completely or otherwise)");
87
88static cl::opt<bool>
89UnrollRuntimeEpilog("unroll-runtime-epilog", cl::init(Val: false), cl::Hidden,
90 cl::desc("Allow runtime unrolled loops to be unrolled "
91 "with epilog instead of prolog."));
92
93static cl::opt<bool>
94UnrollVerifyDomtree("unroll-verify-domtree", cl::Hidden,
95 cl::desc("Verify domtree after unrolling"),
96#ifdef EXPENSIVE_CHECKS
97 cl::init(true)
98#else
99 cl::init(Val: false)
100#endif
101 );
102
103static cl::opt<bool>
104UnrollVerifyLoopInfo("unroll-verify-loopinfo", cl::Hidden,
105 cl::desc("Verify loopinfo after unrolling"),
106#ifdef EXPENSIVE_CHECKS
107 cl::init(true)
108#else
109 cl::init(Val: false)
110#endif
111 );
112
113static cl::opt<bool> UnrollAddParallelReductions(
114 "unroll-add-parallel-reductions", cl::init(Val: false), cl::Hidden,
115 cl::desc("Allow unrolling to add parallel reduction phis."));
116
117/// Check if unrolling created a situation where we need to insert phi nodes to
118/// preserve LCSSA form.
119/// \param Blocks is a vector of basic blocks representing unrolled loop.
120/// \param L is the outer loop.
121/// It's possible that some of the blocks are in L, and some are not. In this
122/// case, if there is a use is outside L, and definition is inside L, we need to
123/// insert a phi-node, otherwise LCSSA will be broken.
124/// The function is just a helper function for llvm::UnrollLoop that returns
125/// true if this situation occurs, indicating that LCSSA needs to be fixed.
126static bool needToInsertPhisForLCSSA(Loop *L,
127 const std::vector<BasicBlock *> &Blocks,
128 LoopInfo *LI) {
129 for (BasicBlock *BB : Blocks) {
130 if (LI->getLoopFor(BB) == L)
131 continue;
132 for (Instruction &I : *BB) {
133 for (Use &U : I.operands()) {
134 if (const auto *Def = dyn_cast<Instruction>(Val&: U)) {
135 Loop *DefLoop = LI->getLoopFor(BB: Def->getParent());
136 if (!DefLoop)
137 continue;
138 if (DefLoop->contains(L))
139 return true;
140 }
141 }
142 }
143 }
144 return false;
145}
146
147/// Adds ClonedBB to LoopInfo, creates a new loop for ClonedBB if necessary
148/// and adds a mapping from the original loop to the new loop to NewLoops.
149/// Returns nullptr if no new loop was created and a pointer to the
150/// original loop OriginalBB was part of otherwise.
151const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB,
152 BasicBlock *ClonedBB, LoopInfo *LI,
153 NewLoopsMap &NewLoops) {
154 // Figure out which loop New is in.
155 const Loop *OldLoop = LI->getLoopFor(BB: OriginalBB);
156 assert(OldLoop && "Should (at least) be in the loop being unrolled!");
157
158 Loop *&NewLoop = NewLoops[OldLoop];
159 if (!NewLoop) {
160 // Found a new sub-loop.
161 assert(OriginalBB == OldLoop->getHeader() &&
162 "Header should be first in RPO");
163
164 NewLoop = LI->AllocateLoop();
165 Loop *NewLoopParent = NewLoops.lookup(Val: OldLoop->getParentLoop());
166
167 if (NewLoopParent)
168 NewLoopParent->addChildLoop(NewChild: NewLoop);
169 else
170 LI->addTopLevelLoop(New: NewLoop);
171
172 NewLoop->addBasicBlockToLoop(NewBB: ClonedBB, LI&: *LI);
173 return OldLoop;
174 } else {
175 NewLoop->addBasicBlockToLoop(NewBB: ClonedBB, LI&: *LI);
176 return nullptr;
177 }
178}
179
180/// The function chooses which type of unroll (epilog or prolog) is more
181/// profitabale.
182/// Epilog unroll is more profitable when there is PHI that starts from
183/// constant. In this case epilog will leave PHI start from constant,
184/// but prolog will convert it to non-constant.
185///
186/// loop:
187/// PN = PHI [I, Latch], [CI, PreHeader]
188/// I = foo(PN)
189/// ...
190///
191/// Epilog unroll case.
192/// loop:
193/// PN = PHI [I2, Latch], [CI, PreHeader]
194/// I1 = foo(PN)
195/// I2 = foo(I1)
196/// ...
197/// Prolog unroll case.
198/// NewPN = PHI [PrologI, Prolog], [CI, PreHeader]
199/// loop:
200/// PN = PHI [I2, Latch], [NewPN, PreHeader]
201/// I1 = foo(PN)
202/// I2 = foo(I1)
203/// ...
204///
205static bool isEpilogProfitable(Loop *L) {
206 BasicBlock *PreHeader = L->getLoopPreheader();
207 BasicBlock *Header = L->getHeader();
208 assert(PreHeader && Header);
209 for (const PHINode &PN : Header->phis()) {
210 if (isa<ConstantInt>(Val: PN.getIncomingValueForBlock(BB: PreHeader)))
211 return true;
212 }
213 return false;
214}
215
216struct LoadValue {
217 Instruction *DefI = nullptr;
218 unsigned Generation = 0;
219 LoadValue() = default;
220 LoadValue(Instruction *Inst, unsigned Generation)
221 : DefI(Inst), Generation(Generation) {}
222};
223
224class StackNode {
225 ScopedHashTable<const SCEV *, LoadValue>::ScopeTy LoadScope;
226 unsigned CurrentGeneration;
227 unsigned ChildGeneration;
228 DomTreeNode *Node;
229 DomTreeNode::const_iterator ChildIter;
230 DomTreeNode::const_iterator EndIter;
231 bool Processed = false;
232
233public:
234 StackNode(ScopedHashTable<const SCEV *, LoadValue> &AvailableLoads,
235 unsigned cg, DomTreeNode *N, DomTreeNode::const_iterator Child,
236 DomTreeNode::const_iterator End)
237 : LoadScope(AvailableLoads), CurrentGeneration(cg), ChildGeneration(cg),
238 Node(N), ChildIter(Child), EndIter(End) {}
239 // Accessors.
240 unsigned currentGeneration() const { return CurrentGeneration; }
241 unsigned childGeneration() const { return ChildGeneration; }
242 void childGeneration(unsigned generation) { ChildGeneration = generation; }
243 DomTreeNode *node() { return Node; }
244 DomTreeNode::const_iterator childIter() const { return ChildIter; }
245
246 DomTreeNode *nextChild() {
247 DomTreeNode *Child = *ChildIter;
248 ++ChildIter;
249 return Child;
250 }
251
252 DomTreeNode::const_iterator end() const { return EndIter; }
253 bool isProcessed() const { return Processed; }
254 void process() { Processed = true; }
255};
256
257Value *getMatchingValue(LoadValue LV, LoadInst *LI, unsigned CurrentGeneration,
258 BatchAAResults &BAA,
259 function_ref<MemorySSA *()> GetMSSA) {
260 if (!LV.DefI)
261 return nullptr;
262 if (LV.DefI->getType() != LI->getType())
263 return nullptr;
264 if (LV.Generation != CurrentGeneration) {
265 MemorySSA *MSSA = GetMSSA();
266 if (!MSSA)
267 return nullptr;
268 auto *EarlierMA = MSSA->getMemoryAccess(I: LV.DefI);
269 MemoryAccess *LaterDef =
270 MSSA->getWalker()->getClobberingMemoryAccess(I: LI, AA&: BAA);
271 if (!MSSA->dominates(A: LaterDef, B: EarlierMA))
272 return nullptr;
273 }
274 return LV.DefI;
275}
276
277void loadCSE(Loop *L, DominatorTree &DT, ScalarEvolution &SE, LoopInfo &LI,
278 BatchAAResults &BAA, function_ref<MemorySSA *()> GetMSSA) {
279 ScopedHashTable<const SCEV *, LoadValue> AvailableLoads;
280 SmallVector<std::unique_ptr<StackNode>> NodesToProcess;
281 DomTreeNode *HeaderD = DT.getNode(BB: L->getHeader());
282 NodesToProcess.emplace_back(Args: new StackNode(AvailableLoads, 0, HeaderD,
283 HeaderD->begin(), HeaderD->end()));
284
285 unsigned CurrentGeneration = 0;
286 while (!NodesToProcess.empty()) {
287 StackNode *NodeToProcess = &*NodesToProcess.back();
288
289 CurrentGeneration = NodeToProcess->currentGeneration();
290
291 if (!NodeToProcess->isProcessed()) {
292 // Process the node.
293
294 // If this block has a single predecessor, then the predecessor is the
295 // parent
296 // of the domtree node and all of the live out memory values are still
297 // current in this block. If this block has multiple predecessors, then
298 // they could have invalidated the live-out memory values of our parent
299 // value. For now, just be conservative and invalidate memory if this
300 // block has multiple predecessors.
301 if (!NodeToProcess->node()->getBlock()->getSinglePredecessor())
302 ++CurrentGeneration;
303 for (auto &I : make_early_inc_range(Range&: *NodeToProcess->node()->getBlock())) {
304
305 auto *Load = dyn_cast<LoadInst>(Val: &I);
306 if (!Load || !Load->isSimple()) {
307 if (I.mayWriteToMemory())
308 CurrentGeneration++;
309 continue;
310 }
311
312 const SCEV *PtrSCEV = SE.getSCEV(V: Load->getPointerOperand());
313 LoadValue LV = AvailableLoads.lookup(Key: PtrSCEV);
314 if (Value *M =
315 getMatchingValue(LV, LI: Load, CurrentGeneration, BAA, GetMSSA)) {
316 if (LI.replacementPreservesLCSSAForm(From: Load, To: M)) {
317 Load->replaceAllUsesWith(V: M);
318 Load->eraseFromParent();
319 }
320 } else {
321 AvailableLoads.insert(Key: PtrSCEV, Val: LoadValue(Load, CurrentGeneration));
322 }
323 }
324 NodeToProcess->childGeneration(generation: CurrentGeneration);
325 NodeToProcess->process();
326 } else if (NodeToProcess->childIter() != NodeToProcess->end()) {
327 // Push the next child onto the stack.
328 DomTreeNode *Child = NodeToProcess->nextChild();
329 if (!L->contains(BB: Child->getBlock()))
330 continue;
331 NodesToProcess.emplace_back(
332 Args: new StackNode(AvailableLoads, NodeToProcess->childGeneration(), Child,
333 Child->begin(), Child->end()));
334 } else {
335 // It has been processed, and there are no more children to process,
336 // so delete it and pop it off the stack.
337 NodesToProcess.pop_back();
338 }
339 }
340}
341
342/// Perform some cleanup and simplifications on loops after unrolling. It is
343/// useful to simplify the IV's in the new loop, as well as do a quick
344/// simplify/dce pass of the instructions.
345void llvm::simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI,
346 ScalarEvolution *SE, DominatorTree *DT,
347 AssumptionCache *AC,
348 const TargetTransformInfo *TTI,
349 ArrayRef<BasicBlock *> Blocks,
350 AAResults *AA) {
351 using namespace llvm::PatternMatch;
352
353 // Simplify any new induction variables in the partially unrolled loop.
354 if (SE && SimplifyIVs) {
355 SmallVector<WeakTrackingVH, 16> DeadInsts;
356 simplifyLoopIVs(L, SE, DT, LI, TTI, Dead&: DeadInsts);
357
358 // Aggressively clean up dead instructions that simplifyLoopIVs already
359 // identified. Any remaining should be cleaned up below.
360 while (!DeadInsts.empty()) {
361 Value *V = DeadInsts.pop_back_val();
362 if (Instruction *Inst = dyn_cast_or_null<Instruction>(Val: V))
363 RecursivelyDeleteTriviallyDeadInstructions(V: Inst);
364 }
365
366 if (AA) {
367 std::unique_ptr<MemorySSA> MSSA = nullptr;
368 BatchAAResults BAA(*AA);
369 loadCSE(L, DT&: *DT, SE&: *SE, LI&: *LI, BAA, GetMSSA: [L, AA, DT, &MSSA]() -> MemorySSA * {
370 if (!MSSA)
371 MSSA.reset(p: new MemorySSA(*L, AA, DT));
372 return &*MSSA;
373 });
374 }
375 }
376
377 // At this point, the code is well formed. Perform constprop, instsimplify,
378 // and dce.
379 SmallVector<WeakTrackingVH, 16> DeadInsts;
380 for (BasicBlock *BB : Blocks) {
381 // Remove repeated debug instructions after loop unrolling.
382 if (BB->getParent()->getSubprogram())
383 RemoveRedundantDbgInstrs(BB);
384
385 for (Instruction &Inst : llvm::make_early_inc_range(Range&: *BB)) {
386 if (Value *V = simplifyInstruction(
387 I: &Inst, Q: {BB->getDataLayout(), nullptr, DT, AC}))
388 if (LI->replacementPreservesLCSSAForm(From: &Inst, To: V))
389 Inst.replaceAllUsesWith(V);
390 if (isInstructionTriviallyDead(I: &Inst))
391 DeadInsts.emplace_back(Args: &Inst);
392
393 // Fold ((add X, C1), C2) to (add X, C1+C2). This is very common in
394 // unrolled loops, and handling this early allows following code to
395 // identify the IV as a "simple recurrence" without first folding away
396 // a long chain of adds.
397 {
398 Value *X;
399 const APInt *C1, *C2;
400 if (match(V: &Inst, P: m_Add(L: m_Add(L: m_Value(V&: X), R: m_APInt(Res&: C1)), R: m_APInt(Res&: C2)))) {
401 auto *InnerI = dyn_cast<Instruction>(Val: Inst.getOperand(i: 0));
402 auto *InnerOBO = cast<OverflowingBinaryOperator>(Val: Inst.getOperand(i: 0));
403 bool SignedOverflow;
404 APInt NewC = C1->sadd_ov(RHS: *C2, Overflow&: SignedOverflow);
405 Inst.setOperand(i: 0, Val: X);
406 Inst.setOperand(i: 1, Val: ConstantInt::get(Ty: Inst.getType(), V: NewC));
407 Inst.setHasNoUnsignedWrap(Inst.hasNoUnsignedWrap() &&
408 InnerOBO->hasNoUnsignedWrap());
409 Inst.setHasNoSignedWrap(Inst.hasNoSignedWrap() &&
410 InnerOBO->hasNoSignedWrap() &&
411 !SignedOverflow);
412 if (InnerI && isInstructionTriviallyDead(I: InnerI))
413 DeadInsts.emplace_back(Args&: InnerI);
414 }
415 }
416 }
417 // We can't do recursive deletion until we're done iterating, as we might
418 // have a phi which (potentially indirectly) uses instructions later in
419 // the block we're iterating through.
420 RecursivelyDeleteTriviallyDeadInstructions(DeadInsts);
421 }
422}
423
424// Loops containing convergent instructions that are uncontrolled or controlled
425// from outside the loop must have a count that divides their TripMultiple.
426LLVM_ATTRIBUTE_USED
427static bool canHaveUnrollRemainder(const Loop *L) {
428 if (getLoopConvergenceHeart(TheLoop: L))
429 return false;
430
431 // Check for uncontrolled convergent operations.
432 for (auto &BB : L->blocks()) {
433 for (auto &I : *BB) {
434 if (isa<ConvergenceControlInst>(Val: I))
435 return true;
436 if (auto *CB = dyn_cast<CallBase>(Val: &I))
437 if (CB->isConvergent())
438 return CB->getConvergenceControlToken();
439 }
440 }
441 return true;
442}
443
444// If LoopUnroll has proven OriginalLoopProb is incorrect for some iterations
445// of the original loop, adjust latch probabilities in the unrolled loop to
446// maintain the original total frequency of the original loop body.
447//
448// OriginalLoopProb is practical but imprecise
449// -------------------------------------------
450//
451// The latch branch weights that LLVM originally adds to a loop encode one latch
452// probability, OriginalLoopProb, applied uniformly across the loop's infinite
453// set of theoretically possible iterations. While this uniform latch
454// probability serves as a practical statistic summarizing the trip counts
455// observed during profiling, it is imprecise. Specifically, unless it is zero,
456// it is impossible for it to be the actual probability observed at every
457// individual iteration. To see why, consider that the only way to actually
458// observe at run time that the latch probability remains non-zero is to profile
459// at least one loop execution that has an infinite number of iterations. I do
460// not know how to profile an infinite number of loop iterations, and most loops
461// I work with are always finite.
462//
463// LoopUnroll proves OriginalLoopProb is incorrect
464// ------------------------------------------------
465//
466// LoopUnroll reorganizes the original loop so that loop iterations are no
467// longer all implemented by the same code, and then it analyzes some of those
468// loop iteration implementations independently of others. In particular, it
469// converts some of their conditional latches to unconditional. That is, by
470// examining code structure without any profile data, LoopUnroll proves that the
471// actual latch probability at the end of such an iteration is either 1 or 0.
472// When an individual iteration's actual latch probability is 1 or 0, that means
473// it always behaves the same, so it is impossible to observe it as having any
474// other probability. The original uniform latch probability is rarely 1 or 0
475// because, when applied to all possible iterations, that would yield an
476// estimated trip count of infinity or 1, respectively.
477//
478// Thus, the new probabilities of 1 or 0 are proven corrections to
479// OriginalLoopProb for individual iterations in the original loop. However,
480// LoopUnroll often is able to perform these corrections for only some
481// iterations, leaving other iterations with OriginalLoopProb, and thus
482// corrupting the aggregate effect on the total frequency of the original loop
483// body.
484//
485// Adjusting latch probabilities
486// -----------------------------
487//
488// This function ensures that the total frequency of the original loop body,
489// summed across all its occurrences in the unrolled loop after the
490// aforementioned latch conversions, is the same as in the original loop. To do
491// so, it adjusts probabilities on the remaining conditional latches. However,
492// it cannot derive the new probabilities directly from the original uniform
493// latch probability because the latter has been proven incorrect for some
494// original loop iterations.
495//
496// There are often many sets of latch probabilities that can produce the
497// original total loop body frequency. If there are many remaining conditional
498// latches, this function just quickly hacks a few of their probabilities to
499// restore the original total loop body frequency. Otherwise, it determines
500// less arbitrary probabilities.
501static void fixProbContradiction(Loop *L, UnrollLoopOptions ULO,
502 OptimizationRemarkEmitter *ORE,
503 BranchProbability OriginalLoopProb,
504 bool CompletelyUnroll,
505 std::vector<unsigned> &IterCounts,
506 const std::vector<BasicBlock *> &CondLatches,
507 std::vector<BasicBlock *> &CondLatchNexts) {
508 // Runtime unrolling is handled later in LoopUnroll not here.
509 //
510 // There are two scenarios in which LoopUnroll sets ProbUpdateRequired to true
511 // because it needs to update probabilities that were originally
512 // OriginalLoopProb, but only in one scenario has LoopUnroll proven
513 // OriginalLoopProb incorrect for iterations within the original loop:
514 // - If ULO.Runtime, LoopUnroll adds new guards that enforce new reaching
515 // conditions for new loop iteration implementations (e.g., one unrolled
516 // loop iteration executes only if at least ULO.Count original loop
517 // iterations remain). Those reaching conditions dictate how conditional
518 // latches can be converted to unconditional (e.g., within an unrolled loop
519 // iteration, there is no need to recheck the number of remaining original
520 // loop iterations). None of this reorganization alters the set of possible
521 // original loop iteration counts or proves OriginalLoopProb incorrect for
522 // any of the original loop iterations. Thus, LoopUnroll derives
523 // probabilities for the new guards and latches directly from
524 // OriginalLoopProb based on the probabilities that their reaching
525 // conditions would occur in the original loop. Doing so maintains the
526 // total frequency of the original loop body.
527 // - If !ULO.Runtime, LoopUnroll initially adds new loop iteration
528 // implementations, which have the same latch probabilities as in the
529 // original loop because there are no new guards that change their reaching
530 // conditions. Sometimes, LoopUnroll is then done, and so does not set
531 // ProbUpdateRequired to true. Other times, LoopUnroll then proves that
532 // some latches are unconditional, directly contradicting OriginalLoopProb
533 // for the corresponding original loop iterations. That reduces the set of
534 // possible original loop iteration counts, possibly producing a finite set
535 // if it manages to eliminate the backedge. LoopUnroll has to choose a new
536 // set of latch probabilities that produce the same total loop body
537 // frequency.
538 //
539 // This function addresses the second scenario only.
540 if (ULO.Runtime)
541 return;
542
543 // If CondLatches.empty(), there are no latch branches with probabilities we
544 // can adjust. That should mean that the actual trip count is always exactly
545 // the number of remaining unrolled iterations, and so OriginalLoopProb should
546 // have yielded that trip count as the original loop body frequency. Of
547 // course, OriginalLoopProb could be based on inaccurate profile data, but
548 // there is nothing we can do about that here.
549 if (CondLatches.empty())
550 return;
551
552 // If the original latch probability is 1, the original frequency is infinity.
553 // Leaving all remaining probabilities set to 1 might or might not get us
554 // there (e.g., a completely unrolled loop cannot be infinite), but it is the
555 // closest we can come.
556 assert(!OriginalLoopProb.isUnknown() &&
557 "Expected to have loop probability to fix");
558 if (OriginalLoopProb.isOne())
559 return;
560
561 // FreqDesired is the frequency implied by the original loop probability.
562 double FreqDesired = 1 / (1 - OriginalLoopProb.toDouble());
563
564 // Get the probability at CondLatches[I].
565 auto GetProb = [&](unsigned I) {
566 CondBrInst *B = cast<CondBrInst>(Val: CondLatches[I]->getTerminator());
567 bool FirstTargetIsNext = B->getSuccessor(i: 0) == CondLatchNexts[I];
568 return getBranchProbability(B, ForFirstTarget: FirstTargetIsNext).toDouble();
569 };
570
571 // Set the probability at CondLatches[I] to Prob.
572 auto SetProb = [&](unsigned I, double Prob) {
573 CondBrInst *B = cast<CondBrInst>(Val: CondLatches[I]->getTerminator());
574 bool FirstTargetIsNext = B->getSuccessor(i: 0) == CondLatchNexts[I];
575 setBranchProbability(B, P: BranchProbability::getBranchProbability(Prob),
576 ForFirstTarget: FirstTargetIsNext);
577 };
578
579 // Set all probabilities in CondLatches to Prob.
580 auto SetAllProbs = [&](double Prob) {
581 for (unsigned I = 0, E = CondLatches.size(); I < E; ++I)
582 SetProb(I, Prob);
583 };
584
585 // If n <= 2, we choose the simplest probability model we can think of: every
586 // remaining conditional branch instruction has the same probability, Prob,
587 // of continuing to the next iteration. This model has several helpful
588 // properties:
589 // - We have no reason to think one latch branch's probability should be
590 // higher or lower than another, and so this model makes them all the same.
591 // In the worst cases, we thus avoid setting just some probabilities to 0 or
592 // 1, which can unrealistically make some code appear unreachable. There
593 // are cases where they *all* must become 0 or 1 to achieve the total
594 // frequency of original loop body, and our model does permit that.
595 // - The frequency, FreqOne, of the original loop body in a single iteration
596 // of the unrolled loop is computed by a simple polynomial, where p=Prob,
597 // n=CondLatches.size(), and c_i=IterCounts[i]:
598 //
599 // FreqOne = Sum(i=0..n)(c_i * p^i)
600 //
601 // - If the backedge has been eliminated, FreqOne is the total frequency of
602 // the original loop body in the unrolled loop.
603 // - If the backedge remains, Sum(i=0..inf)(FreqOne * p^(n*i)) =
604 // FreqOne / (1 - p^n) is the total frequency of the original loop body in
605 // the unrolled loop, regardless of whether the backedge is conditional or
606 // unconditional.
607 // - For n <= 2, we can use simple formulas to solve the above polynomial
608 // equations exactly for p without performing a search.
609
610 // When iterating for a solution, we stop early if we find probabilities
611 // that produce a Freq whose difference from FreqDesired is small
612 // (FreqPrec). Otherwise, we expect to compute a solution at least that
613 // accurate (but surely far more accurate).
614 const double FreqPrec = 1e-6;
615
616 // Compute the probability that, used at CondLaches[0] where
617 // CondLatches.size() == 1, gets as close as possible to FreqDesired.
618 auto ComputeProbForLinear = [&]() {
619 // The polynomial is linear (0 = A*p + B), so just solve it.
620 double A = IterCounts[1] + (CompletelyUnroll ? 0 : FreqDesired);
621 double B = IterCounts[0] - FreqDesired;
622 assert(A > 0 && "Expected iterations after last conditional latch");
623 double Prob = -B / A;
624 Prob = std::max(a: Prob, b: 0.);
625 Prob = std::min(a: Prob, b: 1.);
626 return Prob;
627 };
628
629 // Compute the probability that, used throughout CondLatches where
630 // CondLatches.size() == 2, gets as close as possible to FreqDesired.
631 auto ComputeProbForQuadratic = [&]() {
632 // The polynomial is quadratic (0 = A*p^2 + B*p + C), so just solve it.
633 double A = IterCounts[2] + (CompletelyUnroll ? 0 : FreqDesired);
634 double B = IterCounts[1];
635 double C = IterCounts[0] - FreqDesired;
636 assert(A > 0 && "Expected iterations after last conditional latch");
637 double Prob = (-B + sqrt(x: B * B - 4 * A * C)) / (2 * A);
638 Prob = std::max(a: Prob, b: 0.);
639 Prob = std::min(a: Prob, b: 1.);
640 return Prob;
641 };
642
643 // Adjust the probability at CondLatches[ComputeIdx] to get as close as
644 // possible to FreqDesired without replacing probabilities elsewhere in
645 // CondLatches. Return the new total frequency.
646 //
647 // Given a CondLatches index I, then for a single unrolled loop iteration:
648 // - ProbBefore or ProbAfter is the probability that control flow can pass
649 // through every CondLatches[J] for J < I or J > I, respectively.
650 // - FreqBefore or FreqAfter is the total frequency accumulated before or
651 // after CondLatches[I], respectively, while the probability at
652 // CondLatches[I] is treated as 1.
653 //
654 // If ComputeIdx == 0, then ComputeProb will set those values for I == 0 and
655 // ignore the current values. If ComputeIdx > 0, then it expects those values
656 // to already be set for I == ComputeIdx - 1, and it will set them for I ==
657 // ComputeIdx.
658 auto AdjustProb = [&](unsigned ComputeIdx, double &ProbBefore,
659 double &ProbAfter, double &FreqBefore,
660 double &FreqAfter) {
661 assert(ComputeIdx < CondLatches.size() &&
662 "Expected valid CondLatches index");
663
664 // Compute or update ProbBefore, ProbAfter, FreqBefore, and FreqAfter.
665 auto ComputeAfter = [&]() {
666 ProbAfter = 1;
667 FreqAfter = IterCounts[ComputeIdx + 1];
668 for (unsigned I = ComputeIdx + 1, E = CondLatches.size(); I < E; ++I) {
669 double Prob = GetProb(I);
670 ProbAfter *= Prob;
671 // After Prob == 0, ProbAfter and FreqAfter won't change, so save time.
672 if (Prob == 0)
673 break;
674 FreqAfter += IterCounts[I + 1] * ProbAfter;
675 }
676 };
677 if (ComputeIdx == 0) {
678 ProbBefore = 1;
679 FreqBefore = IterCounts[0];
680 ComputeAfter();
681 } else {
682 // Rather than iterating all of CondLatches again, we fix up the
683 // previously computed values.
684 double ProbOld = GetProb(ComputeIdx);
685 if (ProbOld > 0) {
686 FreqAfter -= IterCounts[ComputeIdx] * ProbBefore;
687 ProbAfter /= ProbOld;
688 FreqAfter /= ProbOld;
689 } else {
690 // We cannot divide out the old zero probability. We short-circuited
691 // the iteration at that zero in the previous ComputeAfter call, so now
692 // we pick up where we left off.
693 ComputeAfter();
694 }
695 ProbBefore *= GetProb(ComputeIdx - 1);
696 FreqBefore += IterCounts[ComputeIdx] * ProbBefore;
697 }
698
699 // Compute the required probability, and limit it to a valid probability (0
700 // <= p <= 1). See the FreqCompute formula below for how to derive the
701 // ProbCompute formula.
702 double ProbReachingBackedge = CompletelyUnroll ? 0 : ProbBefore * ProbAfter;
703 double ProbComputeNumerator = FreqDesired - FreqBefore;
704 double ProbComputeDenominator =
705 FreqAfter + FreqDesired * ProbReachingBackedge;
706 double ProbCompute = -1; // Init expected to be unused.
707 if (ProbComputeNumerator <= 0) {
708 // FreqBefore has already reached or surpassed FreqDesired, so add no more
709 // frequency. It is possible that ProbComputeDenominator == 0 here
710 // because some latch probability (maybe the original) was set to zero, so
711 // this check avoids setting ProbCompute=1 (in the else if below) and
712 // division by zero where the numerator <= 0 (in the else below).
713 ProbCompute = 0;
714 } else if (ProbComputeDenominator == 0) {
715 // Analytically, this case seems impossible. It would occur if either:
716 // - Both FreqAfter and FreqDesired are zero. But the latter would cause
717 // ProbComputeNumerator < 0, which we catch above, and FreqDesired
718 // should always be >= 1 anyway.
719 // - There are no iterations after CondLatches[ComputeIdx], not even via
720 // a backedge, so that both FreqAfter and ProbReachingBackedge are zero.
721 // But iterations should exist after even the last conditional latch.
722 // - Some latch probability (maybe the original) was set to zero so that
723 // both FreqAfter and ProbReachingBackedge are zero. But that should
724 // not have happened because, according to the above
725 // ProbComputeNumerator check, we have not yet reached FreqDesired
726 // (which, if the original latch probability is zero, is just 1 and thus
727 // always reached or surpassed).
728 //
729 // Numerically, perhaps this case is possible. We interpret it to mean we
730 // need more frequency (ProbComputeNumerator > 0) but have no way to get
731 // any (ProbComputeDenominator is analytically too small to distinguish it
732 // from 0 in floating point), suggesting infinite probability is needed,
733 // but 1 is the maximum valid probability and thus the best we can do.
734 //
735 // TODO: Cover this case in the test suite if you can.
736 ProbCompute = 1;
737 } else {
738 ProbCompute = ProbComputeNumerator / ProbComputeDenominator;
739 ProbCompute = std::max(a: ProbCompute, b: 0.);
740 ProbCompute = std::min(a: ProbCompute, b: 1.);
741 }
742 SetProb(ComputeIdx, ProbCompute);
743
744 // Compute the resulting total frequency.
745 double FreqCompute = -1; // Init expected to be unused.
746 if (ProbReachingBackedge * ProbCompute == 1) {
747 // Analytically, this case seems impossible. It requires that there is a
748 // backedge and that FreqDesired == infinity so that every conditional
749 // latch's probability had to be set to 1. But FreqDesired == infinity
750 // means OriginalLoopProb.isOne(), which we guarded against earlier.
751 //
752 // Numerically, perhaps this case is possible. We interpret it to mean
753 // that analytically the probability has to be so near 1 that, in floating
754 // point, the frequency is computed as infinite.
755 //
756 // TODO: Cover this case in the test suite if you can.
757 FreqCompute = std::numeric_limits<double>::infinity();
758 if (ORE) {
759 ORE->emit(RemarkBuilder: [&]() {
760 return OptimizationRemark(DEBUG_TYPE, "InfiniteFrequency",
761 L->getStartLoc(), L->getHeader());
762 });
763 }
764 } else {
765 assert(FreqBefore > 0 &&
766 "Expected at least one iteration before first latch");
767 // In this equation, if we replace the left-hand side with FreqDesired and
768 // then solve for ProbCompute, we get the ProbCompute formula above.
769 FreqCompute = (FreqBefore + FreqAfter * ProbCompute) /
770 (1 - ProbReachingBackedge * ProbCompute);
771 }
772 assert(FreqCompute > 0 && "Expected valid frequency");
773 return FreqCompute;
774 };
775
776 // Determine and set branch weights.
777 if (CondLatches.size() == 1) {
778 SetAllProbs(ComputeProbForLinear());
779 } else if (CondLatches.size() == 2) {
780 SetAllProbs(ComputeProbForQuadratic());
781 } else {
782 // The polynomial is too complex for a simple formula, so the quick and
783 // dirty fix has been selected. Adjust probabilities starting from the
784 // first latch, which has the most influence on the total frequency, so
785 // starting there should minimize the number of latches that have to be
786 // visited. We do have to iterate because the first latch alone might not
787 // be enough. For example, we might need to set all probabilities to 1 if
788 // the frequency is the unroll factor.
789 double ProbBefore = -1, ProbAfter = -1; // Inits expected to be unused.
790 double FreqBefore = -1, FreqAfter = -1; // Inits expected to be unused.
791 for (unsigned I = 0; I != CondLatches.size(); ++I) {
792 double Freq = AdjustProb(I, ProbBefore, ProbAfter, FreqBefore, FreqAfter);
793 if (fabs(x: Freq - FreqDesired) < FreqPrec)
794 break;
795 }
796 }
797
798 // FIXME: We have not considered non-latch loop exits:
799 // - Their original probabilities are not considered in our calculation of
800 // FreqDesired.
801 // - Their probabilities are not considered in our probability model used to
802 // determine new probabilities for remaining conditional branches.
803 // - If they are conditional and LoopUnroll converts them to unconditional,
804 // LoopUnroll has proven their original probabilities are incorrect for some
805 // original loop iterations, but that does not cause ProbUpdateRequired to
806 // be set to true.
807 //
808 // To adjust FreqDesired and our probability model correctly for a non-latch
809 // loop exit, we would need to compute the original probability that the exit
810 // is reached from the loop header (in contrast, we currently assume that
811 // probability is 1 in the case of a latch exit) and the probability that the
812 // exit is taken if it is conditional (use the branch's old or new weights for
813 // FreqDesired or the probability model, respectively). Does computing the
814 // reaching probability require a CFG traversal, or is there some existing
815 // library that can do it? Prior discussions suggest some such libraries are
816 // difficult to use within LoopUnroll:
817 // <https://github.com/llvm/llvm-project/pull/164799#issuecomment-3438681519>.
818 // For now, we just let our corrected probabilities be less accurate in that
819 // scenario. Alternatively, we could refuse to correct probabilities at all
820 // in that scenario, but that seems worse.
821}
822
823/// Unroll the given loop by Count. The loop must be in LCSSA form. Unrolling
824/// can only fail when the loop's latch block is not terminated by a conditional
825/// branch instruction. However, if the trip count (and multiple) are not known,
826/// loop unrolling will mostly produce more code that is no faster.
827///
828/// If Runtime is true then UnrollLoop will try to insert a prologue or
829/// epilogue that ensures the latch has a trip multiple of Count. UnrollLoop
830/// will not runtime-unroll the loop if computing the run-time trip count will
831/// be expensive and AllowExpensiveTripCount is false.
832///
833/// The LoopInfo Analysis that is passed will be kept consistent.
834///
835/// This utility preserves LoopInfo. It will also preserve ScalarEvolution and
836/// DominatorTree if they are non-null.
837///
838/// If RemainderLoop is non-null, it will receive the remainder loop (if
839/// required and not fully unrolled).
840LoopUnrollResult
841llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
842 ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC,
843 const TargetTransformInfo *TTI, OptimizationRemarkEmitter *ORE,
844 bool PreserveLCSSA, Loop **RemainderLoop, AAResults *AA) {
845 assert(DT && "DomTree is required");
846
847 if (!L->getLoopPreheader()) {
848 LLVM_DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n");
849 return LoopUnrollResult::Unmodified;
850 }
851
852 if (!L->getLoopLatch()) {
853 LLVM_DEBUG(dbgs() << " Can't unroll; loop exit-block-insertion failed.\n");
854 return LoopUnrollResult::Unmodified;
855 }
856
857 // Loops with indirectbr cannot be cloned.
858 if (!L->isSafeToClone()) {
859 LLVM_DEBUG(dbgs() << " Can't unroll; Loop body cannot be cloned.\n");
860 return LoopUnrollResult::Unmodified;
861 }
862
863 if (L->getHeader()->hasAddressTaken()) {
864 // The loop-rotate pass can be helpful to avoid this in many cases.
865 LLVM_DEBUG(
866 dbgs() << " Won't unroll loop: address of header block is taken.\n");
867 return LoopUnrollResult::Unmodified;
868 }
869
870 assert(ULO.Count > 0);
871
872 // All these values should be taken only after peeling because they might have
873 // changed.
874 BasicBlock *Preheader = L->getLoopPreheader();
875 BasicBlock *Header = L->getHeader();
876 BasicBlock *LatchBlock = L->getLoopLatch();
877 SmallVector<BasicBlock *, 4> ExitBlocks;
878 L->getExitBlocks(ExitBlocks);
879 std::vector<BasicBlock *> OriginalLoopBlocks = L->getBlocks();
880
881 const unsigned MaxTripCount = SE->getSmallConstantMaxTripCount(L);
882 const bool MaxOrZero = SE->isBackedgeTakenCountMaxOrZero(L);
883 std::optional<unsigned> OriginalTripCount =
884 llvm::getLoopEstimatedTripCount(L);
885 BranchProbability OriginalLoopProb = llvm::getLoopProbability(L);
886
887 // Effectively "DCE" unrolled iterations that are beyond the max tripcount
888 // and will never be executed.
889 if (MaxTripCount && ULO.Count > MaxTripCount)
890 ULO.Count = MaxTripCount;
891
892 struct ExitInfo {
893 unsigned TripCount;
894 unsigned TripMultiple;
895 unsigned BreakoutTrip;
896 bool ExitOnTrue;
897 BasicBlock *FirstExitingBlock = nullptr;
898 SmallVector<BasicBlock *> ExitingBlocks;
899 };
900 MapVector<BasicBlock *, ExitInfo> ExitInfos;
901 SmallVector<BasicBlock *, 4> ExitingBlocks;
902 L->getExitingBlocks(ExitingBlocks);
903 for (auto *ExitingBlock : ExitingBlocks) {
904 // The folding code is not prepared to deal with non-branch instructions
905 // right now.
906 auto *BI = dyn_cast<CondBrInst>(Val: ExitingBlock->getTerminator());
907 if (!BI)
908 continue;
909
910 ExitInfo &Info = ExitInfos[ExitingBlock];
911 Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
912 Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
913 if (Info.TripCount != 0) {
914 Info.BreakoutTrip = Info.TripCount % ULO.Count;
915 Info.TripMultiple = 0;
916 } else {
917 Info.BreakoutTrip = Info.TripMultiple =
918 (unsigned)std::gcd(m: ULO.Count, n: Info.TripMultiple);
919 }
920 Info.ExitOnTrue = !L->contains(BB: BI->getSuccessor(i: 0));
921 Info.ExitingBlocks.push_back(Elt: ExitingBlock);
922 LLVM_DEBUG(dbgs() << " Exiting block %" << ExitingBlock->getName()
923 << ": TripCount=" << Info.TripCount
924 << ", TripMultiple=" << Info.TripMultiple
925 << ", BreakoutTrip=" << Info.BreakoutTrip << "\n");
926 }
927
928 // Are we eliminating the loop control altogether? Note that we can know
929 // we're eliminating the backedge without knowing exactly which iteration
930 // of the unrolled body exits.
931 const bool CompletelyUnroll = ULO.Count == MaxTripCount;
932
933 const bool PreserveOnlyFirst = CompletelyUnroll && MaxOrZero;
934
935 // There's no point in performing runtime unrolling if this unroll count
936 // results in a full unroll.
937 if (CompletelyUnroll)
938 ULO.Runtime = false;
939
940 // Go through all exits of L and see if there are any phi-nodes there. We just
941 // conservatively assume that they're inserted to preserve LCSSA form, which
942 // means that complete unrolling might break this form. We need to either fix
943 // it in-place after the transformation, or entirely rebuild LCSSA. TODO: For
944 // now we just recompute LCSSA for the outer loop, but it should be possible
945 // to fix it in-place.
946 bool NeedToFixLCSSA =
947 PreserveLCSSA && CompletelyUnroll &&
948 any_of(Range&: ExitBlocks,
949 P: [](const BasicBlock *BB) { return isa<PHINode>(Val: BB->begin()); });
950
951 // The current loop unroll pass can unroll loops that have
952 // (1) single latch; and
953 // (2a) latch is unconditional; or
954 // (2b) latch is conditional and is an exiting block
955 // FIXME: The implementation can be extended to work with more complicated
956 // cases, e.g. loops with multiple latches.
957 Instruction *LatchTerm = LatchBlock->getTerminator();
958
959 // A conditional branch which exits the loop, which can be optimized to an
960 // unconditional branch in the unrolled loop in some cases.
961 bool LatchIsExiting = L->isLoopExiting(BB: LatchBlock);
962 if (!isa<UncondBrInst>(Val: LatchTerm) &&
963 !(isa<CondBrInst>(Val: LatchTerm) && LatchIsExiting)) {
964 LLVM_DEBUG(
965 dbgs() << "Can't unroll; a conditional latch must exit the loop");
966 return LoopUnrollResult::Unmodified;
967 }
968
969 bool EpilogProfitability =
970 UnrollRuntimeEpilog.getNumOccurrences() ? UnrollRuntimeEpilog
971 : isEpilogProfitable(L);
972
973 if (ULO.Runtime &&
974 !UnrollRuntimeLoopRemainder(
975 L, Count: ULO.Count, AllowExpensiveTripCount: ULO.AllowExpensiveTripCount, UseEpilogRemainder: EpilogProfitability,
976 UnrollRemainder: ULO.UnrollRemainder, ForgetAllSCEV: ULO.ForgetAllSCEV, LI, SE, DT, AC, TTI,
977 PreserveLCSSA, SCEVExpansionBudget: ULO.SCEVExpansionBudget, RuntimeUnrollMultiExit: ULO.RuntimeUnrollMultiExit,
978 ResultLoop: RemainderLoop, OriginalTripCount, OriginalLoopProb)) {
979 if (ULO.Force)
980 ULO.Runtime = false;
981 else {
982 LLVM_DEBUG(dbgs() << "Won't unroll; remainder loop could not be "
983 "generated when assuming runtime trip count\n");
984 return LoopUnrollResult::Unmodified;
985 }
986 }
987
988 using namespace ore;
989
990 // Determine whether this loop originated from the vectorizer so we can
991 // produce more informative remarks.
992 StringRef LoopKind = getLoopVectorizeKindPrefix(L);
993
994 // Report the unrolling decision.
995 if (CompletelyUnroll) {
996 LLVM_DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName()
997 << " with trip count " << ULO.Count << "!\n");
998 if (ORE)
999 ORE->emit(RemarkBuilder: [&]() {
1000 return OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(),
1001 L->getHeader())
1002 << "completely unrolled " + LoopKind.str() + "loop with "
1003 << NV("UnrollCount", ULO.Count) << " iterations";
1004 });
1005 } else {
1006 LLVM_DEBUG({
1007 dbgs() << "UNROLLING loop %" << Header->getName() << " by " << ULO.Count;
1008 if (ULO.Runtime) {
1009 dbgs() << " with run-time trip count";
1010 if (ULO.UnrollRemainder)
1011 dbgs() << " (remainder unrolled)";
1012 }
1013 dbgs() << "!\n";
1014 });
1015
1016 if (ORE)
1017 ORE->emit(RemarkBuilder: [&]() {
1018 OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(),
1019 L->getHeader());
1020 Diag << "unrolled " + LoopKind.str() + "loop by a factor of "
1021 << NV("UnrollCount", ULO.Count);
1022 if (ULO.Runtime)
1023 Diag << " with run-time trip count"
1024 << (ULO.UnrollRemainder ? " (remainder unrolled)" : "");
1025 return Diag;
1026 });
1027 }
1028
1029 // We are going to make changes to this loop. SCEV may be keeping cached info
1030 // about it, in particular about backedge taken count. The changes we make
1031 // are guaranteed to invalidate this information for our loop. It is tempting
1032 // to only invalidate the loop being unrolled, but it is incorrect as long as
1033 // all exiting branches from all inner loops have impact on the outer loops,
1034 // and if something changes inside them then any of outer loops may also
1035 // change. When we forget outermost loop, we also forget all contained loops
1036 // and this is what we need here.
1037 if (SE) {
1038 if (ULO.ForgetAllSCEV)
1039 SE->forgetAllLoops();
1040 else {
1041 SE->forgetTopmostLoop(L);
1042 SE->forgetBlockAndLoopDispositions();
1043 }
1044 }
1045
1046 if (!LatchIsExiting)
1047 ++NumUnrolledNotLatch;
1048
1049 // For the first iteration of the loop, we should use the precloned values for
1050 // PHI nodes. Insert associations now.
1051 ValueToValueMapTy LastValueMap;
1052 std::vector<PHINode*> OrigPHINode;
1053 for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(Val: I); ++I) {
1054 OrigPHINode.push_back(x: cast<PHINode>(Val&: I));
1055 }
1056
1057 // Collect phi nodes for reductions for which we can introduce multiple
1058 // parallel reduction phis and compute the final reduction result after the
1059 // loop. This requires a single exit block after unrolling. This is ensured by
1060 // restricting to single-block loops where the unrolled iterations are known
1061 // to not exit.
1062 DenseMap<PHINode *, RecurrenceDescriptor> Reductions;
1063 bool CanAddAdditionalAccumulators =
1064 (UnrollAddParallelReductions.getNumOccurrences() > 0
1065 ? UnrollAddParallelReductions
1066 : ULO.AddAdditionalAccumulators) &&
1067 !CompletelyUnroll && L->getNumBlocks() == 1 &&
1068 (ULO.Runtime ||
1069 (ExitInfos.contains(Key: Header) && ((ExitInfos[Header].TripCount != 0 &&
1070 ExitInfos[Header].BreakoutTrip == 0))));
1071
1072 // Limit parallelizing reductions to unroll counts of 4 or less for now.
1073 // TODO: The number of parallel reductions should depend on the number of
1074 // execution units. We also don't have to add a parallel reduction phi per
1075 // unrolled iteration, but could for example add a parallel phi for every 2
1076 // unrolled iterations.
1077 if (CanAddAdditionalAccumulators && ULO.Count <= 4) {
1078 for (PHINode &Phi : Header->phis()) {
1079 auto RdxDesc = canParallelizeReductionWhenUnrolling(Phi, L, SE);
1080 if (!RdxDesc)
1081 continue;
1082
1083 // Only handle duplicate phis for a single reduction for now.
1084 // TODO: Handle any number of reductions
1085 if (!Reductions.empty())
1086 continue;
1087
1088 Reductions[&Phi] = *RdxDesc;
1089 }
1090 }
1091
1092 std::vector<BasicBlock *> Headers;
1093 std::vector<BasicBlock *> Latches;
1094 Headers.push_back(x: Header);
1095 Latches.push_back(x: LatchBlock);
1096
1097 // The current on-the-fly SSA update requires blocks to be processed in
1098 // reverse postorder so that LastValueMap contains the correct value at each
1099 // exit.
1100 LoopBlocksDFS DFS(L);
1101 DFS.perform(LI);
1102
1103 // Stash the DFS iterators before adding blocks to the loop.
1104 LoopBlocksDFS::RPOIterator BlockBegin = DFS.beginRPO();
1105 LoopBlocksDFS::RPOIterator BlockEnd = DFS.endRPO();
1106
1107 std::vector<BasicBlock*> UnrolledLoopBlocks = L->getBlocks();
1108
1109 // Loop Unrolling might create new loops. While we do preserve LoopInfo, we
1110 // might break loop-simplified form for these loops (as they, e.g., would
1111 // share the same exit blocks). We'll keep track of loops for which we can
1112 // break this so that later we can re-simplify them.
1113 SmallSetVector<Loop *, 4> LoopsToSimplify;
1114 LoopsToSimplify.insert_range(R&: *L);
1115
1116 // When a FSDiscriminator is enabled, we don't need to add the multiply
1117 // factors to the discriminators.
1118 if (Header->getParent()->shouldEmitDebugInfoForProfiling() &&
1119 !EnableFSDiscriminator)
1120 for (BasicBlock *BB : L->getBlocks())
1121 for (Instruction &I : *BB)
1122 if (!I.isDebugOrPseudoInst())
1123 if (const DILocation *DIL = I.getDebugLoc()) {
1124 auto NewDIL = DIL->cloneByMultiplyingDuplicationFactor(DF: ULO.Count);
1125 if (NewDIL)
1126 I.setDebugLoc(*NewDIL);
1127 else
1128 LLVM_DEBUG(dbgs()
1129 << "Failed to create new discriminator: "
1130 << DIL->getFilename() << " Line: " << DIL->getLine());
1131 }
1132
1133 // Identify what noalias metadata is inside the loop: if it is inside the
1134 // loop, the associated metadata must be cloned for each iteration.
1135 SmallVector<MDNode *, 6> LoopLocalNoAliasDeclScopes;
1136 identifyNoAliasScopesToClone(BBs: L->getBlocks(), NoAliasDeclScopes&: LoopLocalNoAliasDeclScopes);
1137
1138 // We place the unrolled iterations immediately after the original loop
1139 // latch. This is a reasonable default placement if we don't have block
1140 // frequencies, and if we do, well the layout will be adjusted later.
1141 auto BlockInsertPt = std::next(x: LatchBlock->getIterator());
1142 SmallVector<Instruction *> PartialReductions;
1143 for (unsigned It = 1; It != ULO.Count; ++It) {
1144 SmallVector<BasicBlock *, 8> NewBlocks;
1145 SmallDenseMap<const Loop *, Loop *, 4> NewLoops;
1146 NewLoops[L] = L;
1147
1148 for (LoopBlocksDFS::RPOIterator BB = BlockBegin; BB != BlockEnd; ++BB) {
1149 ValueToValueMapTy VMap;
1150 BasicBlock *New = CloneBasicBlock(BB: *BB, VMap, NameSuffix: "." + Twine(It));
1151 Header->getParent()->insert(Position: BlockInsertPt, BB: New);
1152
1153 assert((*BB != Header || LI->getLoopFor(*BB) == L) &&
1154 "Header should not be in a sub-loop");
1155 // Tell LI about New.
1156 const Loop *OldLoop = addClonedBlockToLoopInfo(OriginalBB: *BB, ClonedBB: New, LI, NewLoops);
1157 if (OldLoop)
1158 LoopsToSimplify.insert(X: NewLoops[OldLoop]);
1159
1160 if (*BB == Header) {
1161 // Loop over all of the PHI nodes in the block, changing them to use
1162 // the incoming values from the previous block.
1163 for (PHINode *OrigPHI : OrigPHINode) {
1164 PHINode *NewPHI = cast<PHINode>(Val&: VMap[OrigPHI]);
1165 Value *InVal = NewPHI->getIncomingValueForBlock(BB: LatchBlock);
1166
1167 // Use cloned phis as parallel phis for partial reductions, which will
1168 // get combined to the final reduction result after the loop.
1169 if (Reductions.contains(Val: OrigPHI)) {
1170 // Collect partial reduction results.
1171 if (PartialReductions.empty())
1172 PartialReductions.push_back(Elt: cast<Instruction>(Val: InVal));
1173 PartialReductions.push_back(Elt: cast<Instruction>(Val&: VMap[InVal]));
1174
1175 // Update the start value for the cloned phis to use the identity
1176 // value for the reduction.
1177 const RecurrenceDescriptor &RdxDesc = Reductions[OrigPHI];
1178 NewPHI->setIncomingValueForBlock(
1179 BB: L->getLoopPreheader(),
1180 V: getRecurrenceIdentity(K: RdxDesc.getRecurrenceKind(),
1181 Tp: OrigPHI->getType(),
1182 FMF: RdxDesc.getFastMathFlags()));
1183
1184 // Update NewPHI to use the cloned value for the iteration and move
1185 // to header.
1186 NewPHI->replaceUsesOfWith(From: InVal, To: VMap[InVal]);
1187 NewPHI->moveBefore(InsertPos: OrigPHI->getIterator());
1188 continue;
1189 }
1190
1191 if (Instruction *InValI = dyn_cast<Instruction>(Val: InVal))
1192 if (It > 1 && L->contains(Inst: InValI))
1193 InVal = LastValueMap[InValI];
1194 VMap[OrigPHI] = InVal;
1195 NewPHI->eraseFromParent();
1196 }
1197
1198 // Eliminate copies of the loop heart intrinsic, if any.
1199 if (ULO.Heart) {
1200 auto it = VMap.find(Val: ULO.Heart);
1201 assert(it != VMap.end());
1202 Instruction *heartCopy = cast<Instruction>(Val&: it->second);
1203 heartCopy->eraseFromParent();
1204 VMap.erase(I: it);
1205 }
1206 }
1207
1208 // Remap source location atom instance. Do this now, rather than
1209 // when we remap instructions, because remap is called once we've
1210 // cloned all blocks (all the clones would get the same atom
1211 // number).
1212 if (!VMap.AtomMap.empty())
1213 for (Instruction &I : *New)
1214 RemapSourceAtom(I: &I, VM&: VMap);
1215
1216 // Update our running map of newest clones
1217 LastValueMap[*BB] = New;
1218 for (ValueToValueMapTy::iterator VI = VMap.begin(), VE = VMap.end();
1219 VI != VE; ++VI)
1220 LastValueMap[VI->first] = VI->second;
1221
1222 // Add phi entries for newly created values to all exit blocks.
1223 for (BasicBlock *Succ : successors(BB: *BB)) {
1224 if (L->contains(BB: Succ))
1225 continue;
1226 for (PHINode &PHI : Succ->phis()) {
1227 Value *Incoming = PHI.getIncomingValueForBlock(BB: *BB);
1228 ValueToValueMapTy::iterator It = LastValueMap.find(Val: Incoming);
1229 if (It != LastValueMap.end())
1230 Incoming = It->second;
1231 PHI.addIncoming(V: Incoming, BB: New);
1232 SE->forgetLcssaPhiWithNewPredecessor(L, V: &PHI);
1233 }
1234 }
1235 // Keep track of new headers and latches as we create them, so that
1236 // we can insert the proper branches later.
1237 if (*BB == Header)
1238 Headers.push_back(x: New);
1239 if (*BB == LatchBlock)
1240 Latches.push_back(x: New);
1241
1242 // Keep track of the exiting block and its successor block contained in
1243 // the loop for the current iteration.
1244 auto ExitInfoIt = ExitInfos.find(Key: *BB);
1245 if (ExitInfoIt != ExitInfos.end())
1246 ExitInfoIt->second.ExitingBlocks.push_back(Elt: New);
1247
1248 NewBlocks.push_back(Elt: New);
1249 UnrolledLoopBlocks.push_back(x: New);
1250
1251 // Update DomTree: since we just copy the loop body, and each copy has a
1252 // dedicated entry block (copy of the header block), this header's copy
1253 // dominates all copied blocks. That means, dominance relations in the
1254 // copied body are the same as in the original body.
1255 if (*BB == Header)
1256 DT->addNewBlock(BB: New, DomBB: Latches[It - 1]);
1257 else {
1258 auto BBDomNode = DT->getNode(BB: *BB);
1259 auto BBIDom = BBDomNode->getIDom();
1260 BasicBlock *OriginalBBIDom = BBIDom->getBlock();
1261 DT->addNewBlock(
1262 BB: New, DomBB: cast<BasicBlock>(Val&: LastValueMap[cast<Value>(Val: OriginalBBIDom)]));
1263 }
1264 }
1265
1266 // Remap all instructions in the most recent iteration.
1267 // Key Instructions: Nothing to do - we've already remapped the atoms.
1268 remapInstructionsInBlocks(Blocks: NewBlocks, VMap&: LastValueMap);
1269 for (BasicBlock *NewBlock : NewBlocks)
1270 for (Instruction &I : *NewBlock)
1271 if (auto *II = dyn_cast<AssumeInst>(Val: &I))
1272 AC->registerAssumption(CI: II);
1273
1274 {
1275 // Identify what other metadata depends on the cloned version. After
1276 // cloning, replace the metadata with the corrected version for both
1277 // memory instructions and noalias intrinsics.
1278 std::string ext = (Twine("It") + Twine(It)).str();
1279 cloneAndAdaptNoAliasScopes(NoAliasDeclScopes: LoopLocalNoAliasDeclScopes, NewBlocks,
1280 Context&: Header->getContext(), Ext: ext);
1281 }
1282 }
1283
1284 // Loop over the PHI nodes in the original block, setting incoming values.
1285 for (PHINode *PN : OrigPHINode) {
1286 if (CompletelyUnroll) {
1287 PN->replaceAllUsesWith(V: PN->getIncomingValueForBlock(BB: Preheader));
1288 PN->eraseFromParent();
1289 } else if (ULO.Count > 1) {
1290 if (Reductions.contains(Val: PN))
1291 continue;
1292
1293 Value *InVal = PN->removeIncomingValue(BB: LatchBlock, DeletePHIIfEmpty: false);
1294 // If this value was defined in the loop, take the value defined by the
1295 // last iteration of the loop.
1296 if (Instruction *InValI = dyn_cast<Instruction>(Val: InVal)) {
1297 if (L->contains(Inst: InValI))
1298 InVal = LastValueMap[InVal];
1299 }
1300 assert(Latches.back() == LastValueMap[LatchBlock] && "bad last latch");
1301 PN->addIncoming(V: InVal, BB: Latches.back());
1302 }
1303 }
1304
1305 // Connect latches of the unrolled iterations to the headers of the next
1306 // iteration. Currently they point to the header of the same iteration.
1307 for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
1308 unsigned j = (i + 1) % e;
1309 Latches[i]->getTerminator()->replaceSuccessorWith(OldBB: Headers[i], NewBB: Headers[j]);
1310 }
1311
1312 // Remove loop metadata copied from the original loop latch to branches that
1313 // are no longer latches.
1314 for (unsigned I = 0, E = Latches.size() - (CompletelyUnroll ? 0 : 1); I < E;
1315 ++I)
1316 Latches[I]->getTerminator()->setMetadata(KindID: LLVMContext::MD_loop, Node: nullptr);
1317
1318 // Update dominators of blocks we might reach through exits.
1319 // Immediate dominator of such block might change, because we add more
1320 // routes which can lead to the exit: we can now reach it from the copied
1321 // iterations too.
1322 if (ULO.Count > 1) {
1323 for (auto *BB : OriginalLoopBlocks) {
1324 auto *BBDomNode = DT->getNode(BB);
1325 SmallVector<BasicBlock *, 16> ChildrenToUpdate;
1326 for (auto *ChildDomNode : BBDomNode->children()) {
1327 auto *ChildBB = ChildDomNode->getBlock();
1328 if (!L->contains(BB: ChildBB))
1329 ChildrenToUpdate.push_back(Elt: ChildBB);
1330 }
1331 // The new idom of the block will be the nearest common dominator
1332 // of all copies of the previous idom. This is equivalent to the
1333 // nearest common dominator of the previous idom and the first latch,
1334 // which dominates all copies of the previous idom.
1335 BasicBlock *NewIDom = DT->findNearestCommonDominator(A: BB, B: LatchBlock);
1336 for (auto *ChildBB : ChildrenToUpdate)
1337 DT->changeImmediateDominator(BB: ChildBB, NewBB: NewIDom);
1338 }
1339 }
1340
1341 assert(!UnrollVerifyDomtree ||
1342 DT->verify(DominatorTree::VerificationLevel::Fast));
1343
1344 SmallVector<DominatorTree::UpdateType> DTUpdates;
1345 auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
1346 auto *Term = cast<CondBrInst>(Val: Src->getTerminator());
1347 const unsigned Idx = ExitOnTrue ^ WillExit;
1348 BasicBlock *Dest = Term->getSuccessor(i: Idx);
1349 BasicBlock *DeadSucc = Term->getSuccessor(i: 1-Idx);
1350
1351 // Remove predecessors from all non-Dest successors.
1352 DeadSucc->removePredecessor(Pred: Src, /* KeepOneInputPHIs */ true);
1353
1354 // Replace the conditional branch with an unconditional one.
1355 auto *BI = UncondBrInst::Create(Target: Dest, InsertBefore: Term->getIterator());
1356 BI->setDebugLoc(Term->getDebugLoc());
1357 Term->eraseFromParent();
1358
1359 DTUpdates.emplace_back(Args: DominatorTree::Delete, Args&: Src, Args&: DeadSucc);
1360 };
1361
1362 auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j,
1363 bool IsLatch) -> std::optional<bool> {
1364 if (CompletelyUnroll) {
1365 if (PreserveOnlyFirst) {
1366 if (i == 0)
1367 return std::nullopt;
1368 return j == 0;
1369 }
1370 // Complete (but possibly inexact) unrolling
1371 if (j == 0)
1372 return true;
1373 if (Info.TripCount && j != Info.TripCount)
1374 return false;
1375 return std::nullopt;
1376 }
1377
1378 if (ULO.Runtime) {
1379 // If runtime unrolling inserts a prologue, information about non-latch
1380 // exits may be stale.
1381 if (IsLatch && j != 0)
1382 return false;
1383 return std::nullopt;
1384 }
1385
1386 if (j != Info.BreakoutTrip &&
1387 (Info.TripMultiple == 0 || j % Info.TripMultiple != 0)) {
1388 // If we know the trip count or a multiple of it, we can safely use an
1389 // unconditional branch for some iterations.
1390 return false;
1391 }
1392 return std::nullopt;
1393 };
1394
1395 // Fold branches for iterations where we know that they will exit or not
1396 // exit. In the case of an iteration's latch, if we thus find
1397 // *OriginalLoopProb is incorrect, set ProbUpdateRequired to true.
1398 bool ProbUpdateRequired = false;
1399 for (auto &Pair : ExitInfos) {
1400 ExitInfo &Info = Pair.second;
1401 for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) {
1402 // The branch destination.
1403 unsigned j = (i + 1) % e;
1404 bool IsLatch = Pair.first == LatchBlock;
1405 std::optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch);
1406 if (!KnownWillExit) {
1407 if (!Info.FirstExitingBlock)
1408 Info.FirstExitingBlock = Info.ExitingBlocks[i];
1409 continue;
1410 }
1411
1412 // We don't fold known-exiting branches for non-latch exits here,
1413 // because this ensures that both all loop blocks and all exit blocks
1414 // remain reachable in the CFG.
1415 // TODO: We could fold these branches, but it would require much more
1416 // sophisticated updates to LoopInfo.
1417 if (*KnownWillExit && !IsLatch) {
1418 if (!Info.FirstExitingBlock)
1419 Info.FirstExitingBlock = Info.ExitingBlocks[i];
1420 continue;
1421 }
1422
1423 // For a latch, record any OriginalLoopProb contradiction.
1424 if (!OriginalLoopProb.isUnknown() && IsLatch) {
1425 BranchProbability ActualProb = *KnownWillExit
1426 ? BranchProbability::getZero()
1427 : BranchProbability::getOne();
1428 ProbUpdateRequired |= OriginalLoopProb != ActualProb;
1429 }
1430
1431 SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue);
1432 }
1433 }
1434
1435 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
1436 DomTreeUpdater *DTUToUse = &DTU;
1437 if (ExitingBlocks.size() == 1 && ExitInfos.size() == 1) {
1438 // Manually update the DT if there's a single exiting node. In that case
1439 // there's a single exit node and it is sufficient to update the nodes
1440 // immediately dominated by the original exiting block. They will become
1441 // dominated by the first exiting block that leaves the loop after
1442 // unrolling. Note that the CFG inside the loop does not change, so there's
1443 // no need to update the DT inside the unrolled loop.
1444 DTUToUse = nullptr;
1445 auto &[OriginalExit, Info] = *ExitInfos.begin();
1446 if (!Info.FirstExitingBlock)
1447 Info.FirstExitingBlock = Info.ExitingBlocks.back();
1448 for (auto *C : to_vector(Range: DT->getNode(BB: OriginalExit)->children())) {
1449 if (L->contains(BB: C->getBlock()))
1450 continue;
1451 C->setIDom(DT->getNode(BB: Info.FirstExitingBlock));
1452 }
1453 } else {
1454 DTU.applyUpdates(Updates: DTUpdates);
1455 }
1456
1457 // When completely unrolling, the last latch becomes unreachable.
1458 if (!LatchIsExiting && CompletelyUnroll) {
1459 // There is no need to update the DT here, because there must be a unique
1460 // latch. Hence if the latch is not exiting it must directly branch back to
1461 // the original loop header and does not dominate any nodes.
1462 assert(LatchBlock->getSingleSuccessor() && "Loop with multiple latches?");
1463 changeToUnreachable(I: Latches.back()->getTerminator(), PreserveLCSSA);
1464 }
1465
1466 // After merging adjacent blocks in Latches below:
1467 // - CondLatches will list the blocks from Latches that are still terminated
1468 // with conditional branches.
1469 // - For 1 <= I < CondLatches.size(), IterCounts[I] will store the number of
1470 // the original loop iterations through which control flows from
1471 // CondLatches[I-1] to CondLatches[I].
1472 // - For I == 0 or I == CondLatches.size(), IterCounts[I] will store the
1473 // number of the original loop iterations through which control can flow
1474 // before CondLatches.front() or after CondLatches.back(), respectively,
1475 // without taking the unrolled loop's backedge, if any.
1476 // - CondLatchNexts[I] will store the CondLatches[I] branch target for the
1477 // next of the original loop's iterations (as opposed to the exit target).
1478 assert(ULO.Count == Latches.size() &&
1479 "Expected one latch block per unrolled iteration");
1480 std::vector<unsigned> IterCounts(1, 0);
1481 std::vector<BasicBlock *> CondLatches;
1482 std::vector<BasicBlock *> CondLatchNexts;
1483 IterCounts.reserve(n: Latches.size() + 1);
1484 CondLatches.reserve(n: Latches.size());
1485 CondLatchNexts.reserve(n: Latches.size());
1486
1487 // Merge adjacent basic blocks, if possible.
1488 for (auto [I, Latch] : enumerate(First&: Latches)) {
1489 ++IterCounts.back();
1490 assert((isa<UncondBrInst, CondBrInst>(Latch->getTerminator()) ||
1491 (CompletelyUnroll && !LatchIsExiting && Latch == Latches.back())) &&
1492 "Need a branch as terminator, except when fully unrolling with "
1493 "unconditional latch");
1494 if (auto *Term = dyn_cast<UncondBrInst>(Val: Latch->getTerminator())) {
1495 BasicBlock *Dest = Term->getSuccessor();
1496 BasicBlock *Fold = Dest->getUniquePredecessor();
1497 if (MergeBlockIntoPredecessor(BB: Dest, /*DTU=*/DTUToUse, LI,
1498 /*MSSAU=*/nullptr, /*MemDep=*/nullptr,
1499 /*PredecessorWithTwoSuccessors=*/false,
1500 DT: DTUToUse ? nullptr : DT)) {
1501 // Dest has been folded into Fold. Update our worklists accordingly.
1502 llvm::replace(Range&: Latches, OldValue: Dest, NewValue: Fold);
1503 llvm::erase(C&: UnrolledLoopBlocks, V: Dest);
1504 }
1505 } else if (isa<CondBrInst>(Val: Latch->getTerminator())) {
1506 IterCounts.push_back(x: 0);
1507 CondLatches.push_back(x: Latch);
1508 CondLatchNexts.push_back(x: Headers[(I + 1) % Latches.size()]);
1509 }
1510 }
1511
1512 // Fix probabilities we contradicted above.
1513 if (ProbUpdateRequired) {
1514 fixProbContradiction(L, ULO, ORE, OriginalLoopProb, CompletelyUnroll,
1515 IterCounts, CondLatches, CondLatchNexts);
1516 }
1517
1518 // If there are partial reductions, create code in the exit block to compute
1519 // the final result and update users of the final result.
1520 if (!PartialReductions.empty()) {
1521 BasicBlock *ExitBlock = L->getExitBlock();
1522 assert(ExitBlock &&
1523 "Can only introduce parallel reduction phis with single exit block");
1524 assert(Reductions.size() == 1 &&
1525 "currently only a single reduction is supported");
1526 Value *FinalRdxValue = PartialReductions.back();
1527 Value *RdxResult = nullptr;
1528 for (PHINode &Phi : ExitBlock->phis()) {
1529 if (Phi.getIncomingValueForBlock(BB: L->getLoopLatch()) != FinalRdxValue)
1530 continue;
1531 if (!RdxResult) {
1532 RdxResult = PartialReductions.front();
1533 IRBuilder Builder(ExitBlock, ExitBlock->getFirstNonPHIIt());
1534 Builder.setFastMathFlags(Reductions.begin()->second.getFastMathFlags());
1535 RecurKind RK = Reductions.begin()->second.getRecurrenceKind();
1536 for (Instruction *RdxPart : drop_begin(RangeOrContainer&: PartialReductions)) {
1537 if (RecurrenceDescriptor::isMinMaxRecurrenceKind(Kind: RK))
1538 RdxResult = createMinMaxOp(Builder, RK, Left: RdxResult, Right: RdxPart);
1539 else
1540 RdxResult = Builder.CreateBinOp(
1541 Opc: (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode(Kind: RK),
1542 LHS: RdxPart, RHS: RdxResult, Name: "bin.rdx");
1543 }
1544 NeedToFixLCSSA = true;
1545 for (Instruction *RdxPart : PartialReductions)
1546 RdxPart->dropPoisonGeneratingFlags();
1547 }
1548
1549 Phi.replaceAllUsesWith(V: RdxResult);
1550 }
1551 }
1552
1553 if (DTUToUse) {
1554 // Apply updates to the DomTree.
1555 DT = &DTU.getDomTree();
1556 }
1557 assert(!UnrollVerifyDomtree ||
1558 DT->verify(DominatorTree::VerificationLevel::Fast));
1559
1560 Loop *OuterL = L->getParentLoop();
1561 std::vector<BasicBlock *> Blocks;
1562 // Update LoopInfo if the loop is completely removed.
1563 if (CompletelyUnroll) {
1564 Blocks = L->getBlocks();
1565 LI->erase(L);
1566 // We shouldn't try to use `L` anymore.
1567 L = nullptr;
1568 }
1569
1570 // At this point, the code is well formed. We now simplify the unrolled loop,
1571 // doing constant propagation and dead code elimination as we go.
1572 simplifyLoopAfterUnroll(
1573 L, SimplifyIVs: !CompletelyUnroll && ULO.Count > 1, LI, SE, DT, AC, TTI,
1574 Blocks: CompletelyUnroll ? ArrayRef<BasicBlock *>(Blocks) : L->getBlocks(), AA);
1575
1576 NumCompletelyUnrolled += CompletelyUnroll;
1577 ++NumUnrolled;
1578
1579 if (!CompletelyUnroll) {
1580 // Update metadata for the loop's branch weights and estimated trip count:
1581 // - If ULO.Runtime, UnrollRuntimeLoopRemainder sets the guard branch
1582 // weights, latch branch weights, and estimated trip count of the
1583 // remainder loop it creates. It also sets the branch weights for the
1584 // unrolled loop guard it creates. The branch weights for the unrolled
1585 // loop latch are adjusted below. FIXME: Handle prologue loops.
1586 // - Otherwise, if unrolled loop iteration latches become unconditional,
1587 // branch weights are adjusted by the fixProbContradiction call above.
1588 // - Otherwise, the original loop's branch weights are correct for the
1589 // unrolled loop, so do not adjust them.
1590 // - In all cases, the unrolled loop's estimated trip count is set below.
1591 //
1592 // As an example of the last case, consider what happens if the unroll count
1593 // is 4 for a loop with an estimated trip count of 10 when we do not create
1594 // a remainder loop and all iterations' latches remain conditional. Each
1595 // unrolled iteration's latch still has the same probability of exiting the
1596 // loop as it did when in the original loop, and thus it should still have
1597 // the same branch weights. Each unrolled iteration's non-zero probability
1598 // of exiting already appropriately reduces the probability of reaching the
1599 // remaining iterations just as it did in the original loop. Trying to also
1600 // adjust the branch weights of the final unrolled iteration's latch (i.e.,
1601 // the backedge for the unrolled loop as a whole) to reflect its new trip
1602 // count of 3 will erroneously further reduce its block frequencies.
1603 // However, in case an analysis later needs to estimate the trip count of
1604 // the unrolled loop as a whole without considering the branch weights for
1605 // each unrolled iteration's latch within it, we store the new trip count as
1606 // separate metadata.
1607 if (!OriginalLoopProb.isUnknown() && ULO.Runtime && EpilogProfitability) {
1608 assert((CondLatches.size() == 1 &&
1609 (ProbUpdateRequired || OriginalLoopProb.isOne())) &&
1610 "Expected ULO.Runtime to give unrolled loop 1 conditional latch, "
1611 "the backedge, requiring a probability update unless infinite");
1612 // Where p is always the probability of executing at least 1 more
1613 // iteration, the probability for at least n more iterations is p^n.
1614 setLoopProbability(L, P: OriginalLoopProb.pow(N: ULO.Count));
1615 }
1616 if (OriginalTripCount) {
1617 unsigned NewTripCount = *OriginalTripCount / ULO.Count;
1618 if (!ULO.Runtime && *OriginalTripCount % ULO.Count)
1619 ++NewTripCount;
1620 setLoopEstimatedTripCount(L, EstimatedTripCount: NewTripCount);
1621 }
1622 }
1623
1624 // LoopInfo should not be valid, confirm that.
1625 if (UnrollVerifyLoopInfo)
1626 LI->verify(DomTree: *DT);
1627
1628 // After complete unrolling most of the blocks should be contained in OuterL.
1629 // However, some of them might happen to be out of OuterL (e.g. if they
1630 // precede a loop exit). In this case we might need to insert PHI nodes in
1631 // order to preserve LCSSA form.
1632 // We don't need to check this if we already know that we need to fix LCSSA
1633 // form.
1634 // TODO: For now we just recompute LCSSA for the outer loop in this case, but
1635 // it should be possible to fix it in-place.
1636 if (PreserveLCSSA && OuterL && CompletelyUnroll && !NeedToFixLCSSA)
1637 NeedToFixLCSSA |= ::needToInsertPhisForLCSSA(L: OuterL, Blocks: UnrolledLoopBlocks, LI);
1638
1639 // Make sure that loop-simplify form is preserved. We want to simplify
1640 // at least one layer outside of the loop that was unrolled so that any
1641 // changes to the parent loop exposed by the unrolling are considered.
1642 if (OuterL) {
1643 // OuterL includes all loops for which we can break loop-simplify, so
1644 // it's sufficient to simplify only it (it'll recursively simplify inner
1645 // loops too).
1646 if (NeedToFixLCSSA) {
1647 // LCSSA must be performed on the outermost affected loop. The unrolled
1648 // loop's last loop latch is guaranteed to be in the outermost loop
1649 // after LoopInfo's been updated by LoopInfo::erase.
1650 Loop *LatchLoop = LI->getLoopFor(BB: Latches.back());
1651 Loop *FixLCSSALoop = OuterL;
1652 if (!FixLCSSALoop->contains(L: LatchLoop))
1653 while (FixLCSSALoop->getParentLoop() != LatchLoop)
1654 FixLCSSALoop = FixLCSSALoop->getParentLoop();
1655
1656 formLCSSARecursively(L&: *FixLCSSALoop, DT: *DT, LI, SE);
1657 } else if (PreserveLCSSA) {
1658 assert(OuterL->isLCSSAForm(*DT) &&
1659 "Loops should be in LCSSA form after loop-unroll.");
1660 }
1661
1662 // TODO: That potentially might be compile-time expensive. We should try
1663 // to fix the loop-simplified form incrementally.
1664 simplifyLoop(L: OuterL, DT, LI, SE, AC, MSSAU: nullptr, PreserveLCSSA);
1665 } else {
1666 // Simplify loops for which we might've broken loop-simplify form.
1667 for (Loop *SubLoop : LoopsToSimplify)
1668 simplifyLoop(L: SubLoop, DT, LI, SE, AC, MSSAU: nullptr, PreserveLCSSA);
1669 }
1670
1671 return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled
1672 : LoopUnrollResult::PartiallyUnrolled;
1673}
1674
1675/// Given an llvm.loop loop id metadata node, returns the loop hint metadata
1676/// node with the given name (for example, "llvm.loop.unroll.count"). If no
1677/// such metadata node exists, then nullptr is returned.
1678MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) {
1679 // First operand should refer to the loop id itself.
1680 assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
1681 assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
1682
1683 for (const MDOperand &MDO : llvm::drop_begin(RangeOrContainer: LoopID->operands())) {
1684 MDNode *MD = dyn_cast<MDNode>(Val: MDO);
1685 if (!MD)
1686 continue;
1687
1688 MDString *S = dyn_cast<MDString>(Val: MD->getOperand(I: 0));
1689 if (!S)
1690 continue;
1691
1692 if (Name == S->getString())
1693 return MD;
1694 }
1695 return nullptr;
1696}
1697
1698// Returns the loop hint metadata node with the given name (for example,
1699// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is
1700// returned.
1701MDNode *llvm::getUnrollMetadataForLoop(const Loop *L, StringRef Name) {
1702 if (MDNode *LoopID = L->getLoopID())
1703 return GetUnrollMetadata(LoopID, Name);
1704 return nullptr;
1705}
1706
1707std::optional<RecurrenceDescriptor>
1708llvm::canParallelizeReductionWhenUnrolling(PHINode &Phi, Loop *L,
1709 ScalarEvolution *SE) {
1710 RecurrenceDescriptor RdxDesc;
1711 if (!RecurrenceDescriptor::isReductionPHI(Phi: &Phi, TheLoop: L, RedDes&: RdxDesc,
1712 /*DemandedBits=*/DB: nullptr,
1713 /*AC=*/nullptr, /*DT=*/nullptr, SE))
1714 return std::nullopt;
1715 if (RdxDesc.hasUsesOutsideReductionChain())
1716 return std::nullopt;
1717 RecurKind RK = RdxDesc.getRecurrenceKind();
1718 // Skip unsupported reductions.
1719 // TODO: Handle any-of and find-last reductions.
1720 if (RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind: RK) ||
1721 RecurrenceDescriptor::isFindRecurrenceKind(Kind: RK))
1722 return std::nullopt;
1723
1724 if (RdxDesc.hasExactFPMath())
1725 return std::nullopt;
1726
1727 if (RdxDesc.IntermediateStore)
1728 return std::nullopt;
1729
1730 // Don't unroll reductions with constant ops; those can be folded to a
1731 // single induction update.
1732 if (any_of(Range: cast<Instruction>(Val: Phi.getIncomingValueForBlock(BB: L->getLoopLatch()))
1733 ->operands(),
1734 P: IsaPred<Constant>))
1735 return std::nullopt;
1736
1737 BasicBlock *Latch = L->getLoopLatch();
1738 if (!Latch ||
1739 !is_contained(
1740 Range: cast<Instruction>(Val: Phi.getIncomingValueForBlock(BB: Latch))->operands(),
1741 Element: &Phi))
1742 return std::nullopt;
1743
1744 return RdxDesc;
1745}
1746