1//===-- LoopUtils.cpp - Loop Utility functions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines common loop utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/LoopUtils.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Analysis/AliasAnalysis.h"
21#include "llvm/Analysis/BasicAliasAnalysis.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/LoopAccessAnalysis.h"
26#include "llvm/Analysis/LoopInfo.h"
27#include "llvm/Analysis/LoopPass.h"
28#include "llvm/Analysis/MemorySSA.h"
29#include "llvm/Analysis/MemorySSAUpdater.h"
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
32#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33#include "llvm/IR/DIBuilder.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/Instructions.h"
36#include "llvm/IR/IntrinsicInst.h"
37#include "llvm/IR/MDBuilder.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/PatternMatch.h"
40#include "llvm/IR/ProfDataUtils.h"
41#include "llvm/IR/ValueHandle.h"
42#include "llvm/InitializePasses.h"
43#include "llvm/Pass.h"
44#include "llvm/Support/Compiler.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Transforms/Utils/BasicBlockUtils.h"
47#include "llvm/Transforms/Utils/Local.h"
48#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
49
50using namespace llvm;
51using namespace llvm::PatternMatch;
52
53#define DEBUG_TYPE "loop-utils"
54
55static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced";
56static const char *LLVMLoopDisableLICM = "llvm.licm.disable";
57namespace llvm {
58extern cl::opt<bool> ProfcheckDisableMetadataFixes;
59} // namespace llvm
60
61bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
62 MemorySSAUpdater *MSSAU,
63 bool PreserveLCSSA) {
64 bool Changed = false;
65
66 // We re-use a vector for the in-loop predecesosrs.
67 SmallVector<BasicBlock *, 4> InLoopPredecessors;
68
69 auto RewriteExit = [&](BasicBlock *BB) {
70 assert(InLoopPredecessors.empty() &&
71 "Must start with an empty predecessors list!");
72 llvm::scope_exit Cleanup([&] { InLoopPredecessors.clear(); });
73
74 // See if there are any non-loop predecessors of this exit block and
75 // keep track of the in-loop predecessors.
76 bool IsDedicatedExit = true;
77 for (auto *PredBB : predecessors(BB))
78 if (L->contains(BB: PredBB)) {
79 if (isa<IndirectBrInst>(Val: PredBB->getTerminator()))
80 // We cannot rewrite exiting edges from an indirectbr.
81 return false;
82
83 InLoopPredecessors.push_back(Elt: PredBB);
84 } else {
85 IsDedicatedExit = false;
86 }
87
88 assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!");
89
90 // Nothing to do if this is already a dedicated exit.
91 if (IsDedicatedExit)
92 return false;
93
94 auto *NewExitBB = SplitBlockPredecessors(
95 BB, Preds: InLoopPredecessors, Suffix: ".loopexit", DT, LI, MSSAU, PreserveLCSSA);
96
97 if (!NewExitBB)
98 LLVM_DEBUG(
99 dbgs() << "WARNING: Can't create a dedicated exit block for loop: "
100 << *L << "\n");
101 else
102 LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block "
103 << NewExitBB->getName() << "\n");
104 return true;
105 };
106
107 // Walk the exit blocks directly rather than building up a data structure for
108 // them, but only visit each one once.
109 SmallPtrSet<BasicBlock *, 4> Visited;
110 for (auto *BB : L->blocks())
111 for (auto *SuccBB : successors(BB)) {
112 // We're looking for exit blocks so skip in-loop successors.
113 if (L->contains(BB: SuccBB))
114 continue;
115
116 // Visit each exit block exactly once.
117 if (!Visited.insert(Ptr: SuccBB).second)
118 continue;
119
120 Changed |= RewriteExit(SuccBB);
121 }
122
123 return Changed;
124}
125
126/// Returns the instructions that use values defined in the loop.
127SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) {
128 SmallVector<Instruction *, 8> UsedOutside;
129
130 for (auto *Block : L->getBlocks())
131 // FIXME: I believe that this could use copy_if if the Inst reference could
132 // be adapted into a pointer.
133 for (auto &Inst : *Block) {
134 auto Users = Inst.users();
135 if (any_of(Range&: Users, P: [&](User *U) {
136 auto *Use = cast<Instruction>(Val: U);
137 return !L->contains(BB: Use->getParent());
138 }))
139 UsedOutside.push_back(Elt: &Inst);
140 }
141
142 return UsedOutside;
143}
144
145void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
146 // By definition, all loop passes need the LoopInfo analysis and the
147 // Dominator tree it depends on. Because they all participate in the loop
148 // pass manager, they must also preserve these.
149 AU.addRequired<DominatorTreeWrapperPass>();
150 AU.addPreserved<DominatorTreeWrapperPass>();
151 AU.addRequired<LoopInfoWrapperPass>();
152 AU.addPreserved<LoopInfoWrapperPass>();
153
154 // We must also preserve LoopSimplify and LCSSA. We locally access their IDs
155 // here because users shouldn't directly get them from this header.
156 extern char &LoopSimplifyID;
157 extern char &LCSSAID;
158 AU.addRequiredID(ID&: LoopSimplifyID);
159 AU.addPreservedID(ID&: LoopSimplifyID);
160 AU.addRequiredID(ID&: LCSSAID);
161 AU.addPreservedID(ID&: LCSSAID);
162 // This is used in the LPPassManager to perform LCSSA verification on passes
163 // which preserve lcssa form
164 AU.addRequired<LCSSAVerificationPass>();
165 AU.addPreserved<LCSSAVerificationPass>();
166
167 // Loop passes are designed to run inside of a loop pass manager which means
168 // that any function analyses they require must be required by the first loop
169 // pass in the manager (so that it is computed before the loop pass manager
170 // runs) and preserved by all loop pasess in the manager. To make this
171 // reasonably robust, the set needed for most loop passes is maintained here.
172 // If your loop pass requires an analysis not listed here, you will need to
173 // carefully audit the loop pass manager nesting structure that results.
174 AU.addRequired<AAResultsWrapperPass>();
175 AU.addPreserved<AAResultsWrapperPass>();
176 AU.addPreserved<BasicAAWrapperPass>();
177 AU.addPreserved<GlobalsAAWrapperPass>();
178 AU.addPreserved<SCEVAAWrapperPass>();
179 AU.addRequired<ScalarEvolutionWrapperPass>();
180 AU.addPreserved<ScalarEvolutionWrapperPass>();
181 // FIXME: When all loop passes preserve MemorySSA, it can be required and
182 // preserved here instead of the individual handling in each pass.
183}
184
185/// Manually defined generic "LoopPass" dependency initialization. This is used
186/// to initialize the exact set of passes from above in \c
187/// getLoopAnalysisUsage. It can be used within a loop pass's initialization
188/// with:
189///
190/// INITIALIZE_PASS_DEPENDENCY(LoopPass)
191///
192/// As-if "LoopPass" were a pass.
193void llvm::initializeLoopPassPass(PassRegistry &Registry) {
194 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
195 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
197 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
198 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
199 INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
200 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
201 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
202 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
203 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
204}
205
206/// Create MDNode for input string.
207static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
208 LLVMContext &Context = TheLoop->getHeader()->getContext();
209 Metadata *MDs[] = {
210 MDString::get(Context, Str: Name),
211 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V))};
212 return MDNode::get(Context, MDs);
213}
214
215/// Set input string into loop metadata by keeping other values intact.
216/// If the string is already in loop metadata update value if it is
217/// different.
218void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
219 unsigned V) {
220 SmallVector<Metadata *, 4> MDs(1);
221 // If the loop already has metadata, retain it.
222 MDNode *LoopID = TheLoop->getLoopID();
223 if (LoopID) {
224 for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
225 MDNode *Node = cast<MDNode>(Val: LoopID->getOperand(I: i));
226 // If it is of form key = value, try to parse it.
227 if (Node->getNumOperands() == 2) {
228 MDString *S = dyn_cast<MDString>(Val: Node->getOperand(I: 0));
229 if (S && S->getString() == StringMD) {
230 ConstantInt *IntMD =
231 mdconst::extract_or_null<ConstantInt>(MD: Node->getOperand(I: 1));
232 if (IntMD && IntMD->getSExtValue() == V)
233 // It is already in place. Do nothing.
234 return;
235 // We need to update the value, so just skip it here and it will
236 // be added after copying other existed nodes.
237 continue;
238 }
239 }
240 MDs.push_back(Elt: Node);
241 }
242 }
243 // Add new metadata.
244 MDs.push_back(Elt: createStringMetadata(TheLoop, Name: StringMD, V));
245 // Replace current metadata node with new one.
246 LLVMContext &Context = TheLoop->getHeader()->getContext();
247 MDNode *NewLoopID = MDNode::get(Context, MDs);
248 // Set operand 0 to refer to the loop id itself.
249 NewLoopID->replaceOperandWith(I: 0, New: NewLoopID);
250 TheLoop->setLoopID(NewLoopID);
251}
252
253std::optional<ElementCount>
254llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) {
255 std::optional<int> Width =
256 getOptionalIntLoopAttribute(TheLoop, Name: "llvm.loop.vectorize.width");
257
258 if (Width) {
259 std::optional<int> IsScalable = getOptionalIntLoopAttribute(
260 TheLoop, Name: "llvm.loop.vectorize.scalable.enable");
261 return ElementCount::get(MinVal: *Width, Scalable: IsScalable.value_or(u: false));
262 }
263
264 return std::nullopt;
265}
266
267std::optional<MDNode *> llvm::makeFollowupLoopID(
268 MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions,
269 const char *InheritOptionsExceptPrefix, bool AlwaysNew) {
270 if (!OrigLoopID) {
271 if (AlwaysNew)
272 return nullptr;
273 return std::nullopt;
274 }
275
276 assert(OrigLoopID->getOperand(0) == OrigLoopID);
277
278 bool InheritAllAttrs = !InheritOptionsExceptPrefix;
279 bool InheritSomeAttrs =
280 InheritOptionsExceptPrefix && InheritOptionsExceptPrefix[0] != '\0';
281 SmallVector<Metadata *, 8> MDs;
282 MDs.push_back(Elt: nullptr);
283
284 bool Changed = false;
285 if (InheritAllAttrs || InheritSomeAttrs) {
286 for (const MDOperand &Existing : drop_begin(RangeOrContainer: OrigLoopID->operands())) {
287 MDNode *Op = cast<MDNode>(Val: Existing.get());
288
289 auto InheritThisAttribute = [InheritSomeAttrs,
290 InheritOptionsExceptPrefix](MDNode *Op) {
291 if (!InheritSomeAttrs)
292 return false;
293
294 // Skip malformatted attribute metadata nodes.
295 if (Op->getNumOperands() == 0)
296 return true;
297 Metadata *NameMD = Op->getOperand(I: 0).get();
298 if (!isa<MDString>(Val: NameMD))
299 return true;
300 StringRef AttrName = cast<MDString>(Val: NameMD)->getString();
301
302 // Do not inherit excluded attributes.
303 return !AttrName.starts_with(Prefix: InheritOptionsExceptPrefix);
304 };
305
306 if (InheritThisAttribute(Op))
307 MDs.push_back(Elt: Op);
308 else
309 Changed = true;
310 }
311 } else {
312 // Modified if we dropped at least one attribute.
313 Changed = OrigLoopID->getNumOperands() > 1;
314 }
315
316 bool HasAnyFollowup = false;
317 for (StringRef OptionName : FollowupOptions) {
318 MDNode *FollowupNode = findOptionMDForLoopID(LoopID: OrigLoopID, Name: OptionName);
319 if (!FollowupNode)
320 continue;
321
322 HasAnyFollowup = true;
323 for (const MDOperand &Option : drop_begin(RangeOrContainer: FollowupNode->operands())) {
324 MDs.push_back(Elt: Option.get());
325 Changed = true;
326 }
327 }
328
329 // Attributes of the followup loop not specified explicity, so signal to the
330 // transformation pass to add suitable attributes.
331 if (!AlwaysNew && !HasAnyFollowup)
332 return std::nullopt;
333
334 // If no attributes were added or remove, the previous loop Id can be reused.
335 if (!AlwaysNew && !Changed)
336 return OrigLoopID;
337
338 // No attributes is equivalent to having no !llvm.loop metadata at all.
339 if (MDs.size() == 1)
340 return nullptr;
341
342 // Build the new loop ID.
343 MDTuple *FollowupLoopID = MDNode::get(Context&: OrigLoopID->getContext(), MDs);
344 FollowupLoopID->replaceOperandWith(I: 0, New: FollowupLoopID);
345 return FollowupLoopID;
346}
347
348bool llvm::hasDisableAllTransformsHint(const Loop *L) {
349 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableNonforced);
350}
351
352bool llvm::hasDisableLICMTransformsHint(const Loop *L) {
353 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableLICM);
354}
355
356TransformationMode llvm::hasUnrollTransformation(const Loop *L) {
357 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable"))
358 return TM_SuppressedByUser;
359
360 std::optional<int> Count =
361 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count");
362 if (Count)
363 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
364
365 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable"))
366 return TM_ForcedByUser;
367
368 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full"))
369 return TM_ForcedByUser;
370
371 if (hasDisableAllTransformsHint(L))
372 return TM_Disable;
373
374 return TM_Unspecified;
375}
376
377TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) {
378 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.disable"))
379 return TM_SuppressedByUser;
380
381 std::optional<int> Count =
382 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.count");
383 if (Count)
384 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
385
386 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.enable"))
387 return TM_ForcedByUser;
388
389 if (hasDisableAllTransformsHint(L))
390 return TM_Disable;
391
392 return TM_Unspecified;
393}
394
395TransformationMode llvm::hasVectorizeTransformation(const Loop *L) {
396 std::optional<bool> Enable =
397 getOptionalBoolLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.enable");
398
399 if (Enable == false)
400 return TM_SuppressedByUser;
401
402 std::optional<ElementCount> VectorizeWidth =
403 getOptionalElementCountLoopAttribute(TheLoop: L);
404 std::optional<int> InterleaveCount =
405 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.interleave.count");
406
407 // 'Forcing' vector width and interleave count to one effectively disables
408 // this tranformation.
409 if (Enable == true && VectorizeWidth && VectorizeWidth->isScalar() &&
410 InterleaveCount == 1)
411 return TM_SuppressedByUser;
412
413 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.isvectorized"))
414 return TM_Disable;
415
416 if (Enable == true)
417 return TM_ForcedByUser;
418
419 if ((VectorizeWidth && VectorizeWidth->isScalar()) && InterleaveCount == 1)
420 return TM_Disable;
421
422 if ((VectorizeWidth && VectorizeWidth->isVector()) || InterleaveCount > 1)
423 return TM_Enable;
424
425 if (hasDisableAllTransformsHint(L))
426 return TM_Disable;
427
428 return TM_Unspecified;
429}
430
431TransformationMode llvm::hasDistributeTransformation(const Loop *L) {
432 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.distribute.enable"))
433 return TM_ForcedByUser;
434
435 if (hasDisableAllTransformsHint(L))
436 return TM_Disable;
437
438 return TM_Unspecified;
439}
440
441TransformationMode llvm::hasLICMVersioningTransformation(const Loop *L) {
442 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.licm_versioning.disable"))
443 return TM_SuppressedByUser;
444
445 if (hasDisableAllTransformsHint(L))
446 return TM_Disable;
447
448 return TM_Unspecified;
449}
450
451/// Does a BFS from a given node to all of its children inside a given loop.
452/// The returned vector of basic blocks includes the starting point.
453SmallVector<BasicBlock *, 16> llvm::collectChildrenInLoop(DominatorTree *DT,
454 DomTreeNode *N,
455 const Loop *CurLoop) {
456 SmallVector<BasicBlock *, 16> Worklist;
457 auto AddRegionToWorklist = [&](DomTreeNode *DTN) {
458 // Only include subregions in the top level loop.
459 BasicBlock *BB = DTN->getBlock();
460 if (CurLoop->contains(BB))
461 Worklist.push_back(Elt: DTN->getBlock());
462 };
463
464 AddRegionToWorklist(N);
465
466 for (size_t I = 0; I < Worklist.size(); I++) {
467 for (DomTreeNode *Child : DT->getNode(BB: Worklist[I])->children())
468 AddRegionToWorklist(Child);
469 }
470
471 return Worklist;
472}
473
474bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
475 int LatchIdx = PN->getBasicBlockIndex(BB: LatchBlock);
476 assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode");
477 Value *IncV = PN->getIncomingValue(i: LatchIdx);
478
479 for (User *U : PN->users())
480 if (U != Cond && U != IncV) return false;
481
482 for (User *U : IncV->users())
483 if (U != Cond && U != PN) return false;
484 return true;
485}
486
487
488void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
489 LoopInfo *LI, MemorySSA *MSSA) {
490 assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
491 auto *Preheader = L->getLoopPreheader();
492 assert(Preheader && "Preheader should exist!");
493
494 std::unique_ptr<MemorySSAUpdater> MSSAU;
495 if (MSSA)
496 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
497
498 // Now that we know the removal is safe, remove the loop by changing the
499 // branch from the preheader to go to the single exit block.
500 //
501 // Because we're deleting a large chunk of code at once, the sequence in which
502 // we remove things is very important to avoid invalidation issues.
503
504 // Tell ScalarEvolution that the loop is deleted. Do this before
505 // deleting the loop so that ScalarEvolution can look at the loop
506 // to determine what it needs to clean up.
507 if (SE) {
508 SE->forgetLoop(L);
509 SE->forgetBlockAndLoopDispositions();
510 }
511
512 Instruction *OldTerm = Preheader->getTerminator();
513 assert(!OldTerm->mayHaveSideEffects() &&
514 "Preheader must end with a side-effect-free terminator");
515 assert(OldTerm->getNumSuccessors() == 1 &&
516 "Preheader must have a single successor");
517 // Connect the preheader to the exit block. Keep the old edge to the header
518 // around to perform the dominator tree update in two separate steps
519 // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge
520 // preheader -> header.
521 //
522 //
523 // 0. Preheader 1. Preheader 2. Preheader
524 // | | | |
525 // V | V |
526 // Header <--\ | Header <--\ | Header <--\
527 // | | | | | | | | | | |
528 // | V | | | V | | | V |
529 // | Body --/ | | Body --/ | | Body --/
530 // V V V V V
531 // Exit Exit Exit
532 //
533 // By doing this is two separate steps we can perform the dominator tree
534 // update without using the batch update API.
535 //
536 // Even when the loop is never executed, we cannot remove the edge from the
537 // source block to the exit block. Consider the case where the unexecuted loop
538 // branches back to an outer loop. If we deleted the loop and removed the edge
539 // coming to this inner loop, this will break the outer loop structure (by
540 // deleting the backedge of the outer loop). If the outer loop is indeed a
541 // non-loop, it will be deleted in a future iteration of loop deletion pass.
542 IRBuilder<> Builder(OldTerm);
543
544 auto *ExitBlock = L->getUniqueExitBlock();
545 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
546 if (ExitBlock) {
547 assert(ExitBlock && "Should have a unique exit block!");
548 assert(L->hasDedicatedExits() && "Loop should have dedicated exits!");
549
550 Builder.CreateCondBr(Cond: Builder.getFalse(), True: L->getHeader(), False: ExitBlock);
551 // Remove the old branch. The conditional branch becomes a new terminator.
552 OldTerm->eraseFromParent();
553
554 // Rewrite phis in the exit block to get their inputs from the Preheader
555 // instead of the exiting block.
556 for (PHINode &P : ExitBlock->phis()) {
557 // Set the zero'th element of Phi to be from the preheader and remove all
558 // other incoming values. Given the loop has dedicated exits, all other
559 // incoming values must be from the exiting blocks.
560 int PredIndex = 0;
561 P.setIncomingBlock(i: PredIndex, BB: Preheader);
562 // Removes all incoming values from all other exiting blocks (including
563 // duplicate values from an exiting block).
564 // Nuke all entries except the zero'th entry which is the preheader entry.
565 P.removeIncomingValueIf(Predicate: [](unsigned Idx) { return Idx != 0; },
566 /* DeletePHIIfEmpty */ false);
567
568 assert((P.getNumIncomingValues() == 1 &&
569 P.getIncomingBlock(PredIndex) == Preheader) &&
570 "Should have exactly one value and that's from the preheader!");
571 }
572
573 if (DT) {
574 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}});
575 if (MSSA) {
576 MSSAU->applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}},
577 DT&: *DT);
578 if (VerifyMemorySSA)
579 MSSA->verifyMemorySSA();
580 }
581 }
582
583 // Disconnect the loop body by branching directly to its exit.
584 Builder.SetInsertPoint(Preheader->getTerminator());
585 Builder.CreateBr(Dest: ExitBlock);
586 // Remove the old branch.
587 Preheader->getTerminator()->eraseFromParent();
588 } else {
589 assert(L->hasNoExitBlocks() &&
590 "Loop should have either zero or one exit blocks.");
591
592 Builder.SetInsertPoint(OldTerm);
593 Builder.CreateUnreachable();
594 Preheader->getTerminator()->eraseFromParent();
595 }
596
597 if (DT) {
598 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}});
599 if (MSSA) {
600 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}},
601 DT&: *DT);
602 SmallSetVector<BasicBlock *, 8> DeadBlockSet(L->block_begin(),
603 L->block_end());
604 MSSAU->removeBlocks(DeadBlocks: DeadBlockSet);
605 if (VerifyMemorySSA)
606 MSSA->verifyMemorySSA();
607 }
608 }
609
610 // Use a map to unique and a vector to guarantee deterministic ordering.
611 llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
612 llvm::SmallVector<DbgVariableRecord *, 4> DeadDbgVariableRecords;
613
614 // Given LCSSA form is satisfied, we should not have users of instructions
615 // within the dead loop outside of the loop. However, LCSSA doesn't take
616 // unreachable uses into account. We handle them here.
617 // We could do it after drop all references (in this case all users in the
618 // loop will be already eliminated and we have less work to do but according
619 // to API doc of User::dropAllReferences only valid operation after dropping
620 // references, is deletion. So let's substitute all usages of
621 // instruction from the loop with poison value of corresponding type first.
622 for (auto *Block : L->blocks())
623 for (Instruction &I : *Block) {
624 auto *Poison = PoisonValue::get(T: I.getType());
625 for (Use &U : llvm::make_early_inc_range(Range: I.uses())) {
626 if (auto *Usr = dyn_cast<Instruction>(Val: U.getUser()))
627 if (L->contains(BB: Usr->getParent()))
628 continue;
629 // If we have a DT then we can check that uses outside a loop only in
630 // unreachable block.
631 if (DT)
632 assert(!DT->isReachableFromEntry(U) &&
633 "Unexpected user in reachable block");
634 U.set(Poison);
635 }
636
637 if (ExitBlock) {
638 // For one of each variable encountered, preserve a debug record (set
639 // to Poison) and transfer it to the loop exit. This terminates any
640 // variable locations that were set during the loop.
641 for (DbgVariableRecord &DVR :
642 llvm::make_early_inc_range(Range: filterDbgVars(R: I.getDbgRecordRange()))) {
643 DebugVariable Key(DVR.getVariable(), DVR.getExpression(),
644 DVR.getDebugLoc().get());
645 if (!DeadDebugSet.insert(V: Key).second)
646 continue;
647 // Unlinks the DVR from it's container, for later insertion.
648 DVR.removeFromParent();
649 DeadDbgVariableRecords.push_back(Elt: &DVR);
650 }
651 }
652 }
653
654 if (ExitBlock) {
655 // After the loop has been deleted all the values defined and modified
656 // inside the loop are going to be unavailable. Values computed in the
657 // loop will have been deleted, automatically causing their debug uses
658 // be be replaced with undef. Loop invariant values will still be available.
659 // Move dbg.values out the loop so that earlier location ranges are still
660 // terminated and loop invariant assignments are preserved.
661 DIBuilder DIB(*ExitBlock->getModule());
662 BasicBlock::iterator InsertDbgValueBefore =
663 ExitBlock->getFirstInsertionPt();
664 assert(InsertDbgValueBefore != ExitBlock->end() &&
665 "There should be a non-PHI instruction in exit block, else these "
666 "instructions will have no parent.");
667
668 // Due to the "head" bit in BasicBlock::iterator, we're going to insert
669 // each DbgVariableRecord right at the start of the block, wheras dbg.values
670 // would be repeatedly inserted before the first instruction. To replicate
671 // this behaviour, do it backwards.
672 for (DbgVariableRecord *DVR : llvm::reverse(C&: DeadDbgVariableRecords))
673 ExitBlock->insertDbgRecordBefore(DR: DVR, Here: InsertDbgValueBefore);
674 }
675
676 // Remove the block from the reference counting scheme, so that we can
677 // delete it freely later.
678 for (auto *Block : L->blocks())
679 Block->dropAllReferences();
680
681 if (MSSA && VerifyMemorySSA)
682 MSSA->verifyMemorySSA();
683
684 if (LI) {
685 // Erase the instructions and the blocks without having to worry
686 // about ordering because we already dropped the references.
687 // NOTE: This iteration is safe because erasing the block does not remove
688 // its entry from the loop's block list. We do that in the next section.
689 for (BasicBlock *BB : L->blocks())
690 BB->eraseFromParent();
691
692 // Finally, the blocks from loopinfo. This has to happen late because
693 // otherwise our loop iterators won't work.
694
695 SmallPtrSet<BasicBlock *, 8> blocks(llvm::from_range, L->blocks());
696 for (BasicBlock *BB : blocks)
697 LI->removeBlock(BB);
698
699 // The last step is to update LoopInfo now that we've eliminated this loop.
700 // Note: LoopInfo::erase remove the given loop and relink its subloops with
701 // its parent. While removeLoop/removeChildLoop remove the given loop but
702 // not relink its subloops, which is what we want.
703 if (Loop *ParentLoop = L->getParentLoop()) {
704 Loop::iterator I = find(Range&: *ParentLoop, Val: L);
705 assert(I != ParentLoop->end() && "Couldn't find loop");
706 ParentLoop->removeChildLoop(I);
707 } else {
708 Loop::iterator I = find(Range&: *LI, Val: L);
709 assert(I != LI->end() && "Couldn't find loop");
710 LI->removeLoop(I);
711 }
712 LI->destroy(L);
713 }
714}
715
716void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
717 LoopInfo &LI, MemorySSA *MSSA) {
718 auto *Latch = L->getLoopLatch();
719 assert(Latch && "multiple latches not yet supported");
720 auto *Header = L->getHeader();
721 Loop *OutermostLoop = L->getOutermostLoop();
722
723 SE.forgetLoop(L);
724 SE.forgetBlockAndLoopDispositions();
725
726 std::unique_ptr<MemorySSAUpdater> MSSAU;
727 if (MSSA)
728 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
729
730 // Update the CFG and domtree. We chose to special case a couple of
731 // of common cases for code quality and test readability reasons.
732 [&]() -> void {
733 if (auto *BI = dyn_cast<BranchInst>(Val: Latch->getTerminator())) {
734 if (!BI->isConditional()) {
735 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
736 (void)changeToUnreachable(I: BI, /*PreserveLCSSA*/ true, DTU: &DTU,
737 MSSAU: MSSAU.get());
738 return;
739 }
740
741 // Conditional latch/exit - note that latch can be shared by inner
742 // and outer loop so the other target doesn't need to an exit
743 if (L->isLoopExiting(BB: Latch)) {
744 // TODO: Generalize ConstantFoldTerminator so that it can be used
745 // here without invalidating LCSSA or MemorySSA. (Tricky case for
746 // LCSSA: header is an exit block of a preceeding sibling loop w/o
747 // dedicated exits.)
748 const unsigned ExitIdx = L->contains(BB: BI->getSuccessor(i: 0)) ? 1 : 0;
749 BasicBlock *ExitBB = BI->getSuccessor(i: ExitIdx);
750
751 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
752 Header->removePredecessor(Pred: Latch, KeepOneInputPHIs: true);
753
754 IRBuilder<> Builder(BI);
755 auto *NewBI = Builder.CreateBr(Dest: ExitBB);
756 // Transfer the metadata to the new branch instruction (minus the
757 // loop info since this is no longer a loop)
758 NewBI->copyMetadata(SrcInst: *BI, WL: {LLVMContext::MD_dbg,
759 LLVMContext::MD_annotation});
760
761 BI->eraseFromParent();
762 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}});
763 if (MSSA)
764 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}}, DT);
765 return;
766 }
767 }
768
769 // General case. By splitting the backedge, and then explicitly making it
770 // unreachable we gracefully handle corner cases such as switch and invoke
771 // termiantors.
772 auto *BackedgeBB = SplitEdge(From: Latch, To: Header, DT: &DT, LI: &LI, MSSAU: MSSAU.get());
773
774 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
775 (void)changeToUnreachable(I: BackedgeBB->getTerminator(),
776 /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
777 }();
778
779 // Erase (and destroy) this loop instance. Handles relinking sub-loops
780 // and blocks within the loop as needed.
781 LI.erase(L);
782
783 // If the loop we broke had a parent, then changeToUnreachable might have
784 // caused a block to be removed from the parent loop (see loop_nest_lcssa
785 // test case in zero-btc.ll for an example), thus changing the parent's
786 // exit blocks. If that happened, we need to rebuild LCSSA on the outermost
787 // loop which might have a had a block removed.
788 if (OutermostLoop != L)
789 formLCSSARecursively(L&: *OutermostLoop, DT, LI: &LI, SE: &SE);
790}
791
792
793/// Checks if \p L has an exiting latch branch. There may also be other
794/// exiting blocks. Returns branch instruction terminating the loop
795/// latch if above check is successful, nullptr otherwise.
796static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) {
797 BasicBlock *Latch = L->getLoopLatch();
798 if (!Latch)
799 return nullptr;
800
801 BranchInst *LatchBR = dyn_cast<BranchInst>(Val: Latch->getTerminator());
802 if (!LatchBR || LatchBR->getNumSuccessors() != 2 || !L->isLoopExiting(BB: Latch))
803 return nullptr;
804
805 assert((LatchBR->getSuccessor(0) == L->getHeader() ||
806 LatchBR->getSuccessor(1) == L->getHeader()) &&
807 "At least one edge out of the latch must go to the header");
808
809 return LatchBR;
810}
811
812struct DbgLoop {
813 const Loop *L;
814 explicit DbgLoop(const Loop *L) : L(L) {}
815};
816
817#ifndef NDEBUG
818static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
819 OS << "function ";
820 D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
821 return OS << " " << *D.L;
822}
823#endif // NDEBUG
824
825static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
826 // Currently we take the estimate exit count only from the loop latch,
827 // ignoring other exiting blocks. This can overestimate the trip count
828 // if we exit through another exit, but can never underestimate it.
829 // TODO: incorporate information from other exits
830 BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
831 if (!ExitingBranch) {
832 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
833 << "latch branch of required form in " << DbgLoop(L)
834 << "\n");
835 return std::nullopt;
836 }
837
838 // To estimate the number of times the loop body was executed, we want to
839 // know the number of times the backedge was taken, vs. the number of times
840 // we exited the loop.
841 uint64_t LoopWeight, ExitWeight;
842 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight)) {
843 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
844 << "weights for " << DbgLoop(L) << "\n");
845 return std::nullopt;
846 }
847
848 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
849 std::swap(a&: LoopWeight, b&: ExitWeight);
850
851 if (!ExitWeight) {
852 // Don't have a way to return predicated infinite
853 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
854 << "probability for " << DbgLoop(L) << "\n");
855 return std::nullopt;
856 }
857
858 // Estimated exit count is a ratio of the loop weight by the weight of the
859 // edge exiting the loop, rounded to nearest.
860 uint64_t ExitCount = llvm::divideNearest(Numerator: LoopWeight, Denominator: ExitWeight);
861
862 // When ExitCount + 1 would wrap in unsigned, saturate at UINT_MAX.
863 if (ExitCount >= std::numeric_limits<unsigned>::max())
864 return std::numeric_limits<unsigned>::max();
865
866 // Estimated trip count is one plus estimated exit count.
867 uint64_t TC = ExitCount + 1;
868 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Estimated trip count of " << TC
869 << " for " << DbgLoop(L) << "\n");
870 return TC;
871}
872
873std::optional<unsigned>
874llvm::getLoopEstimatedTripCount(Loop *L,
875 unsigned *EstimatedLoopInvocationWeight) {
876 // If EstimatedLoopInvocationWeight, we do not support this loop if
877 // getExpectedExitLoopLatchBranch returns nullptr.
878 //
879 // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
880 // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
881 // it was created for an inner loop. The problem is that loop metadata is
882 // attached to the branch instruction in the loop latch block, but that can be
883 // shared by the loops. A solution is to attach loop metadata to loop headers
884 // instead, but that would be a large change to LLVM.
885 //
886 // Until that happens, we work around the problem as follows.
887 // getExpectedExitLoopLatchBranch (which also guards
888 // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
889 // one latch and that latch has exactly two successors one of which is an exit
890 // from the loop. If the latch is shared by nested loops, then that condition
891 // might hold for the inner loop but cannot hold for the outer loop:
892 // - Because the latch is shared, it must have at least two successors: the
893 // inner loop header and the outer loop header, which is also an exit for
894 // the inner loop. That satisifies the condition for the inner loop.
895 // - To satsify the condition for the outer loop, the latch must have a third
896 // successor that is an exit for the outer loop. But that violates the
897 // condition for both loops.
898 BranchInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
899 if (!ExitingBranch)
900 return std::nullopt;
901
902 // If requested, either compute *EstimatedLoopInvocationWeight or return
903 // nullopt if cannot.
904 //
905 // TODO: Eventually, once all passes have migrated away from setting branch
906 // weights to indicate estimated trip counts, this function will drop the
907 // EstimatedLoopInvocationWeight parameter.
908 if (EstimatedLoopInvocationWeight) {
909 uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
910 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight))
911 return std::nullopt;
912 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
913 std::swap(a&: LoopWeight, b&: ExitWeight);
914 if (!ExitWeight)
915 return std::nullopt;
916 *EstimatedLoopInvocationWeight = ExitWeight;
917 }
918
919 // Return the estimated trip count from metadata unless the metadata is
920 // missing or has no value.
921 //
922 // Some passes set llvm.loop.estimated_trip_count to 0. For example, after
923 // peeling 10 or more iterations from a loop with an estimated trip count of
924 // 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It
925 // indicates that, each time execution reaches the peeled iterations,
926 // execution is estimated to exit them without reaching the remaining loop's
927 // header.
928 //
929 // Even if the probability of reaching a loop's header is low, if it is
930 // reached, it is the start of an iteration. Consequently, some passes
931 // historically assume that llvm::getLoopEstimatedTripCount always returns a
932 // positive count or std::nullopt. Thus, return std::nullopt when
933 // llvm.loop.estimated_trip_count is 0.
934 if (auto TC = getOptionalIntLoopAttribute(TheLoop: L, Name: LLVMLoopEstimatedTripCount)) {
935 LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
936 << LLVMLoopEstimatedTripCount << " metadata has trip "
937 << "count of " << *TC
938 << (*TC == 0 ? " (returning std::nullopt)" : "")
939 << " for " << DbgLoop(L) << "\n");
940 return *TC == 0 ? std::nullopt : std::optional(*TC);
941 }
942
943 // Estimate the trip count from latch branch weights.
944 return estimateLoopTripCount(L);
945}
946
947bool llvm::setLoopEstimatedTripCount(
948 Loop *L, unsigned EstimatedTripCount,
949 std::optional<unsigned> EstimatedloopInvocationWeight) {
950 // If EstimatedLoopInvocationWeight, we do not support this loop if
951 // getExpectedExitLoopLatchBranch returns nullptr.
952 //
953 // FIXME: See comments in getLoopEstimatedTripCount for why this is required
954 // here regardless of EstimatedLoopInvocationWeight.
955 BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
956 if (!LatchBranch)
957 return false;
958
959 // Set the metadata.
960 addStringMetadataToLoop(TheLoop: L, StringMD: LLVMLoopEstimatedTripCount, V: EstimatedTripCount);
961
962 // At the moment, we currently support changing the estimated trip count in
963 // the latch branch's branch weights only. We could extend this API to
964 // manipulate estimated trip counts for any exit.
965 //
966 // TODO: Eventually, once all passes have migrated away from setting branch
967 // weights to indicate estimated trip counts, we will not set branch weights
968 // here at all.
969 if (!EstimatedloopInvocationWeight)
970 return true;
971
972 // Calculate taken and exit weights.
973 unsigned LatchExitWeight = ProfcheckDisableMetadataFixes ? 0 : 1;
974 unsigned BackedgeTakenWeight = 0;
975
976 if (EstimatedTripCount != 0) {
977 LatchExitWeight = *EstimatedloopInvocationWeight;
978 BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
979 }
980
981 // Make a swap if back edge is taken when condition is "false".
982 if (LatchBranch->getSuccessor(i: 0) != L->getHeader())
983 std::swap(a&: BackedgeTakenWeight, b&: LatchExitWeight);
984
985 // Set/Update profile metadata.
986 setBranchWeights(I&: *LatchBranch, Weights: {BackedgeTakenWeight, LatchExitWeight},
987 /*IsExpected=*/false);
988
989 return true;
990}
991
992BranchProbability llvm::getLoopProbability(Loop *L) {
993 BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
994 if (!LatchBranch)
995 return BranchProbability::getUnknown();
996 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
997 return getBranchProbability(B: LatchBranch, ForFirstTarget: FirstTargetIsLoop);
998}
999
1000bool llvm::setLoopProbability(Loop *L, BranchProbability P) {
1001 BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
1002 if (!LatchBranch)
1003 return false;
1004 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
1005 return setBranchProbability(B: LatchBranch, P, ForFirstTarget: FirstTargetIsLoop);
1006}
1007
1008BranchProbability llvm::getBranchProbability(BranchInst *B,
1009 bool ForFirstTarget) {
1010 if (B->getNumSuccessors() != 2)
1011 return BranchProbability::getUnknown();
1012 uint64_t Weight0, Weight1;
1013 if (!extractBranchWeights(I: *B, TrueVal&: Weight0, FalseVal&: Weight1))
1014 return BranchProbability::getUnknown();
1015 uint64_t Denominator = Weight0 + Weight1;
1016 if (Denominator == 0)
1017 return BranchProbability::getUnknown();
1018 if (!ForFirstTarget)
1019 std::swap(a&: Weight0, b&: Weight1);
1020 return BranchProbability::getBranchProbability(Numerator: Weight0, Denominator);
1021}
1022
1023BranchProbability llvm::getBranchProbability(BasicBlock *Src, BasicBlock *Dst) {
1024 assert(Src != Dst && "Passed in same source as destination");
1025
1026 Instruction *TI = Src->getTerminator();
1027 if (!TI || TI->getNumSuccessors() == 0)
1028 return BranchProbability::getZero();
1029
1030 SmallVector<uint32_t, 4> Weights;
1031
1032 if (!extractBranchWeights(I: *TI, Weights)) {
1033 // No metadata
1034 return BranchProbability::getUnknown();
1035 }
1036 assert(TI->getNumSuccessors() == Weights.size() &&
1037 "Missing weights in branch_weights");
1038
1039 uint64_t Total = 0;
1040 uint32_t Numerator = 0;
1041 for (auto [i, Weight] : llvm::enumerate(First&: Weights)) {
1042 if (TI->getSuccessor(Idx: i) == Dst)
1043 Numerator += Weight;
1044 Total += Weight;
1045 }
1046
1047 // Total of edges might be 0 if the metadata is incorrect/set by hand
1048 // or missing. In such case return here to avoid division by 0 later on.
1049 // There might also be a case where the value of Total cannot fit into
1050 // uint32_t, in such case, just bail out.
1051 if (Total == 0 || Total > std::numeric_limits<uint32_t>::max())
1052 return BranchProbability::getUnknown();
1053
1054 return BranchProbability(Numerator, Total);
1055}
1056
1057bool llvm::setBranchProbability(BranchInst *B, BranchProbability P,
1058 bool ForFirstTarget) {
1059 if (B->getNumSuccessors() != 2)
1060 return false;
1061 BranchProbability Prob0 = P;
1062 BranchProbability Prob1 = P.getCompl();
1063 if (!ForFirstTarget)
1064 std::swap(a&: Prob0, b&: Prob1);
1065 setBranchWeights(I&: *B, Weights: {Prob0.getNumerator(), Prob1.getNumerator()},
1066 /*IsExpected=*/false);
1067 return true;
1068}
1069
1070bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
1071 ScalarEvolution &SE) {
1072 Loop *OuterL = InnerLoop->getParentLoop();
1073 if (!OuterL)
1074 return true;
1075
1076 // Get the backedge taken count for the inner loop
1077 BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch();
1078 const SCEV *InnerLoopBECountSC = SE.getExitCount(L: InnerLoop, ExitingBlock: InnerLoopLatch);
1079 if (isa<SCEVCouldNotCompute>(Val: InnerLoopBECountSC) ||
1080 !InnerLoopBECountSC->getType()->isIntegerTy())
1081 return false;
1082
1083 // Get whether count is invariant to the outer loop
1084 ScalarEvolution::LoopDisposition LD =
1085 SE.getLoopDisposition(S: InnerLoopBECountSC, L: OuterL);
1086 if (LD != ScalarEvolution::LoopInvariant)
1087 return false;
1088
1089 return true;
1090}
1091
1092constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
1093 switch (RK) {
1094 default:
1095 llvm_unreachable("Unexpected recurrence kind");
1096 case RecurKind::AddChainWithSubs:
1097 case RecurKind::Sub:
1098 case RecurKind::Add:
1099 return Intrinsic::vector_reduce_add;
1100 case RecurKind::Mul:
1101 return Intrinsic::vector_reduce_mul;
1102 case RecurKind::And:
1103 return Intrinsic::vector_reduce_and;
1104 case RecurKind::Or:
1105 return Intrinsic::vector_reduce_or;
1106 case RecurKind::Xor:
1107 return Intrinsic::vector_reduce_xor;
1108 case RecurKind::FMulAdd:
1109 case RecurKind::FAdd:
1110 return Intrinsic::vector_reduce_fadd;
1111 case RecurKind::FMul:
1112 return Intrinsic::vector_reduce_fmul;
1113 case RecurKind::SMax:
1114 return Intrinsic::vector_reduce_smax;
1115 case RecurKind::SMin:
1116 return Intrinsic::vector_reduce_smin;
1117 case RecurKind::UMax:
1118 return Intrinsic::vector_reduce_umax;
1119 case RecurKind::UMin:
1120 return Intrinsic::vector_reduce_umin;
1121 case RecurKind::FMax:
1122 case RecurKind::FMaxNum:
1123 return Intrinsic::vector_reduce_fmax;
1124 case RecurKind::FMin:
1125 case RecurKind::FMinNum:
1126 return Intrinsic::vector_reduce_fmin;
1127 case RecurKind::FMaximum:
1128 return Intrinsic::vector_reduce_fmaximum;
1129 case RecurKind::FMinimum:
1130 return Intrinsic::vector_reduce_fminimum;
1131 case RecurKind::FMaximumNum:
1132 return Intrinsic::vector_reduce_fmax;
1133 case RecurKind::FMinimumNum:
1134 return Intrinsic::vector_reduce_fmin;
1135 }
1136}
1137
1138Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
1139 switch (IID) {
1140 default:
1141 llvm_unreachable("Unexpected intrinsic id");
1142 case Intrinsic::umin:
1143 return Intrinsic::vector_reduce_umin;
1144 case Intrinsic::umax:
1145 return Intrinsic::vector_reduce_umax;
1146 case Intrinsic::smin:
1147 return Intrinsic::vector_reduce_smin;
1148 case Intrinsic::smax:
1149 return Intrinsic::vector_reduce_smax;
1150 }
1151}
1152
1153// This is the inverse to getReductionForBinop
1154unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
1155 switch (RdxID) {
1156 case Intrinsic::vector_reduce_fadd:
1157 return Instruction::FAdd;
1158 case Intrinsic::vector_reduce_fmul:
1159 return Instruction::FMul;
1160 case Intrinsic::vector_reduce_add:
1161 return Instruction::Add;
1162 case Intrinsic::vector_reduce_mul:
1163 return Instruction::Mul;
1164 case Intrinsic::vector_reduce_and:
1165 return Instruction::And;
1166 case Intrinsic::vector_reduce_or:
1167 return Instruction::Or;
1168 case Intrinsic::vector_reduce_xor:
1169 return Instruction::Xor;
1170 case Intrinsic::vector_reduce_smax:
1171 case Intrinsic::vector_reduce_smin:
1172 case Intrinsic::vector_reduce_umax:
1173 case Intrinsic::vector_reduce_umin:
1174 return Instruction::ICmp;
1175 case Intrinsic::vector_reduce_fmax:
1176 case Intrinsic::vector_reduce_fmin:
1177 return Instruction::FCmp;
1178 default:
1179 llvm_unreachable("Unexpected ID");
1180 }
1181}
1182
1183// This is the inverse to getArithmeticReductionInstruction
1184Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
1185 switch (Opc) {
1186 default:
1187 break;
1188 case Instruction::Add:
1189 return Intrinsic::vector_reduce_add;
1190 case Instruction::Mul:
1191 return Intrinsic::vector_reduce_mul;
1192 case Instruction::And:
1193 return Intrinsic::vector_reduce_and;
1194 case Instruction::Or:
1195 return Intrinsic::vector_reduce_or;
1196 case Instruction::Xor:
1197 return Intrinsic::vector_reduce_xor;
1198 }
1199 return Intrinsic::not_intrinsic;
1200}
1201
1202Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
1203 switch (RdxID) {
1204 default:
1205 llvm_unreachable("Unknown min/max recurrence kind");
1206 case Intrinsic::vector_reduce_umin:
1207 return Intrinsic::umin;
1208 case Intrinsic::vector_reduce_umax:
1209 return Intrinsic::umax;
1210 case Intrinsic::vector_reduce_smin:
1211 return Intrinsic::smin;
1212 case Intrinsic::vector_reduce_smax:
1213 return Intrinsic::smax;
1214 case Intrinsic::vector_reduce_fmin:
1215 return Intrinsic::minnum;
1216 case Intrinsic::vector_reduce_fmax:
1217 return Intrinsic::maxnum;
1218 case Intrinsic::vector_reduce_fminimum:
1219 return Intrinsic::minimum;
1220 case Intrinsic::vector_reduce_fmaximum:
1221 return Intrinsic::maximum;
1222 }
1223}
1224
1225Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
1226 switch (RK) {
1227 default:
1228 llvm_unreachable("Unknown min/max recurrence kind");
1229 case RecurKind::UMin:
1230 return Intrinsic::umin;
1231 case RecurKind::UMax:
1232 return Intrinsic::umax;
1233 case RecurKind::SMin:
1234 return Intrinsic::smin;
1235 case RecurKind::SMax:
1236 return Intrinsic::smax;
1237 case RecurKind::FMin:
1238 case RecurKind::FMinNum:
1239 return Intrinsic::minnum;
1240 case RecurKind::FMax:
1241 case RecurKind::FMaxNum:
1242 return Intrinsic::maxnum;
1243 case RecurKind::FMinimum:
1244 return Intrinsic::minimum;
1245 case RecurKind::FMaximum:
1246 return Intrinsic::maximum;
1247 case RecurKind::FMinimumNum:
1248 return Intrinsic::minimumnum;
1249 case RecurKind::FMaximumNum:
1250 return Intrinsic::maximumnum;
1251 }
1252}
1253
1254RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
1255 switch (RdxID) {
1256 case Intrinsic::vector_reduce_smax:
1257 return RecurKind::SMax;
1258 case Intrinsic::vector_reduce_smin:
1259 return RecurKind::SMin;
1260 case Intrinsic::vector_reduce_umax:
1261 return RecurKind::UMax;
1262 case Intrinsic::vector_reduce_umin:
1263 return RecurKind::UMin;
1264 case Intrinsic::vector_reduce_fmax:
1265 return RecurKind::FMax;
1266 case Intrinsic::vector_reduce_fmin:
1267 return RecurKind::FMin;
1268 default:
1269 return RecurKind::None;
1270 }
1271}
1272
1273CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
1274 switch (RK) {
1275 default:
1276 llvm_unreachable("Unknown min/max recurrence kind");
1277 case RecurKind::UMin:
1278 return CmpInst::ICMP_ULT;
1279 case RecurKind::UMax:
1280 return CmpInst::ICMP_UGT;
1281 case RecurKind::SMin:
1282 return CmpInst::ICMP_SLT;
1283 case RecurKind::SMax:
1284 return CmpInst::ICMP_SGT;
1285 case RecurKind::FMin:
1286 return CmpInst::FCMP_OLT;
1287 case RecurKind::FMax:
1288 return CmpInst::FCMP_OGT;
1289 // We do not add FMinimum/FMaximum recurrence kind here since there is no
1290 // equivalent predicate which compares signed zeroes according to the
1291 // semantics of the intrinsics (llvm.minimum/maximum).
1292 }
1293}
1294
1295Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
1296 Value *Right) {
1297 Type *Ty = Left->getType();
1298 if (Ty->isIntOrIntVectorTy() ||
1299 (RK == RecurKind::FMinNum || RK == RecurKind::FMaxNum ||
1300 RK == RecurKind::FMinimum || RK == RecurKind::FMaximum ||
1301 RK == RecurKind::FMinimumNum || RK == RecurKind::FMaximumNum)) {
1302 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK);
1303 return Builder.CreateIntrinsic(RetTy: Ty, ID: Id, Args: {Left, Right}, FMFSource: nullptr,
1304 Name: "rdx.minmax");
1305 }
1306 CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
1307 Value *Cmp = Builder.CreateCmp(Pred, LHS: Left, RHS: Right, Name: "rdx.minmax.cmp");
1308 Value *Select = Builder.CreateSelect(C: Cmp, True: Left, False: Right, Name: "rdx.minmax.select");
1309 return Select;
1310}
1311
1312// Helper to generate an ordered reduction.
1313Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
1314 unsigned Op, RecurKind RdxKind) {
1315 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1316
1317 // Extract and apply reduction ops in ascending order:
1318 // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
1319 Value *Result = Acc;
1320 for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
1321 Value *Ext =
1322 Builder.CreateExtractElement(Vec: Src, Idx: Builder.getInt32(C: ExtractIdx));
1323
1324 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1325 Result = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: Result, RHS: Ext,
1326 Name: "bin.rdx");
1327 } else {
1328 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1329 "Invalid min/max");
1330 Result = createMinMaxOp(Builder, RK: RdxKind, Left: Result, Right: Ext);
1331 }
1332 }
1333
1334 return Result;
1335}
1336
1337// Helper to generate a log2 shuffle reduction.
1338Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1339 unsigned Op,
1340 TargetTransformInfo::ReductionShuffle RS,
1341 RecurKind RdxKind) {
1342 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1343 // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
1344 // and vector ops, reducing the set of values being computed by half each
1345 // round.
1346 assert(isPowerOf2_32(VF) &&
1347 "Reduction emission only supported for pow2 vectors!");
1348 // Note: fast-math-flags flags are controlled by the builder configuration
1349 // and are assumed to apply to all generated arithmetic instructions. Other
1350 // poison generating flags (nsw/nuw/inbounds/inrange/exact) are not part
1351 // of the builder configuration, and since they're not passed explicitly,
1352 // will never be relevant here. Note that it would be generally unsound to
1353 // propagate these from an intrinsic call to the expansion anyways as we/
1354 // change the order of operations.
1355 auto BuildShuffledOp = [&Builder, &Op,
1356 &RdxKind](SmallVectorImpl<int> &ShuffleMask,
1357 Value *&TmpVec) -> void {
1358 Value *Shuf = Builder.CreateShuffleVector(V: TmpVec, Mask: ShuffleMask, Name: "rdx.shuf");
1359 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1360 TmpVec = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: TmpVec, RHS: Shuf,
1361 Name: "bin.rdx");
1362 } else {
1363 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1364 "Invalid min/max");
1365 TmpVec = createMinMaxOp(Builder, RK: RdxKind, Left: TmpVec, Right: Shuf);
1366 }
1367 };
1368
1369 Value *TmpVec = Src;
1370 if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1371 SmallVector<int, 32> ShuffleMask(VF);
1372 for (unsigned stride = 1; stride < VF; stride <<= 1) {
1373 // Initialise the mask with undef.
1374 llvm::fill(Range&: ShuffleMask, Value: -1);
1375 for (unsigned j = 0; j < VF; j += stride << 1) {
1376 ShuffleMask[j] = j + stride;
1377 }
1378 BuildShuffledOp(ShuffleMask, TmpVec);
1379 }
1380 } else {
1381 SmallVector<int, 32> ShuffleMask(VF);
1382 for (unsigned i = VF; i != 1; i >>= 1) {
1383 // Move the upper half of the vector to the lower half.
1384 for (unsigned j = 0; j != i / 2; ++j)
1385 ShuffleMask[j] = i / 2 + j;
1386
1387 // Fill the rest of the mask with undef.
1388 std::fill(first: &ShuffleMask[i / 2], last: ShuffleMask.end(), value: -1);
1389 BuildShuffledOp(ShuffleMask, TmpVec);
1390 }
1391 }
1392 // The result is in the first element of the vector.
1393 return Builder.CreateExtractElement(Vec: TmpVec, Idx: Builder.getInt32(C: 0));
1394}
1395
1396Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
1397 Value *InitVal, PHINode *OrigPhi) {
1398 Value *NewVal = nullptr;
1399
1400 // First use the original phi to determine the new value we're trying to
1401 // select from in the loop.
1402 SelectInst *SI = nullptr;
1403 for (auto *U : OrigPhi->users()) {
1404 if ((SI = dyn_cast<SelectInst>(Val: U)))
1405 break;
1406 }
1407 assert(SI && "One user of the original phi should be a select");
1408
1409 if (SI->getTrueValue() == OrigPhi)
1410 NewVal = SI->getFalseValue();
1411 else {
1412 assert(SI->getFalseValue() == OrigPhi &&
1413 "At least one input to the select should be the original Phi");
1414 NewVal = SI->getTrueValue();
1415 }
1416
1417 // If any predicate is true it means that we want to select the new value.
1418 Value *AnyOf =
1419 Src->getType()->isVectorTy() ? Builder.CreateOrReduce(Src) : Src;
1420 // The compares in the loop may yield poison, which propagates through the
1421 // bitwise ORs. Freeze it here before the condition is used.
1422 AnyOf = Builder.CreateFreeze(V: AnyOf);
1423 return Builder.CreateSelect(C: AnyOf, True: NewVal, False: InitVal, Name: "rdx.select");
1424}
1425
1426Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
1427 FastMathFlags Flags) {
1428 bool Negative = false;
1429 switch (RdxID) {
1430 default:
1431 llvm_unreachable("Expecting a reduction intrinsic");
1432 case Intrinsic::vector_reduce_add:
1433 case Intrinsic::vector_reduce_mul:
1434 case Intrinsic::vector_reduce_or:
1435 case Intrinsic::vector_reduce_xor:
1436 case Intrinsic::vector_reduce_and:
1437 case Intrinsic::vector_reduce_fadd:
1438 case Intrinsic::vector_reduce_fmul: {
1439 unsigned Opc = getArithmeticReductionInstruction(RdxID);
1440 return ConstantExpr::getBinOpIdentity(Opcode: Opc, Ty, AllowRHSConstant: false,
1441 NSZ: Flags.noSignedZeros());
1442 }
1443 case Intrinsic::vector_reduce_umax:
1444 case Intrinsic::vector_reduce_umin:
1445 case Intrinsic::vector_reduce_smin:
1446 case Intrinsic::vector_reduce_smax: {
1447 Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
1448 return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
1449 }
1450 case Intrinsic::vector_reduce_fmax:
1451 case Intrinsic::vector_reduce_fmaximum:
1452 Negative = true;
1453 [[fallthrough]];
1454 case Intrinsic::vector_reduce_fmin:
1455 case Intrinsic::vector_reduce_fminimum: {
1456 bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
1457 RdxID == Intrinsic::vector_reduce_fmaximum;
1458 const fltSemantics &Semantics = Ty->getFltSemantics();
1459 return (!Flags.noNaNs() && !PropagatesNaN)
1460 ? ConstantFP::getQNaN(Ty, Negative)
1461 : !Flags.noInfs()
1462 ? ConstantFP::getInfinity(Ty, Negative)
1463 : ConstantFP::get(Ty, V: APFloat::getLargest(Sem: Semantics, Negative));
1464 }
1465 }
1466}
1467
1468Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
1469 assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
1470 (FMF.noNaNs() && FMF.noSignedZeros())) &&
1471 "nnan, nsz is expected to be set for FP min/max reduction.");
1472 Intrinsic::ID RdxID = getReductionIntrinsicID(RK: K);
1473 return getReductionIdentity(RdxID, Ty: Tp, Flags: FMF);
1474}
1475
1476Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1477 RecurKind RdxKind) {
1478 auto *SrcVecEltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1479 auto getIdentity = [&]() {
1480 return getRecurrenceIdentity(K: RdxKind, Tp: SrcVecEltTy,
1481 FMF: Builder.getFastMathFlags());
1482 };
1483 switch (RdxKind) {
1484 case RecurKind::AddChainWithSubs:
1485 case RecurKind::Sub:
1486 case RecurKind::Add:
1487 case RecurKind::Mul:
1488 case RecurKind::And:
1489 case RecurKind::Or:
1490 case RecurKind::Xor:
1491 case RecurKind::SMax:
1492 case RecurKind::SMin:
1493 case RecurKind::UMax:
1494 case RecurKind::UMin:
1495 case RecurKind::FMax:
1496 case RecurKind::FMin:
1497 case RecurKind::FMinNum:
1498 case RecurKind::FMaxNum:
1499 case RecurKind::FMinimum:
1500 case RecurKind::FMaximum:
1501 case RecurKind::FMinimumNum:
1502 case RecurKind::FMaximumNum:
1503 return Builder.CreateUnaryIntrinsic(ID: getReductionIntrinsicID(RK: RdxKind), V: Src);
1504 case RecurKind::FMulAdd:
1505 case RecurKind::FAdd:
1506 return Builder.CreateFAddReduce(Acc: getIdentity(), Src);
1507 case RecurKind::FMul:
1508 return Builder.CreateFMulReduce(Acc: getIdentity(), Src);
1509 default:
1510 llvm_unreachable("Unhandled opcode");
1511 }
1512}
1513
1514Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1515 RecurKind Kind, Value *Mask, Value *EVL) {
1516 assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1517 !RecurrenceDescriptor::isFindRecurrenceKind(Kind) &&
1518 "AnyOf and FindIV reductions are not supported.");
1519 Intrinsic::ID Id = getReductionIntrinsicID(RK: Kind);
1520 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1521 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1522 "No VPIntrinsic for this reduction");
1523 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1524 Value *Iden = getRecurrenceIdentity(K: Kind, Tp: EltTy, FMF: Builder.getFastMathFlags());
1525 Value *Ops[] = {Iden, Src, Mask, EVL};
1526 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1527}
1528
1529Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
1530 Value *Src, Value *Start) {
1531 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1532 "Unexpected reduction kind");
1533 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1534 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1535
1536 return B.CreateFAddReduce(Acc: Start, Src);
1537}
1538
1539Value *llvm::createOrderedReduction(IRBuilderBase &Builder, RecurKind Kind,
1540 Value *Src, Value *Start, Value *Mask,
1541 Value *EVL) {
1542 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1543 "Unexpected reduction kind");
1544 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1545 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1546
1547 Intrinsic::ID Id = getReductionIntrinsicID(RK: RecurKind::FAdd);
1548 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1549 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1550 "No VPIntrinsic for this reduction");
1551 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1552 Value *Ops[] = {Start, Src, Mask, EVL};
1553 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1554}
1555
1556void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
1557 bool IncludeWrapFlags) {
1558 auto *VecOp = dyn_cast<Instruction>(Val: I);
1559 if (!VecOp)
1560 return;
1561 auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(Val: VL[0])
1562 : dyn_cast<Instruction>(Val: OpValue);
1563 if (!Intersection)
1564 return;
1565 const unsigned Opcode = Intersection->getOpcode();
1566 VecOp->copyIRFlags(V: Intersection, IncludeWrapFlags);
1567 for (auto *V : VL) {
1568 auto *Instr = dyn_cast<Instruction>(Val: V);
1569 if (!Instr)
1570 continue;
1571 if (OpValue == nullptr || Opcode == Instr->getOpcode())
1572 VecOp->andIRFlags(V);
1573 }
1574}
1575
1576bool llvm::isKnownNegativeInLoop(const SCEV *S, const Loop *L,
1577 ScalarEvolution &SE) {
1578 const SCEV *Zero = SE.getZero(Ty: S->getType());
1579 return SE.isAvailableAtLoopEntry(S, L) &&
1580 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLT, LHS: S, RHS: Zero);
1581}
1582
1583bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L,
1584 ScalarEvolution &SE) {
1585 const SCEV *Zero = SE.getZero(Ty: S->getType());
1586 return SE.isAvailableAtLoopEntry(S, L) &&
1587 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGE, LHS: S, RHS: Zero);
1588}
1589
1590bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L,
1591 ScalarEvolution &SE) {
1592 const SCEV *Zero = SE.getZero(Ty: S->getType());
1593 return SE.isAvailableAtLoopEntry(S, L) &&
1594 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGT, LHS: S, RHS: Zero);
1595}
1596
1597bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L,
1598 ScalarEvolution &SE) {
1599 const SCEV *Zero = SE.getZero(Ty: S->getType());
1600 return SE.isAvailableAtLoopEntry(S, L) &&
1601 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLE, LHS: S, RHS: Zero);
1602}
1603
1604bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1605 bool Signed) {
1606 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1607 APInt Min = Signed ? APInt::getSignedMinValue(numBits: BitWidth) :
1608 APInt::getMinValue(numBits: BitWidth);
1609 auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
1610 return SE.isAvailableAtLoopEntry(S, L) &&
1611 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1612 RHS: SE.getConstant(Val: Min));
1613}
1614
1615bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1616 bool Signed) {
1617 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1618 APInt Max = Signed ? APInt::getSignedMaxValue(numBits: BitWidth) :
1619 APInt::getMaxValue(numBits: BitWidth);
1620 auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1621 return SE.isAvailableAtLoopEntry(S, L) &&
1622 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1623 RHS: SE.getConstant(Val: Max));
1624}
1625
1626//===----------------------------------------------------------------------===//
1627// rewriteLoopExitValues - Optimize IV users outside the loop.
1628// As a side effect, reduces the amount of IV processing within the loop.
1629//===----------------------------------------------------------------------===//
1630
1631static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
1632 SmallPtrSet<const Instruction *, 8> Visited;
1633 SmallVector<const Instruction *, 8> WorkList;
1634 Visited.insert(Ptr: I);
1635 WorkList.push_back(Elt: I);
1636 while (!WorkList.empty()) {
1637 const Instruction *Curr = WorkList.pop_back_val();
1638 // This use is outside the loop, nothing to do.
1639 if (!L->contains(Inst: Curr))
1640 continue;
1641 // Do we assume it is a "hard" use which will not be eliminated easily?
1642 if (Curr->mayHaveSideEffects())
1643 return true;
1644 // Otherwise, add all its users to worklist.
1645 for (const auto *U : Curr->users()) {
1646 auto *UI = cast<Instruction>(Val: U);
1647 if (Visited.insert(Ptr: UI).second)
1648 WorkList.push_back(Elt: UI);
1649 }
1650 }
1651 return false;
1652}
1653
1654// Collect information about PHI nodes which can be transformed in
1655// rewriteLoopExitValues.
1656struct RewritePhi {
1657 PHINode *PN; // For which PHI node is this replacement?
1658 unsigned Ith; // For which incoming value?
1659 const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
1660 Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
1661 bool HighCost; // Is this expansion a high-cost?
1662
1663 RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
1664 bool H)
1665 : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
1666 HighCost(H) {}
1667};
1668
1669// Check whether it is possible to delete the loop after rewriting exit
1670// value. If it is possible, ignore ReplaceExitValue and do rewriting
1671// aggressively.
1672static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) {
1673 BasicBlock *Preheader = L->getLoopPreheader();
1674 // If there is no preheader, the loop will not be deleted.
1675 if (!Preheader)
1676 return false;
1677
1678 // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1.
1679 // We obviate multiple ExitingBlocks case for simplicity.
1680 // TODO: If we see testcase with multiple ExitingBlocks can be deleted
1681 // after exit value rewriting, we can enhance the logic here.
1682 SmallVector<BasicBlock *, 4> ExitingBlocks;
1683 L->getExitingBlocks(ExitingBlocks);
1684 SmallVector<BasicBlock *, 8> ExitBlocks;
1685 L->getUniqueExitBlocks(ExitBlocks);
1686 if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
1687 return false;
1688
1689 BasicBlock *ExitBlock = ExitBlocks[0];
1690 BasicBlock::iterator BI = ExitBlock->begin();
1691 while (PHINode *P = dyn_cast<PHINode>(Val&: BI)) {
1692 Value *Incoming = P->getIncomingValueForBlock(BB: ExitingBlocks[0]);
1693
1694 // If the Incoming value of P is found in RewritePhiSet, we know it
1695 // could be rewritten to use a loop invariant value in transformation
1696 // phase later. Skip it in the loop invariant check below.
1697 bool found = false;
1698 for (const RewritePhi &Phi : RewritePhiSet) {
1699 unsigned i = Phi.Ith;
1700 if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
1701 found = true;
1702 break;
1703 }
1704 }
1705
1706 Instruction *I;
1707 if (!found && (I = dyn_cast<Instruction>(Val: Incoming)))
1708 if (!L->hasLoopInvariantOperands(I))
1709 return false;
1710
1711 ++BI;
1712 }
1713
1714 for (auto *BB : L->blocks())
1715 if (llvm::any_of(Range&: *BB, P: [](Instruction &I) {
1716 return I.mayHaveSideEffects();
1717 }))
1718 return false;
1719
1720 return true;
1721}
1722
1723/// Checks if it is safe to call InductionDescriptor::isInductionPHI for \p Phi,
1724/// and returns true if this Phi is an induction phi in the loop. When
1725/// isInductionPHI returns true, \p ID will be also be set by isInductionPHI.
1726static bool checkIsIndPhi(PHINode *Phi, Loop *L, ScalarEvolution *SE,
1727 InductionDescriptor &ID) {
1728 if (!Phi)
1729 return false;
1730 if (!L->getLoopPreheader())
1731 return false;
1732 if (Phi->getParent() != L->getHeader())
1733 return false;
1734 return InductionDescriptor::isInductionPHI(Phi, L, SE, D&: ID);
1735}
1736
1737int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
1738 ScalarEvolution *SE,
1739 const TargetTransformInfo *TTI,
1740 SCEVExpander &Rewriter, DominatorTree *DT,
1741 ReplaceExitVal ReplaceExitValue,
1742 SmallVector<WeakTrackingVH, 16> &DeadInsts) {
1743 // Check a pre-condition.
1744 assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1745 "Indvars did not preserve LCSSA!");
1746
1747 SmallVector<BasicBlock*, 8> ExitBlocks;
1748 L->getUniqueExitBlocks(ExitBlocks);
1749
1750 SmallVector<RewritePhi, 8> RewritePhiSet;
1751 // Find all values that are computed inside the loop, but used outside of it.
1752 // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan
1753 // the exit blocks of the loop to find them.
1754 for (BasicBlock *ExitBB : ExitBlocks) {
1755 // If there are no PHI nodes in this exit block, then no values defined
1756 // inside the loop are used on this path, skip it.
1757 PHINode *PN = dyn_cast<PHINode>(Val: ExitBB->begin());
1758 if (!PN) continue;
1759
1760 unsigned NumPreds = PN->getNumIncomingValues();
1761
1762 // Iterate over all of the PHI nodes.
1763 BasicBlock::iterator BBI = ExitBB->begin();
1764 while ((PN = dyn_cast<PHINode>(Val: BBI++))) {
1765 if (PN->use_empty())
1766 continue; // dead use, don't replace it
1767
1768 if (!SE->isSCEVable(Ty: PN->getType()))
1769 continue;
1770
1771 // Iterate over all of the values in all the PHI nodes.
1772 for (unsigned i = 0; i != NumPreds; ++i) {
1773 // If the value being merged in is not integer or is not defined
1774 // in the loop, skip it.
1775 Value *InVal = PN->getIncomingValue(i);
1776 if (!isa<Instruction>(Val: InVal))
1777 continue;
1778
1779 // If this pred is for a subloop, not L itself, skip it.
1780 if (LI->getLoopFor(BB: PN->getIncomingBlock(i)) != L)
1781 continue; // The Block is in a subloop, skip it.
1782
1783 // Check that InVal is defined in the loop.
1784 Instruction *Inst = cast<Instruction>(Val: InVal);
1785 if (!L->contains(Inst))
1786 continue;
1787
1788 // Find exit values which are induction variables in the loop, and are
1789 // unused in the loop, with the only use being the exit block PhiNode,
1790 // and the induction variable update binary operator.
1791 // The exit value can be replaced with the final value when it is cheap
1792 // to do so.
1793 if (ReplaceExitValue == UnusedIndVarInLoop) {
1794 InductionDescriptor ID;
1795 PHINode *IndPhi = dyn_cast<PHINode>(Val: Inst);
1796 if (IndPhi) {
1797 if (!checkIsIndPhi(Phi: IndPhi, L, SE, ID))
1798 continue;
1799 // This is an induction PHI. Check that the only users are PHI
1800 // nodes, and induction variable update binary operators.
1801 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1802 if (!isa<PHINode>(Val: U) && !isa<BinaryOperator>(Val: U))
1803 return true;
1804 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: U);
1805 if (B && B != ID.getInductionBinOp())
1806 return true;
1807 return false;
1808 }))
1809 continue;
1810 } else {
1811 // If it is not an induction phi, it must be an induction update
1812 // binary operator with an induction phi user.
1813 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: Inst);
1814 if (!B)
1815 continue;
1816 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1817 PHINode *Phi = dyn_cast<PHINode>(Val: U);
1818 if (Phi != PN && !checkIsIndPhi(Phi, L, SE, ID))
1819 return true;
1820 return false;
1821 }))
1822 continue;
1823 if (B != ID.getInductionBinOp())
1824 continue;
1825 }
1826 }
1827
1828 // Okay, this instruction has a user outside of the current loop
1829 // and varies predictably *inside* the loop. Evaluate the value it
1830 // contains when the loop exits, if possible. We prefer to start with
1831 // expressions which are true for all exits (so as to maximize
1832 // expression reuse by the SCEVExpander), but resort to per-exit
1833 // evaluation if that fails.
1834 const SCEV *ExitValue = SE->getSCEVAtScope(V: Inst, L: L->getParentLoop());
1835 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1836 !SE->isLoopInvariant(S: ExitValue, L) ||
1837 !Rewriter.isSafeToExpand(S: ExitValue)) {
1838 // TODO: This should probably be sunk into SCEV in some way; maybe a
1839 // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
1840 // most SCEV expressions and other recurrence types (e.g. shift
1841 // recurrences). Is there existing code we can reuse?
1842 const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: PN->getIncomingBlock(i));
1843 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
1844 continue;
1845 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Inst)))
1846 if (AddRec->getLoop() == L)
1847 ExitValue = AddRec->evaluateAtIteration(It: ExitCount, SE&: *SE);
1848 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1849 !SE->isLoopInvariant(S: ExitValue, L) ||
1850 !Rewriter.isSafeToExpand(S: ExitValue))
1851 continue;
1852 }
1853
1854 // Computing the value outside of the loop brings no benefit if it is
1855 // definitely used inside the loop in a way which can not be optimized
1856 // away. Avoid doing so unless we know we have a value which computes
1857 // the ExitValue already. TODO: This should be merged into SCEV
1858 // expander to leverage its knowledge of existing expressions.
1859 if (ReplaceExitValue != AlwaysRepl && !isa<SCEVConstant>(Val: ExitValue) &&
1860 !isa<SCEVUnknown>(Val: ExitValue) && hasHardUserWithinLoop(L, I: Inst))
1861 continue;
1862
1863 // Check if expansions of this SCEV would count as being high cost.
1864 bool HighCost = Rewriter.isHighCostExpansion(
1865 Exprs: ExitValue, L, Budget: SCEVCheapExpansionBudget, TTI, At: Inst);
1866
1867 // Note that we must not perform expansions until after
1868 // we query *all* the costs, because if we perform temporary expansion
1869 // inbetween, one that we might not intend to keep, said expansion
1870 // *may* affect cost calculation of the next SCEV's we'll query,
1871 // and next SCEV may errneously get smaller cost.
1872
1873 // Collect all the candidate PHINodes to be rewritten.
1874 Instruction *InsertPt =
1875 (isa<PHINode>(Val: Inst) || isa<LandingPadInst>(Val: Inst)) ?
1876 &*Inst->getParent()->getFirstInsertionPt() : Inst;
1877 RewritePhiSet.emplace_back(Args&: PN, Args&: i, Args&: ExitValue, Args&: InsertPt, Args&: HighCost);
1878 }
1879 }
1880 }
1881
1882 // TODO: evaluate whether it is beneficial to change how we calculate
1883 // high-cost: if we have SCEV 'A' which we know we will expand, should we
1884 // calculate the cost of other SCEV's after expanding SCEV 'A', thus
1885 // potentially giving cost bonus to those other SCEV's?
1886
1887 bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
1888 int NumReplaced = 0;
1889
1890 // Transformation.
1891 for (const RewritePhi &Phi : RewritePhiSet) {
1892 PHINode *PN = Phi.PN;
1893
1894 // Only do the rewrite when the ExitValue can be expanded cheaply.
1895 // If LoopCanBeDel is true, rewrite exit value aggressively.
1896 if ((ReplaceExitValue == OnlyCheapRepl ||
1897 ReplaceExitValue == UnusedIndVarInLoop) &&
1898 !LoopCanBeDel && Phi.HighCost)
1899 continue;
1900
1901 Value *ExitVal = Rewriter.expandCodeFor(
1902 SH: Phi.ExpansionSCEV, Ty: Phi.PN->getType(), I: Phi.ExpansionPoint);
1903
1904 LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal
1905 << '\n'
1906 << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
1907
1908#ifndef NDEBUG
1909 // If we reuse an instruction from a loop which is neither L nor one of
1910 // its containing loops, we end up breaking LCSSA form for this loop by
1911 // creating a new use of its instruction.
1912 if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
1913 if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1914 if (EVL != L)
1915 assert(EVL->contains(L) && "LCSSA breach detected!");
1916#endif
1917
1918 NumReplaced++;
1919 Instruction *Inst = cast<Instruction>(Val: PN->getIncomingValue(i: Phi.Ith));
1920 PN->setIncomingValue(i: Phi.Ith, V: ExitVal);
1921 // It's necessary to tell ScalarEvolution about this explicitly so that
1922 // it can walk the def-use list and forget all SCEVs, as it may not be
1923 // watching the PHI itself. Once the new exit value is in place, there
1924 // may not be a def-use connection between the loop and every instruction
1925 // which got a SCEVAddRecExpr for that loop.
1926 SE->forgetValue(V: PN);
1927
1928 // If this instruction is dead now, delete it. Don't do it now to avoid
1929 // invalidating iterators.
1930 if (isInstructionTriviallyDead(I: Inst, TLI))
1931 DeadInsts.push_back(Elt: Inst);
1932
1933 // Replace PN with ExitVal if that is legal and does not break LCSSA.
1934 if (PN->getNumIncomingValues() == 1 &&
1935 LI->replacementPreservesLCSSAForm(From: PN, To: ExitVal)) {
1936 PN->replaceAllUsesWith(V: ExitVal);
1937 PN->eraseFromParent();
1938 }
1939 }
1940
1941 // The insertion point instruction may have been deleted; clear it out
1942 // so that the rewriter doesn't trip over it later.
1943 Rewriter.clearInsertPoint();
1944 return NumReplaced;
1945}
1946
1947/// Utility that implements appending of loops onto a worklist.
1948/// Loops are added in preorder (analogous for reverse postorder for trees),
1949/// and the worklist is processed LIFO.
1950template <typename RangeT>
1951void llvm::appendReversedLoopsToWorklist(
1952 RangeT &&Loops, SmallPriorityWorklist<Loop *, 4> &Worklist) {
1953 // We use an internal worklist to build up the preorder traversal without
1954 // recursion.
1955 SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist;
1956
1957 // We walk the initial sequence of loops in reverse because we generally want
1958 // to visit defs before uses and the worklist is LIFO.
1959 for (Loop *RootL : Loops) {
1960 assert(PreOrderLoops.empty() && "Must start with an empty preorder walk.");
1961 assert(PreOrderWorklist.empty() &&
1962 "Must start with an empty preorder walk worklist.");
1963 PreOrderWorklist.push_back(Elt: RootL);
1964 do {
1965 Loop *L = PreOrderWorklist.pop_back_val();
1966 PreOrderWorklist.append(in_start: L->begin(), in_end: L->end());
1967 PreOrderLoops.push_back(Elt: L);
1968 } while (!PreOrderWorklist.empty());
1969
1970 Worklist.insert(Input: std::move(PreOrderLoops));
1971 PreOrderLoops.clear();
1972 }
1973}
1974
1975template <typename RangeT>
1976void llvm::appendLoopsToWorklist(RangeT &&Loops,
1977 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1978 appendReversedLoopsToWorklist(reverse(Loops), Worklist);
1979}
1980
1981template LLVM_EXPORT_TEMPLATE void
1982llvm::appendLoopsToWorklist<ArrayRef<Loop *> &>(
1983 ArrayRef<Loop *> &Loops, SmallPriorityWorklist<Loop *, 4> &Worklist);
1984
1985template LLVM_EXPORT_TEMPLATE void
1986llvm::appendLoopsToWorklist<Loop &>(Loop &L,
1987 SmallPriorityWorklist<Loop *, 4> &Worklist);
1988
1989void llvm::appendLoopsToWorklist(LoopInfo &LI,
1990 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1991 appendReversedLoopsToWorklist(Loops&: LI, Worklist);
1992}
1993
1994Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
1995 LoopInfo *LI, LPPassManager *LPM) {
1996 Loop &New = *LI->AllocateLoop();
1997 if (PL)
1998 PL->addChildLoop(NewChild: &New);
1999 else
2000 LI->addTopLevelLoop(New: &New);
2001
2002 if (LPM)
2003 LPM->addLoop(L&: New);
2004
2005 // Add all of the blocks in L to the new loop.
2006 for (BasicBlock *BB : L->blocks())
2007 if (LI->getLoopFor(BB) == L)
2008 New.addBasicBlockToLoop(NewBB: cast<BasicBlock>(Val&: VM[BB]), LI&: *LI);
2009
2010 // Add all of the subloops to the new loop.
2011 for (Loop *I : *L)
2012 cloneLoop(L: I, PL: &New, VM, LI, LPM);
2013
2014 return &New;
2015}
2016
2017/// IR Values for the lower and upper bounds of a pointer evolution. We
2018/// need to use value-handles because SCEV expansion can invalidate previously
2019/// expanded values. Thus expansion of a pointer can invalidate the bounds for
2020/// a previous one.
2021struct PointerBounds {
2022 TrackingVH<Value> Start;
2023 TrackingVH<Value> End;
2024 Value *StrideToCheck;
2025};
2026
2027/// Expand code for the lower and upper bound of the pointer group \p CG
2028/// in \p TheLoop. \return the values for the bounds.
2029static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
2030 Loop *TheLoop, Instruction *Loc,
2031 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2032 LLVMContext &Ctx = Loc->getContext();
2033 Type *PtrArithTy = PointerType::get(C&: Ctx, AddressSpace: CG->AddressSpace);
2034
2035 Value *Start = nullptr, *End = nullptr;
2036 LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
2037 const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
2038
2039 // If the Low and High values are themselves loop-variant, then we may want
2040 // to expand the range to include those covered by the outer loop as well.
2041 // There is a trade-off here with the advantage being that creating checks
2042 // using the expanded range permits the runtime memory checks to be hoisted
2043 // out of the outer loop. This reduces the cost of entering the inner loop,
2044 // which can be significant for low trip counts. The disadvantage is that
2045 // there is a chance we may now never enter the vectorized inner loop,
2046 // whereas using a restricted range check could have allowed us to enter at
2047 // least once. This is why the behaviour is not currently the default and is
2048 // controlled by the parameter 'HoistRuntimeChecks'.
2049 if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
2050 isa<SCEVAddRecExpr>(Val: High) && isa<SCEVAddRecExpr>(Val: Low)) {
2051 auto *HighAR = cast<SCEVAddRecExpr>(Val: High);
2052 auto *LowAR = cast<SCEVAddRecExpr>(Val: Low);
2053 const Loop *OuterLoop = TheLoop->getParentLoop();
2054 ScalarEvolution &SE = *Exp.getSE();
2055 const SCEV *Recur = LowAR->getStepRecurrence(SE);
2056 if (Recur == HighAR->getStepRecurrence(SE) &&
2057 HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
2058 BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
2059 const SCEV *OuterExitCount = SE.getExitCount(L: OuterLoop, ExitingBlock: OuterLoopLatch);
2060 if (!isa<SCEVCouldNotCompute>(Val: OuterExitCount) &&
2061 OuterExitCount->getType()->isIntegerTy()) {
2062 const SCEV *NewHigh =
2063 cast<SCEVAddRecExpr>(Val: High)->evaluateAtIteration(It: OuterExitCount, SE);
2064 if (!isa<SCEVCouldNotCompute>(Val: NewHigh)) {
2065 LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
2066 "outer loop in order to permit hoisting\n");
2067 High = NewHigh;
2068 Low = cast<SCEVAddRecExpr>(Val: Low)->getStart();
2069 // If there is a possibility that the stride is negative then we have
2070 // to generate extra checks to ensure the stride is positive.
2071 if (!SE.isKnownNonNegative(
2072 S: SE.applyLoopGuards(Expr: Recur, L: HighAR->getLoop()))) {
2073 Stride = Recur;
2074 LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
2075 "positive: "
2076 << *Stride << '\n');
2077 }
2078 }
2079 }
2080 }
2081 }
2082
2083 Start = Exp.expandCodeFor(SH: Low, Ty: PtrArithTy, I: Loc);
2084 End = Exp.expandCodeFor(SH: High, Ty: PtrArithTy, I: Loc);
2085 if (CG->NeedsFreeze) {
2086 IRBuilder<> Builder(Loc);
2087 Start = Builder.CreateFreeze(V: Start, Name: Start->getName() + ".fr");
2088 End = Builder.CreateFreeze(V: End, Name: End->getName() + ".fr");
2089 }
2090 Value *StrideVal =
2091 Stride ? Exp.expandCodeFor(SH: Stride, Ty: Stride->getType(), I: Loc) : nullptr;
2092 LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
2093 return {.Start: Start, .End: End, .StrideToCheck: StrideVal};
2094}
2095
2096/// Turns a collection of checks into a collection of expanded upper and
2097/// lower bounds for both pointers in the check.
2098static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
2099expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
2100 Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
2101 SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
2102
2103 // Here we're relying on the SCEV Expander's cache to only emit code for the
2104 // same bounds once.
2105 transform(Range: PointerChecks, d_first: std::back_inserter(x&: ChecksWithBounds),
2106 F: [&](const RuntimePointerCheck &Check) {
2107 PointerBounds First = expandBounds(CG: Check.first, TheLoop: L, Loc, Exp,
2108 HoistRuntimeChecks),
2109 Second = expandBounds(CG: Check.second, TheLoop: L, Loc, Exp,
2110 HoistRuntimeChecks);
2111 return std::make_pair(x&: First, y&: Second);
2112 });
2113
2114 return ChecksWithBounds;
2115}
2116
2117Value *llvm::addRuntimeChecks(
2118 Instruction *Loc, Loop *TheLoop,
2119 const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
2120 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2121 // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
2122 // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
2123 auto ExpandedChecks =
2124 expandBounds(PointerChecks, L: TheLoop, Loc, Exp, HoistRuntimeChecks);
2125
2126 LLVMContext &Ctx = Loc->getContext();
2127 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2128 ChkBuilder.SetInsertPoint(Loc);
2129 // Our instructions might fold to a constant.
2130 Value *MemoryRuntimeCheck = nullptr;
2131
2132 for (const auto &[A, B] : ExpandedChecks) {
2133 // Check if two pointers (A and B) conflict where conflict is computed as:
2134 // start(A) <= end(B) && start(B) <= end(A)
2135
2136 assert((A.Start->getType()->getPointerAddressSpace() ==
2137 B.End->getType()->getPointerAddressSpace()) &&
2138 (B.Start->getType()->getPointerAddressSpace() ==
2139 A.End->getType()->getPointerAddressSpace()) &&
2140 "Trying to bounds check pointers with different address spaces");
2141
2142 // [A|B].Start points to the first accessed byte under base [A|B].
2143 // [A|B].End points to the last accessed byte, plus one.
2144 // There is no conflict when the intervals are disjoint:
2145 // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
2146 //
2147 // bound0 = (B.Start < A.End)
2148 // bound1 = (A.Start < B.End)
2149 // IsConflict = bound0 & bound1
2150 Value *Cmp0 = ChkBuilder.CreateICmpULT(LHS: A.Start, RHS: B.End, Name: "bound0");
2151 Value *Cmp1 = ChkBuilder.CreateICmpULT(LHS: B.Start, RHS: A.End, Name: "bound1");
2152 Value *IsConflict = ChkBuilder.CreateAnd(LHS: Cmp0, RHS: Cmp1, Name: "found.conflict");
2153 if (A.StrideToCheck) {
2154 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2155 LHS: A.StrideToCheck, RHS: ConstantInt::get(Ty: A.StrideToCheck->getType(), V: 0),
2156 Name: "stride.check");
2157 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2158 }
2159 if (B.StrideToCheck) {
2160 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2161 LHS: B.StrideToCheck, RHS: ConstantInt::get(Ty: B.StrideToCheck->getType(), V: 0),
2162 Name: "stride.check");
2163 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2164 }
2165 if (MemoryRuntimeCheck) {
2166 IsConflict =
2167 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2168 }
2169 MemoryRuntimeCheck = IsConflict;
2170 }
2171
2172 Exp.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2173 return MemoryRuntimeCheck;
2174}
2175
2176Value *llvm::addDiffRuntimeChecks(
2177 Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
2178 function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
2179
2180 LLVMContext &Ctx = Loc->getContext();
2181 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2182 ChkBuilder.SetInsertPoint(Loc);
2183 // Our instructions might fold to a constant.
2184 Value *MemoryRuntimeCheck = nullptr;
2185
2186 auto &SE = *Expander.getSE();
2187 // Map to keep track of created compares, The key is the pair of operands for
2188 // the compare, to allow detecting and re-using redundant compares.
2189 DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2190 for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2191 Type *Ty = SinkStart->getType();
2192 // Compute VF * IC * AccessSize.
2193 auto *VFTimesICTimesSize =
2194 ChkBuilder.CreateMul(LHS: GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2195 RHS: ConstantInt::get(Ty, V: IC * AccessSize));
2196 Value *Diff =
2197 Expander.expandCodeFor(SH: SE.getMinusSCEV(LHS: SinkStart, RHS: SrcStart), Ty, I: Loc);
2198
2199 // Check if the same compare has already been created earlier. In that case,
2200 // there is no need to check it again.
2201 Value *IsConflict = SeenCompares.lookup(Val: {Diff, VFTimesICTimesSize});
2202 if (IsConflict)
2203 continue;
2204
2205 IsConflict =
2206 ChkBuilder.CreateICmpULT(LHS: Diff, RHS: VFTimesICTimesSize, Name: "diff.check");
2207 SeenCompares.insert(KV: {{Diff, VFTimesICTimesSize}, IsConflict});
2208 if (NeedsFreeze)
2209 IsConflict =
2210 ChkBuilder.CreateFreeze(V: IsConflict, Name: IsConflict->getName() + ".fr");
2211 if (MemoryRuntimeCheck) {
2212 IsConflict =
2213 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2214 }
2215 MemoryRuntimeCheck = IsConflict;
2216 }
2217
2218 Expander.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2219 return MemoryRuntimeCheck;
2220}
2221
2222std::optional<IVConditionInfo>
2223llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold,
2224 const MemorySSA &MSSA, AAResults &AA) {
2225 auto *TI = dyn_cast<BranchInst>(Val: L.getHeader()->getTerminator());
2226 if (!TI || !TI->isConditional())
2227 return {};
2228
2229 auto *CondI = dyn_cast<Instruction>(Val: TI->getCondition());
2230 // The case with the condition outside the loop should already be handled
2231 // earlier.
2232 // Allow CmpInst and TruncInsts as they may be users of load instructions
2233 // and have potential for partial unswitching
2234 if (!CondI || !isa<CmpInst, TruncInst>(Val: CondI) || !L.contains(Inst: CondI))
2235 return {};
2236
2237 SmallVector<Instruction *> InstToDuplicate;
2238 InstToDuplicate.push_back(Elt: CondI);
2239
2240 SmallVector<Value *, 4> WorkList;
2241 WorkList.append(in_start: CondI->op_begin(), in_end: CondI->op_end());
2242
2243 SmallVector<MemoryAccess *, 4> AccessesToCheck;
2244 SmallVector<MemoryLocation, 4> AccessedLocs;
2245 while (!WorkList.empty()) {
2246 Instruction *I = dyn_cast<Instruction>(Val: WorkList.pop_back_val());
2247 if (!I || !L.contains(Inst: I))
2248 continue;
2249
2250 // TODO: support additional instructions.
2251 if (!isa<LoadInst>(Val: I) && !isa<GetElementPtrInst>(Val: I))
2252 return {};
2253
2254 // Do not duplicate volatile and atomic loads.
2255 if (auto *LI = dyn_cast<LoadInst>(Val: I))
2256 if (LI->isVolatile() || LI->isAtomic())
2257 return {};
2258
2259 InstToDuplicate.push_back(Elt: I);
2260 if (MemoryAccess *MA = MSSA.getMemoryAccess(I)) {
2261 if (auto *MemUse = dyn_cast_or_null<MemoryUse>(Val: MA)) {
2262 // Queue the defining access to check for alias checks.
2263 AccessesToCheck.push_back(Elt: MemUse->getDefiningAccess());
2264 AccessedLocs.push_back(Elt: MemoryLocation::get(Inst: I));
2265 } else {
2266 // MemoryDefs may clobber the location or may be atomic memory
2267 // operations. Bail out.
2268 return {};
2269 }
2270 }
2271 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
2272 }
2273
2274 if (InstToDuplicate.empty())
2275 return {};
2276
2277 SmallVector<BasicBlock *, 4> ExitingBlocks;
2278 L.getExitingBlocks(ExitingBlocks);
2279 auto HasNoClobbersOnPath =
2280 [&L, &AA, &AccessedLocs, &ExitingBlocks, &InstToDuplicate,
2281 MSSAThreshold](BasicBlock *Succ, BasicBlock *Header,
2282 SmallVector<MemoryAccess *, 4> AccessesToCheck)
2283 -> std::optional<IVConditionInfo> {
2284 IVConditionInfo Info;
2285 // First, collect all blocks in the loop that are on a patch from Succ
2286 // to the header.
2287 SmallVector<BasicBlock *, 4> WorkList;
2288 WorkList.push_back(Elt: Succ);
2289 WorkList.push_back(Elt: Header);
2290 SmallPtrSet<BasicBlock *, 4> Seen;
2291 Seen.insert(Ptr: Header);
2292 Info.PathIsNoop &=
2293 all_of(Range&: *Header, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2294
2295 while (!WorkList.empty()) {
2296 BasicBlock *Current = WorkList.pop_back_val();
2297 if (!L.contains(BB: Current))
2298 continue;
2299 const auto &SeenIns = Seen.insert(Ptr: Current);
2300 if (!SeenIns.second)
2301 continue;
2302
2303 Info.PathIsNoop &= all_of(
2304 Range&: *Current, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2305 WorkList.append(in_start: succ_begin(BB: Current), in_end: succ_end(BB: Current));
2306 }
2307
2308 // Require at least 2 blocks on a path through the loop. This skips
2309 // paths that directly exit the loop.
2310 if (Seen.size() < 2)
2311 return {};
2312
2313 // Next, check if there are any MemoryDefs that are on the path through
2314 // the loop (in the Seen set) and they may-alias any of the locations in
2315 // AccessedLocs. If that is the case, they may modify the condition and
2316 // partial unswitching is not possible.
2317 SmallPtrSet<MemoryAccess *, 4> SeenAccesses;
2318 while (!AccessesToCheck.empty()) {
2319 MemoryAccess *Current = AccessesToCheck.pop_back_val();
2320 auto SeenI = SeenAccesses.insert(Ptr: Current);
2321 if (!SeenI.second || !Seen.contains(Ptr: Current->getBlock()))
2322 continue;
2323
2324 // Bail out if exceeded the threshold.
2325 if (SeenAccesses.size() >= MSSAThreshold)
2326 return {};
2327
2328 // MemoryUse are read-only accesses.
2329 if (isa<MemoryUse>(Val: Current))
2330 continue;
2331
2332 // For a MemoryDef, check if is aliases any of the location feeding
2333 // the original condition.
2334 if (auto *CurrentDef = dyn_cast<MemoryDef>(Val: Current)) {
2335 if (any_of(Range&: AccessedLocs, P: [&AA, CurrentDef](MemoryLocation &Loc) {
2336 return isModSet(
2337 MRI: AA.getModRefInfo(I: CurrentDef->getMemoryInst(), OptLoc: Loc));
2338 }))
2339 return {};
2340 }
2341
2342 for (Use &U : Current->uses())
2343 AccessesToCheck.push_back(Elt: cast<MemoryAccess>(Val: U.getUser()));
2344 }
2345
2346 // We could also allow loops with known trip counts without mustprogress,
2347 // but ScalarEvolution may not be available.
2348 Info.PathIsNoop &= isMustProgress(L: &L);
2349
2350 // If the path is considered a no-op so far, check if it reaches a
2351 // single exit block without any phis. This ensures no values from the
2352 // loop are used outside of the loop.
2353 if (Info.PathIsNoop) {
2354 for (auto *Exiting : ExitingBlocks) {
2355 if (!Seen.contains(Ptr: Exiting))
2356 continue;
2357 for (auto *Succ : successors(BB: Exiting)) {
2358 if (L.contains(BB: Succ))
2359 continue;
2360
2361 Info.PathIsNoop &= Succ->phis().empty() &&
2362 (!Info.ExitForPath || Info.ExitForPath == Succ);
2363 if (!Info.PathIsNoop)
2364 break;
2365 assert((!Info.ExitForPath || Info.ExitForPath == Succ) &&
2366 "cannot have multiple exit blocks");
2367 Info.ExitForPath = Succ;
2368 }
2369 }
2370 }
2371 if (!Info.ExitForPath)
2372 Info.PathIsNoop = false;
2373
2374 Info.InstToDuplicate = std::move(InstToDuplicate);
2375 return Info;
2376 };
2377
2378 // If we branch to the same successor, partial unswitching will not be
2379 // beneficial.
2380 if (TI->getSuccessor(i: 0) == TI->getSuccessor(i: 1))
2381 return {};
2382
2383 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 0), L.getHeader(),
2384 AccessesToCheck)) {
2385 Info->KnownValue = ConstantInt::getTrue(Context&: TI->getContext());
2386 return Info;
2387 }
2388 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 1), L.getHeader(),
2389 AccessesToCheck)) {
2390 Info->KnownValue = ConstantInt::getFalse(Context&: TI->getContext());
2391 return Info;
2392 }
2393
2394 return {};
2395}
2396