1//===-- LoopUtils.cpp - Loop Utility functions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines common loop utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/LoopUtils.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Analysis/AliasAnalysis.h"
21#include "llvm/Analysis/BasicAliasAnalysis.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/LoopAccessAnalysis.h"
26#include "llvm/Analysis/LoopInfo.h"
27#include "llvm/Analysis/LoopPass.h"
28#include "llvm/Analysis/MemorySSA.h"
29#include "llvm/Analysis/MemorySSAUpdater.h"
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
32#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33#include "llvm/IR/DIBuilder.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/Instructions.h"
36#include "llvm/IR/IntrinsicInst.h"
37#include "llvm/IR/MDBuilder.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/PatternMatch.h"
40#include "llvm/IR/ProfDataUtils.h"
41#include "llvm/IR/ValueHandle.h"
42#include "llvm/InitializePasses.h"
43#include "llvm/Pass.h"
44#include "llvm/Support/Compiler.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Transforms/Utils/BasicBlockUtils.h"
47#include "llvm/Transforms/Utils/Local.h"
48#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
49
50using namespace llvm;
51using namespace llvm::PatternMatch;
52
53#define DEBUG_TYPE "loop-utils"
54
55static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced";
56static const char *LLVMLoopDisableLICM = "llvm.licm.disable";
57namespace llvm {
58extern cl::opt<bool> ProfcheckDisableMetadataFixes;
59} // namespace llvm
60
61bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
62 MemorySSAUpdater *MSSAU,
63 bool PreserveLCSSA) {
64 bool Changed = false;
65
66 // We re-use a vector for the in-loop predecesosrs.
67 SmallVector<BasicBlock *, 4> InLoopPredecessors;
68
69 auto RewriteExit = [&](BasicBlock *BB) {
70 assert(InLoopPredecessors.empty() &&
71 "Must start with an empty predecessors list!");
72 llvm::scope_exit Cleanup([&] { InLoopPredecessors.clear(); });
73
74 // See if there are any non-loop predecessors of this exit block and
75 // keep track of the in-loop predecessors.
76 bool IsDedicatedExit = true;
77 for (auto *PredBB : predecessors(BB))
78 if (L->contains(BB: PredBB)) {
79 if (isa<IndirectBrInst>(Val: PredBB->getTerminator()))
80 // We cannot rewrite exiting edges from an indirectbr.
81 return false;
82
83 InLoopPredecessors.push_back(Elt: PredBB);
84 } else {
85 IsDedicatedExit = false;
86 }
87
88 assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!");
89
90 // Nothing to do if this is already a dedicated exit.
91 if (IsDedicatedExit)
92 return false;
93
94 auto *NewExitBB = SplitBlockPredecessors(
95 BB, Preds: InLoopPredecessors, Suffix: ".loopexit", DT, LI, MSSAU, PreserveLCSSA);
96
97 if (!NewExitBB)
98 LLVM_DEBUG(
99 dbgs() << "WARNING: Can't create a dedicated exit block for loop: "
100 << *L << "\n");
101 else
102 LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block "
103 << NewExitBB->getName() << "\n");
104 return true;
105 };
106
107 // Walk the exit blocks directly rather than building up a data structure for
108 // them, but only visit each one once.
109 SmallPtrSet<BasicBlock *, 4> Visited;
110 for (auto *BB : L->blocks())
111 for (auto *SuccBB : successors(BB)) {
112 // We're looking for exit blocks so skip in-loop successors.
113 if (L->contains(BB: SuccBB))
114 continue;
115
116 // Visit each exit block exactly once.
117 if (!Visited.insert(Ptr: SuccBB).second)
118 continue;
119
120 Changed |= RewriteExit(SuccBB);
121 }
122
123 return Changed;
124}
125
126/// Returns the instructions that use values defined in the loop.
127SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) {
128 SmallVector<Instruction *, 8> UsedOutside;
129
130 for (auto *Block : L->getBlocks())
131 // FIXME: I believe that this could use copy_if if the Inst reference could
132 // be adapted into a pointer.
133 for (auto &Inst : *Block) {
134 auto Users = Inst.users();
135 if (any_of(Range&: Users, P: [&](User *U) {
136 auto *Use = cast<Instruction>(Val: U);
137 return !L->contains(BB: Use->getParent());
138 }))
139 UsedOutside.push_back(Elt: &Inst);
140 }
141
142 return UsedOutside;
143}
144
145void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
146 // By definition, all loop passes need the LoopInfo analysis and the
147 // Dominator tree it depends on. Because they all participate in the loop
148 // pass manager, they must also preserve these.
149 AU.addRequired<DominatorTreeWrapperPass>();
150 AU.addPreserved<DominatorTreeWrapperPass>();
151 AU.addRequired<LoopInfoWrapperPass>();
152 AU.addPreserved<LoopInfoWrapperPass>();
153
154 // We must also preserve LoopSimplify and LCSSA. We locally access their IDs
155 // here because users shouldn't directly get them from this header.
156 extern char &LoopSimplifyID;
157 extern char &LCSSAID;
158 AU.addRequiredID(ID&: LoopSimplifyID);
159 AU.addPreservedID(ID&: LoopSimplifyID);
160 AU.addRequiredID(ID&: LCSSAID);
161 AU.addPreservedID(ID&: LCSSAID);
162 // This is used in the LPPassManager to perform LCSSA verification on passes
163 // which preserve lcssa form
164 AU.addRequired<LCSSAVerificationPass>();
165 AU.addPreserved<LCSSAVerificationPass>();
166
167 // Loop passes are designed to run inside of a loop pass manager which means
168 // that any function analyses they require must be required by the first loop
169 // pass in the manager (so that it is computed before the loop pass manager
170 // runs) and preserved by all loop pasess in the manager. To make this
171 // reasonably robust, the set needed for most loop passes is maintained here.
172 // If your loop pass requires an analysis not listed here, you will need to
173 // carefully audit the loop pass manager nesting structure that results.
174 AU.addRequired<AAResultsWrapperPass>();
175 AU.addPreserved<AAResultsWrapperPass>();
176 AU.addPreserved<BasicAAWrapperPass>();
177 AU.addPreserved<GlobalsAAWrapperPass>();
178 AU.addPreserved<SCEVAAWrapperPass>();
179 AU.addRequired<ScalarEvolutionWrapperPass>();
180 AU.addPreserved<ScalarEvolutionWrapperPass>();
181 // FIXME: When all loop passes preserve MemorySSA, it can be required and
182 // preserved here instead of the individual handling in each pass.
183}
184
185/// Manually defined generic "LoopPass" dependency initialization. This is used
186/// to initialize the exact set of passes from above in \c
187/// getLoopAnalysisUsage. It can be used within a loop pass's initialization
188/// with:
189///
190/// INITIALIZE_PASS_DEPENDENCY(LoopPass)
191///
192/// As-if "LoopPass" were a pass.
193void llvm::initializeLoopPassPass(PassRegistry &Registry) {
194 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
195 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
197 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
198 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
199 INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
200 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
201 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
202 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
203 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
204}
205
206/// Create MDNode for input string.
207static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
208 LLVMContext &Context = TheLoop->getHeader()->getContext();
209 Metadata *MDs[] = {
210 MDString::get(Context, Str: Name),
211 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V))};
212 return MDNode::get(Context, MDs);
213}
214
215/// Set input string into loop metadata by keeping other values intact.
216/// If the string is already in loop metadata update value if it is
217/// different.
218void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
219 unsigned V) {
220 SmallVector<Metadata *, 4> MDs(1);
221 // If the loop already has metadata, retain it.
222 MDNode *LoopID = TheLoop->getLoopID();
223 if (LoopID) {
224 for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
225 MDNode *Node = cast<MDNode>(Val: LoopID->getOperand(I: i));
226 // If it is of form key = value, try to parse it.
227 if (Node->getNumOperands() == 2) {
228 MDString *S = dyn_cast<MDString>(Val: Node->getOperand(I: 0));
229 if (S && S->getString() == StringMD) {
230 ConstantInt *IntMD =
231 mdconst::extract_or_null<ConstantInt>(MD: Node->getOperand(I: 1));
232 if (IntMD && IntMD->getSExtValue() == V)
233 // It is already in place. Do nothing.
234 return;
235 // We need to update the value, so just skip it here and it will
236 // be added after copying other existed nodes.
237 continue;
238 }
239 }
240 MDs.push_back(Elt: Node);
241 }
242 }
243 // Add new metadata.
244 MDs.push_back(Elt: createStringMetadata(TheLoop, Name: StringMD, V));
245 // Replace current metadata node with new one.
246 LLVMContext &Context = TheLoop->getHeader()->getContext();
247 MDNode *NewLoopID = MDNode::get(Context, MDs);
248 // Set operand 0 to refer to the loop id itself.
249 NewLoopID->replaceOperandWith(I: 0, New: NewLoopID);
250 TheLoop->setLoopID(NewLoopID);
251}
252
253std::optional<ElementCount>
254llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) {
255 std::optional<int> Width =
256 getOptionalIntLoopAttribute(TheLoop, Name: "llvm.loop.vectorize.width");
257
258 if (Width) {
259 std::optional<int> IsScalable = getOptionalIntLoopAttribute(
260 TheLoop, Name: "llvm.loop.vectorize.scalable.enable");
261 return ElementCount::get(MinVal: *Width, Scalable: IsScalable.value_or(u: false));
262 }
263
264 return std::nullopt;
265}
266
267std::optional<MDNode *> llvm::makeFollowupLoopID(
268 MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions,
269 const char *InheritOptionsExceptPrefix, bool AlwaysNew) {
270 if (!OrigLoopID) {
271 if (AlwaysNew)
272 return nullptr;
273 return std::nullopt;
274 }
275
276 assert(OrigLoopID->getOperand(0) == OrigLoopID);
277
278 bool InheritAllAttrs = !InheritOptionsExceptPrefix;
279 bool InheritSomeAttrs =
280 InheritOptionsExceptPrefix && InheritOptionsExceptPrefix[0] != '\0';
281 SmallVector<Metadata *, 8> MDs;
282 MDs.push_back(Elt: nullptr);
283
284 bool Changed = false;
285 if (InheritAllAttrs || InheritSomeAttrs) {
286 for (const MDOperand &Existing : drop_begin(RangeOrContainer: OrigLoopID->operands())) {
287 MDNode *Op = cast<MDNode>(Val: Existing.get());
288
289 auto InheritThisAttribute = [InheritSomeAttrs,
290 InheritOptionsExceptPrefix](MDNode *Op) {
291 if (!InheritSomeAttrs)
292 return false;
293
294 // Skip malformatted attribute metadata nodes.
295 if (Op->getNumOperands() == 0)
296 return true;
297 Metadata *NameMD = Op->getOperand(I: 0).get();
298 if (!isa<MDString>(Val: NameMD))
299 return true;
300 StringRef AttrName = cast<MDString>(Val: NameMD)->getString();
301
302 // Do not inherit excluded attributes.
303 return !AttrName.starts_with(Prefix: InheritOptionsExceptPrefix);
304 };
305
306 if (InheritThisAttribute(Op))
307 MDs.push_back(Elt: Op);
308 else
309 Changed = true;
310 }
311 } else {
312 // Modified if we dropped at least one attribute.
313 Changed = OrigLoopID->getNumOperands() > 1;
314 }
315
316 bool HasAnyFollowup = false;
317 for (StringRef OptionName : FollowupOptions) {
318 MDNode *FollowupNode = findOptionMDForLoopID(LoopID: OrigLoopID, Name: OptionName);
319 if (!FollowupNode)
320 continue;
321
322 HasAnyFollowup = true;
323 for (const MDOperand &Option : drop_begin(RangeOrContainer: FollowupNode->operands())) {
324 MDs.push_back(Elt: Option.get());
325 Changed = true;
326 }
327 }
328
329 // Attributes of the followup loop not specified explicity, so signal to the
330 // transformation pass to add suitable attributes.
331 if (!AlwaysNew && !HasAnyFollowup)
332 return std::nullopt;
333
334 // If no attributes were added or remove, the previous loop Id can be reused.
335 if (!AlwaysNew && !Changed)
336 return OrigLoopID;
337
338 // No attributes is equivalent to having no !llvm.loop metadata at all.
339 if (MDs.size() == 1)
340 return nullptr;
341
342 // Build the new loop ID.
343 MDTuple *FollowupLoopID = MDNode::get(Context&: OrigLoopID->getContext(), MDs);
344 FollowupLoopID->replaceOperandWith(I: 0, New: FollowupLoopID);
345 return FollowupLoopID;
346}
347
348bool llvm::hasDisableAllTransformsHint(const Loop *L) {
349 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableNonforced);
350}
351
352bool llvm::hasDisableLICMTransformsHint(const Loop *L) {
353 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableLICM);
354}
355
356TransformationMode llvm::hasUnrollTransformation(const Loop *L) {
357 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable"))
358 return TM_SuppressedByUser;
359
360 std::optional<int> Count =
361 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count");
362 if (Count)
363 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
364
365 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable"))
366 return TM_ForcedByUser;
367
368 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full"))
369 return TM_ForcedByUser;
370
371 if (hasDisableAllTransformsHint(L))
372 return TM_Disable;
373
374 return TM_Unspecified;
375}
376
377TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) {
378 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.disable"))
379 return TM_SuppressedByUser;
380
381 std::optional<int> Count =
382 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.count");
383 if (Count)
384 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
385
386 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.enable"))
387 return TM_ForcedByUser;
388
389 if (hasDisableAllTransformsHint(L))
390 return TM_Disable;
391
392 return TM_Unspecified;
393}
394
395TransformationMode llvm::hasVectorizeTransformation(const Loop *L) {
396 std::optional<bool> Enable =
397 getOptionalBoolLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.enable");
398
399 if (Enable == false)
400 return TM_SuppressedByUser;
401
402 std::optional<ElementCount> VectorizeWidth =
403 getOptionalElementCountLoopAttribute(TheLoop: L);
404 std::optional<int> InterleaveCount =
405 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.interleave.count");
406
407 // 'Forcing' vector width and interleave count to one effectively disables
408 // this tranformation.
409 if (Enable == true && VectorizeWidth && VectorizeWidth->isScalar() &&
410 InterleaveCount == 1)
411 return TM_SuppressedByUser;
412
413 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.isvectorized"))
414 return TM_Disable;
415
416 if (Enable == true)
417 return TM_ForcedByUser;
418
419 if ((VectorizeWidth && VectorizeWidth->isScalar()) && InterleaveCount == 1)
420 return TM_Disable;
421
422 if ((VectorizeWidth && VectorizeWidth->isVector()) || InterleaveCount > 1)
423 return TM_Enable;
424
425 if (hasDisableAllTransformsHint(L))
426 return TM_Disable;
427
428 return TM_Unspecified;
429}
430
431TransformationMode llvm::hasDistributeTransformation(const Loop *L) {
432 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.distribute.enable"))
433 return TM_ForcedByUser;
434
435 if (hasDisableAllTransformsHint(L))
436 return TM_Disable;
437
438 return TM_Unspecified;
439}
440
441TransformationMode llvm::hasLICMVersioningTransformation(const Loop *L) {
442 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.licm_versioning.disable"))
443 return TM_SuppressedByUser;
444
445 if (hasDisableAllTransformsHint(L))
446 return TM_Disable;
447
448 return TM_Unspecified;
449}
450
451/// Does a BFS from a given node to all of its children inside a given loop.
452/// The returned vector of basic blocks includes the starting point.
453SmallVector<BasicBlock *, 16> llvm::collectChildrenInLoop(DominatorTree *DT,
454 DomTreeNode *N,
455 const Loop *CurLoop) {
456 SmallVector<BasicBlock *, 16> Worklist;
457 auto AddRegionToWorklist = [&](DomTreeNode *DTN) {
458 // Only include subregions in the top level loop.
459 BasicBlock *BB = DTN->getBlock();
460 if (CurLoop->contains(BB))
461 Worklist.push_back(Elt: DTN->getBlock());
462 };
463
464 AddRegionToWorklist(N);
465
466 for (size_t I = 0; I < Worklist.size(); I++) {
467 for (DomTreeNode *Child : DT->getNode(BB: Worklist[I])->children())
468 AddRegionToWorklist(Child);
469 }
470
471 return Worklist;
472}
473
474bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
475 int LatchIdx = PN->getBasicBlockIndex(BB: LatchBlock);
476 assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode");
477 Value *IncV = PN->getIncomingValue(i: LatchIdx);
478
479 for (User *U : PN->users())
480 if (U != Cond && U != IncV) return false;
481
482 for (User *U : IncV->users())
483 if (U != Cond && U != PN) return false;
484 return true;
485}
486
487
488void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
489 LoopInfo *LI, MemorySSA *MSSA) {
490 assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
491 auto *Preheader = L->getLoopPreheader();
492 assert(Preheader && "Preheader should exist!");
493
494 std::unique_ptr<MemorySSAUpdater> MSSAU;
495 if (MSSA)
496 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
497
498 // Now that we know the removal is safe, remove the loop by changing the
499 // branch from the preheader to go to the single exit block.
500 //
501 // Because we're deleting a large chunk of code at once, the sequence in which
502 // we remove things is very important to avoid invalidation issues.
503
504 // Tell ScalarEvolution that the loop is deleted. Do this before
505 // deleting the loop so that ScalarEvolution can look at the loop
506 // to determine what it needs to clean up.
507 if (SE) {
508 SE->forgetLoop(L);
509 SE->forgetBlockAndLoopDispositions();
510 }
511
512 Instruction *OldTerm = Preheader->getTerminator();
513 assert(!OldTerm->mayHaveSideEffects() &&
514 "Preheader must end with a side-effect-free terminator");
515 assert(OldTerm->getNumSuccessors() == 1 &&
516 "Preheader must have a single successor");
517 // Connect the preheader to the exit block. Keep the old edge to the header
518 // around to perform the dominator tree update in two separate steps
519 // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge
520 // preheader -> header.
521 //
522 //
523 // 0. Preheader 1. Preheader 2. Preheader
524 // | | | |
525 // V | V |
526 // Header <--\ | Header <--\ | Header <--\
527 // | | | | | | | | | | |
528 // | V | | | V | | | V |
529 // | Body --/ | | Body --/ | | Body --/
530 // V V V V V
531 // Exit Exit Exit
532 //
533 // By doing this is two separate steps we can perform the dominator tree
534 // update without using the batch update API.
535 //
536 // Even when the loop is never executed, we cannot remove the edge from the
537 // source block to the exit block. Consider the case where the unexecuted loop
538 // branches back to an outer loop. If we deleted the loop and removed the edge
539 // coming to this inner loop, this will break the outer loop structure (by
540 // deleting the backedge of the outer loop). If the outer loop is indeed a
541 // non-loop, it will be deleted in a future iteration of loop deletion pass.
542 IRBuilder<> Builder(OldTerm);
543
544 auto *ExitBlock = L->getUniqueExitBlock();
545 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
546 if (ExitBlock) {
547 assert(ExitBlock && "Should have a unique exit block!");
548 assert(L->hasDedicatedExits() && "Loop should have dedicated exits!");
549
550 Builder.CreateCondBr(Cond: Builder.getFalse(), True: L->getHeader(), False: ExitBlock);
551 // Remove the old branch. The conditional branch becomes a new terminator.
552 OldTerm->eraseFromParent();
553
554 // Rewrite phis in the exit block to get their inputs from the Preheader
555 // instead of the exiting block.
556 for (PHINode &P : ExitBlock->phis()) {
557 // Set the zero'th element of Phi to be from the preheader and remove all
558 // other incoming values. Given the loop has dedicated exits, all other
559 // incoming values must be from the exiting blocks.
560 int PredIndex = 0;
561 P.setIncomingBlock(i: PredIndex, BB: Preheader);
562 // Removes all incoming values from all other exiting blocks (including
563 // duplicate values from an exiting block).
564 // Nuke all entries except the zero'th entry which is the preheader entry.
565 P.removeIncomingValueIf(Predicate: [](unsigned Idx) { return Idx != 0; },
566 /* DeletePHIIfEmpty */ false);
567
568 assert((P.getNumIncomingValues() == 1 &&
569 P.getIncomingBlock(PredIndex) == Preheader) &&
570 "Should have exactly one value and that's from the preheader!");
571 }
572
573 if (DT) {
574 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}});
575 if (MSSA) {
576 MSSAU->applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}},
577 DT&: *DT);
578 if (VerifyMemorySSA)
579 MSSA->verifyMemorySSA();
580 }
581 }
582
583 // Disconnect the loop body by branching directly to its exit.
584 Builder.SetInsertPoint(Preheader->getTerminator());
585 Builder.CreateBr(Dest: ExitBlock);
586 // Remove the old branch.
587 Preheader->getTerminator()->eraseFromParent();
588 } else {
589 assert(L->hasNoExitBlocks() &&
590 "Loop should have either zero or one exit blocks.");
591
592 Builder.SetInsertPoint(OldTerm);
593 Builder.CreateUnreachable();
594 Preheader->getTerminator()->eraseFromParent();
595 }
596
597 if (DT) {
598 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}});
599 if (MSSA) {
600 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}},
601 DT&: *DT);
602 SmallSetVector<BasicBlock *, 8> DeadBlockSet(L->block_begin(),
603 L->block_end());
604 MSSAU->removeBlocks(DeadBlocks: DeadBlockSet);
605 if (VerifyMemorySSA)
606 MSSA->verifyMemorySSA();
607 }
608 }
609
610 // Use a map to unique and a vector to guarantee deterministic ordering.
611 llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
612 llvm::SmallVector<DbgVariableRecord *, 4> DeadDbgVariableRecords;
613
614 // Given LCSSA form is satisfied, we should not have users of instructions
615 // within the dead loop outside of the loop. However, LCSSA doesn't take
616 // unreachable uses into account. We handle them here.
617 // We could do it after drop all references (in this case all users in the
618 // loop will be already eliminated and we have less work to do but according
619 // to API doc of User::dropAllReferences only valid operation after dropping
620 // references, is deletion. So let's substitute all usages of
621 // instruction from the loop with poison value of corresponding type first.
622 for (auto *Block : L->blocks())
623 for (Instruction &I : *Block) {
624 auto *Poison = PoisonValue::get(T: I.getType());
625 for (Use &U : llvm::make_early_inc_range(Range: I.uses())) {
626 if (auto *Usr = dyn_cast<Instruction>(Val: U.getUser()))
627 if (L->contains(BB: Usr->getParent()))
628 continue;
629 // If we have a DT then we can check that uses outside a loop only in
630 // unreachable block.
631 if (DT)
632 assert(!DT->isReachableFromEntry(U) &&
633 "Unexpected user in reachable block");
634 U.set(Poison);
635 }
636
637 if (ExitBlock) {
638 // For one of each variable encountered, preserve a debug record (set
639 // to Poison) and transfer it to the loop exit. This terminates any
640 // variable locations that were set during the loop.
641 for (DbgVariableRecord &DVR :
642 llvm::make_early_inc_range(Range: filterDbgVars(R: I.getDbgRecordRange()))) {
643 DebugVariable Key(DVR.getVariable(), DVR.getExpression(),
644 DVR.getDebugLoc().get());
645 if (!DeadDebugSet.insert(V: Key).second)
646 continue;
647 // Unlinks the DVR from it's container, for later insertion.
648 DVR.removeFromParent();
649 DeadDbgVariableRecords.push_back(Elt: &DVR);
650 }
651 }
652 }
653
654 if (ExitBlock) {
655 // After the loop has been deleted all the values defined and modified
656 // inside the loop are going to be unavailable. Values computed in the
657 // loop will have been deleted, automatically causing their debug uses
658 // be be replaced with undef. Loop invariant values will still be available.
659 // Move dbg.values out the loop so that earlier location ranges are still
660 // terminated and loop invariant assignments are preserved.
661 DIBuilder DIB(*ExitBlock->getModule());
662 BasicBlock::iterator InsertDbgValueBefore =
663 ExitBlock->getFirstInsertionPt();
664 assert(InsertDbgValueBefore != ExitBlock->end() &&
665 "There should be a non-PHI instruction in exit block, else these "
666 "instructions will have no parent.");
667
668 // Due to the "head" bit in BasicBlock::iterator, we're going to insert
669 // each DbgVariableRecord right at the start of the block, wheras dbg.values
670 // would be repeatedly inserted before the first instruction. To replicate
671 // this behaviour, do it backwards.
672 for (DbgVariableRecord *DVR : llvm::reverse(C&: DeadDbgVariableRecords))
673 ExitBlock->insertDbgRecordBefore(DR: DVR, Here: InsertDbgValueBefore);
674 }
675
676 // Remove the block from the reference counting scheme, so that we can
677 // delete it freely later.
678 for (auto *Block : L->blocks())
679 Block->dropAllReferences();
680
681 if (MSSA && VerifyMemorySSA)
682 MSSA->verifyMemorySSA();
683
684 if (LI) {
685 // Erase the instructions and the blocks without having to worry
686 // about ordering because we already dropped the references.
687 // NOTE: This iteration is safe because erasing the block does not remove
688 // its entry from the loop's block list. We do that in the next section.
689 for (BasicBlock *BB : L->blocks())
690 BB->eraseFromParent();
691
692 // Finally, the blocks from loopinfo. This has to happen late because
693 // otherwise our loop iterators won't work.
694
695 SmallPtrSet<BasicBlock *, 8> blocks(llvm::from_range, L->blocks());
696 for (BasicBlock *BB : blocks)
697 LI->removeBlock(BB);
698
699 // The last step is to update LoopInfo now that we've eliminated this loop.
700 // Note: LoopInfo::erase remove the given loop and relink its subloops with
701 // its parent. While removeLoop/removeChildLoop remove the given loop but
702 // not relink its subloops, which is what we want.
703 if (Loop *ParentLoop = L->getParentLoop()) {
704 Loop::iterator I = find(Range&: *ParentLoop, Val: L);
705 assert(I != ParentLoop->end() && "Couldn't find loop");
706 ParentLoop->removeChildLoop(I);
707 } else {
708 Loop::iterator I = find(Range&: *LI, Val: L);
709 assert(I != LI->end() && "Couldn't find loop");
710 LI->removeLoop(I);
711 }
712 LI->destroy(L);
713 }
714}
715
716void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
717 LoopInfo &LI, MemorySSA *MSSA) {
718 auto *Latch = L->getLoopLatch();
719 assert(Latch && "multiple latches not yet supported");
720 auto *Header = L->getHeader();
721 Loop *OutermostLoop = L->getOutermostLoop();
722
723 SE.forgetLoop(L);
724 SE.forgetBlockAndLoopDispositions();
725
726 std::unique_ptr<MemorySSAUpdater> MSSAU;
727 if (MSSA)
728 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
729
730 // Update the CFG and domtree. We chose to special case a couple of
731 // of common cases for code quality and test readability reasons.
732 [&]() -> void {
733 if (auto *BI = dyn_cast<UncondBrInst>(Val: Latch->getTerminator())) {
734 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
735 (void)changeToUnreachable(I: BI, /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
736 return;
737 }
738 if (auto *BI = dyn_cast<CondBrInst>(Val: Latch->getTerminator())) {
739 // Conditional latch/exit - note that latch can be shared by inner
740 // and outer loop so the other target doesn't need to an exit
741 if (L->isLoopExiting(BB: Latch)) {
742 // TODO: Generalize ConstantFoldTerminator so that it can be used
743 // here without invalidating LCSSA or MemorySSA. (Tricky case for
744 // LCSSA: header is an exit block of a preceeding sibling loop w/o
745 // dedicated exits.)
746 const unsigned ExitIdx = L->contains(BB: BI->getSuccessor(i: 0)) ? 1 : 0;
747 BasicBlock *ExitBB = BI->getSuccessor(i: ExitIdx);
748
749 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
750 Header->removePredecessor(Pred: Latch, KeepOneInputPHIs: true);
751
752 IRBuilder<> Builder(BI);
753 auto *NewBI = Builder.CreateBr(Dest: ExitBB);
754 // Transfer the metadata to the new branch instruction (minus the
755 // loop info since this is no longer a loop)
756 NewBI->copyMetadata(SrcInst: *BI, WL: {LLVMContext::MD_dbg,
757 LLVMContext::MD_annotation});
758
759 BI->eraseFromParent();
760 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}});
761 if (MSSA)
762 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}}, DT);
763 return;
764 }
765 }
766
767 // General case. By splitting the backedge, and then explicitly making it
768 // unreachable we gracefully handle corner cases such as switch and invoke
769 // termiantors.
770 auto *BackedgeBB = SplitEdge(From: Latch, To: Header, DT: &DT, LI: &LI, MSSAU: MSSAU.get());
771
772 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
773 (void)changeToUnreachable(I: BackedgeBB->getTerminator(),
774 /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
775 }();
776
777 // Erase (and destroy) this loop instance. Handles relinking sub-loops
778 // and blocks within the loop as needed.
779 LI.erase(L);
780
781 // If the loop we broke had a parent, then changeToUnreachable might have
782 // caused a block to be removed from the parent loop (see loop_nest_lcssa
783 // test case in zero-btc.ll for an example), thus changing the parent's
784 // exit blocks. If that happened, we need to rebuild LCSSA on the outermost
785 // loop which might have a had a block removed.
786 if (OutermostLoop != L)
787 formLCSSARecursively(L&: *OutermostLoop, DT, LI: &LI, SE: &SE);
788}
789
790
791/// Checks if \p L has an exiting latch branch. There may also be other
792/// exiting blocks. Returns branch instruction terminating the loop
793/// latch if above check is successful, nullptr otherwise.
794static CondBrInst *getExpectedExitLoopLatchBranch(Loop *L) {
795 BasicBlock *Latch = L->getLoopLatch();
796 if (!Latch)
797 return nullptr;
798
799 CondBrInst *LatchBR = dyn_cast<CondBrInst>(Val: Latch->getTerminator());
800 if (!LatchBR || !L->isLoopExiting(BB: Latch))
801 return nullptr;
802
803 assert((LatchBR->getSuccessor(0) == L->getHeader() ||
804 LatchBR->getSuccessor(1) == L->getHeader()) &&
805 "At least one edge out of the latch must go to the header");
806
807 return LatchBR;
808}
809
810struct DbgLoop {
811 const Loop *L;
812 explicit DbgLoop(const Loop *L) : L(L) {}
813};
814
815#ifndef NDEBUG
816static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
817 OS << "function ";
818 D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
819 return OS << " " << *D.L;
820}
821#endif // NDEBUG
822
823static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
824 // Currently we take the estimate exit count only from the loop latch,
825 // ignoring other exiting blocks. This can overestimate the trip count
826 // if we exit through another exit, but can never underestimate it.
827 // TODO: incorporate information from other exits
828 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
829 if (!ExitingBranch) {
830 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
831 << "latch branch of required form in " << DbgLoop(L)
832 << "\n");
833 return std::nullopt;
834 }
835
836 // To estimate the number of times the loop body was executed, we want to
837 // know the number of times the backedge was taken, vs. the number of times
838 // we exited the loop.
839 uint64_t LoopWeight, ExitWeight;
840 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight)) {
841 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
842 << "weights for " << DbgLoop(L) << "\n");
843 return std::nullopt;
844 }
845
846 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
847 std::swap(a&: LoopWeight, b&: ExitWeight);
848
849 if (!ExitWeight) {
850 // Don't have a way to return predicated infinite
851 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
852 << "probability for " << DbgLoop(L) << "\n");
853 return std::nullopt;
854 }
855
856 // Estimated exit count is a ratio of the loop weight by the weight of the
857 // edge exiting the loop, rounded to nearest.
858 uint64_t ExitCount = llvm::divideNearest(Numerator: LoopWeight, Denominator: ExitWeight);
859
860 // When ExitCount + 1 would wrap in unsigned, saturate at UINT_MAX.
861 if (ExitCount >= std::numeric_limits<unsigned>::max())
862 return std::numeric_limits<unsigned>::max();
863
864 // Estimated trip count is one plus estimated exit count.
865 uint64_t TC = ExitCount + 1;
866 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Estimated trip count of " << TC
867 << " for " << DbgLoop(L) << "\n");
868 return TC;
869}
870
871std::optional<unsigned>
872llvm::getLoopEstimatedTripCount(Loop *L,
873 unsigned *EstimatedLoopInvocationWeight) {
874 // If EstimatedLoopInvocationWeight, we do not support this loop if
875 // getExpectedExitLoopLatchBranch returns nullptr.
876 //
877 // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
878 // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
879 // it was created for an inner loop. The problem is that loop metadata is
880 // attached to the branch instruction in the loop latch block, but that can be
881 // shared by the loops. A solution is to attach loop metadata to loop headers
882 // instead, but that would be a large change to LLVM.
883 //
884 // Until that happens, we work around the problem as follows.
885 // getExpectedExitLoopLatchBranch (which also guards
886 // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
887 // one latch and that latch has exactly two successors one of which is an exit
888 // from the loop. If the latch is shared by nested loops, then that condition
889 // might hold for the inner loop but cannot hold for the outer loop:
890 // - Because the latch is shared, it must have at least two successors: the
891 // inner loop header and the outer loop header, which is also an exit for
892 // the inner loop. That satisifies the condition for the inner loop.
893 // - To satsify the condition for the outer loop, the latch must have a third
894 // successor that is an exit for the outer loop. But that violates the
895 // condition for both loops.
896 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
897 if (!ExitingBranch)
898 return std::nullopt;
899
900 // If requested, either compute *EstimatedLoopInvocationWeight or return
901 // nullopt if cannot.
902 //
903 // TODO: Eventually, once all passes have migrated away from setting branch
904 // weights to indicate estimated trip counts, this function will drop the
905 // EstimatedLoopInvocationWeight parameter.
906 if (EstimatedLoopInvocationWeight) {
907 uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
908 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight))
909 return std::nullopt;
910 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
911 std::swap(a&: LoopWeight, b&: ExitWeight);
912 if (!ExitWeight)
913 return std::nullopt;
914 *EstimatedLoopInvocationWeight = ExitWeight;
915 }
916
917 // Return the estimated trip count from metadata unless the metadata is
918 // missing or has no value.
919 //
920 // Some passes set llvm.loop.estimated_trip_count to 0. For example, after
921 // peeling 10 or more iterations from a loop with an estimated trip count of
922 // 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It
923 // indicates that, each time execution reaches the peeled iterations,
924 // execution is estimated to exit them without reaching the remaining loop's
925 // header.
926 //
927 // Even if the probability of reaching a loop's header is low, if it is
928 // reached, it is the start of an iteration. Consequently, some passes
929 // historically assume that llvm::getLoopEstimatedTripCount always returns a
930 // positive count or std::nullopt. Thus, return std::nullopt when
931 // llvm.loop.estimated_trip_count is 0.
932 if (auto TC = getOptionalIntLoopAttribute(TheLoop: L, Name: LLVMLoopEstimatedTripCount)) {
933 LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
934 << LLVMLoopEstimatedTripCount << " metadata has trip "
935 << "count of " << *TC
936 << (*TC == 0 ? " (returning std::nullopt)" : "")
937 << " for " << DbgLoop(L) << "\n");
938 return *TC == 0 ? std::nullopt : std::optional(*TC);
939 }
940
941 // Estimate the trip count from latch branch weights.
942 return estimateLoopTripCount(L);
943}
944
945bool llvm::setLoopEstimatedTripCount(
946 Loop *L, unsigned EstimatedTripCount,
947 std::optional<unsigned> EstimatedloopInvocationWeight) {
948 // If EstimatedLoopInvocationWeight, we do not support this loop if
949 // getExpectedExitLoopLatchBranch returns nullptr.
950 //
951 // FIXME: See comments in getLoopEstimatedTripCount for why this is required
952 // here regardless of EstimatedLoopInvocationWeight.
953 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
954 if (!LatchBranch)
955 return false;
956
957 // Set the metadata.
958 addStringMetadataToLoop(TheLoop: L, StringMD: LLVMLoopEstimatedTripCount, V: EstimatedTripCount);
959
960 // At the moment, we currently support changing the estimated trip count in
961 // the latch branch's branch weights only. We could extend this API to
962 // manipulate estimated trip counts for any exit.
963 //
964 // TODO: Eventually, once all passes have migrated away from setting branch
965 // weights to indicate estimated trip counts, we will not set branch weights
966 // here at all.
967 if (!EstimatedloopInvocationWeight)
968 return true;
969
970 // Calculate taken and exit weights.
971 unsigned LatchExitWeight = ProfcheckDisableMetadataFixes ? 0 : 1;
972 unsigned BackedgeTakenWeight = 0;
973
974 if (EstimatedTripCount != 0) {
975 LatchExitWeight = *EstimatedloopInvocationWeight;
976 BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
977 }
978
979 // Make a swap if back edge is taken when condition is "false".
980 if (LatchBranch->getSuccessor(i: 0) != L->getHeader())
981 std::swap(a&: BackedgeTakenWeight, b&: LatchExitWeight);
982
983 // Set/Update profile metadata.
984 setBranchWeights(I&: *LatchBranch, Weights: {BackedgeTakenWeight, LatchExitWeight},
985 /*IsExpected=*/false);
986
987 return true;
988}
989
990BranchProbability llvm::getLoopProbability(Loop *L) {
991 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
992 if (!LatchBranch)
993 return BranchProbability::getUnknown();
994 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
995 return getBranchProbability(B: LatchBranch, ForFirstTarget: FirstTargetIsLoop);
996}
997
998bool llvm::setLoopProbability(Loop *L, BranchProbability P) {
999 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
1000 if (!LatchBranch)
1001 return false;
1002 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
1003 setBranchProbability(B: LatchBranch, P, ForFirstTarget: FirstTargetIsLoop);
1004 return true;
1005}
1006
1007BranchProbability llvm::getBranchProbability(CondBrInst *B,
1008 bool ForFirstTarget) {
1009 uint64_t Weight0, Weight1;
1010 if (!extractBranchWeights(I: *B, TrueVal&: Weight0, FalseVal&: Weight1))
1011 return BranchProbability::getUnknown();
1012 uint64_t Denominator = Weight0 + Weight1;
1013 if (Denominator == 0)
1014 return BranchProbability::getUnknown();
1015 if (!ForFirstTarget)
1016 std::swap(a&: Weight0, b&: Weight1);
1017 return BranchProbability::getBranchProbability(Numerator: Weight0, Denominator);
1018}
1019
1020BranchProbability llvm::getBranchProbability(BasicBlock *Src, BasicBlock *Dst) {
1021 assert(Src != Dst && "Passed in same source as destination");
1022
1023 Instruction *TI = Src->getTerminator();
1024 if (!TI || TI->getNumSuccessors() == 0)
1025 return BranchProbability::getZero();
1026
1027 SmallVector<uint32_t, 4> Weights;
1028
1029 if (!extractBranchWeights(I: *TI, Weights)) {
1030 // No metadata
1031 return BranchProbability::getUnknown();
1032 }
1033 assert(TI->getNumSuccessors() == Weights.size() &&
1034 "Missing weights in branch_weights");
1035
1036 uint64_t Total = 0;
1037 uint32_t Numerator = 0;
1038 for (auto [i, Weight] : llvm::enumerate(First&: Weights)) {
1039 if (TI->getSuccessor(Idx: i) == Dst)
1040 Numerator += Weight;
1041 Total += Weight;
1042 }
1043
1044 // Total of edges might be 0 if the metadata is incorrect/set by hand
1045 // or missing. In such case return here to avoid division by 0 later on.
1046 // There might also be a case where the value of Total cannot fit into
1047 // uint32_t, in such case, just bail out.
1048 if (Total == 0 || Total > std::numeric_limits<uint32_t>::max())
1049 return BranchProbability::getUnknown();
1050
1051 return BranchProbability(Numerator, Total);
1052}
1053
1054void llvm::setBranchProbability(CondBrInst *B, BranchProbability P,
1055 bool ForFirstTarget) {
1056 BranchProbability Prob0 = P;
1057 BranchProbability Prob1 = P.getCompl();
1058 if (!ForFirstTarget)
1059 std::swap(a&: Prob0, b&: Prob1);
1060 setBranchWeights(I&: *B, Weights: {Prob0.getNumerator(), Prob1.getNumerator()},
1061 /*IsExpected=*/false);
1062}
1063
1064bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
1065 ScalarEvolution &SE) {
1066 Loop *OuterL = InnerLoop->getParentLoop();
1067 if (!OuterL)
1068 return true;
1069
1070 // Get the backedge taken count for the inner loop
1071 BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch();
1072 const SCEV *InnerLoopBECountSC = SE.getExitCount(L: InnerLoop, ExitingBlock: InnerLoopLatch);
1073 if (isa<SCEVCouldNotCompute>(Val: InnerLoopBECountSC) ||
1074 !InnerLoopBECountSC->getType()->isIntegerTy())
1075 return false;
1076
1077 // Get whether count is invariant to the outer loop
1078 ScalarEvolution::LoopDisposition LD =
1079 SE.getLoopDisposition(S: InnerLoopBECountSC, L: OuterL);
1080 if (LD != ScalarEvolution::LoopInvariant)
1081 return false;
1082
1083 return true;
1084}
1085
1086constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
1087 switch (RK) {
1088 default:
1089 llvm_unreachable("Unexpected recurrence kind");
1090 case RecurKind::AddChainWithSubs:
1091 case RecurKind::Sub:
1092 case RecurKind::Add:
1093 return Intrinsic::vector_reduce_add;
1094 case RecurKind::Mul:
1095 return Intrinsic::vector_reduce_mul;
1096 case RecurKind::And:
1097 return Intrinsic::vector_reduce_and;
1098 case RecurKind::Or:
1099 return Intrinsic::vector_reduce_or;
1100 case RecurKind::Xor:
1101 return Intrinsic::vector_reduce_xor;
1102 case RecurKind::FMulAdd:
1103 case RecurKind::FAdd:
1104 return Intrinsic::vector_reduce_fadd;
1105 case RecurKind::FMul:
1106 return Intrinsic::vector_reduce_fmul;
1107 case RecurKind::SMax:
1108 return Intrinsic::vector_reduce_smax;
1109 case RecurKind::SMin:
1110 return Intrinsic::vector_reduce_smin;
1111 case RecurKind::UMax:
1112 return Intrinsic::vector_reduce_umax;
1113 case RecurKind::UMin:
1114 return Intrinsic::vector_reduce_umin;
1115 case RecurKind::FMax:
1116 case RecurKind::FMaxNum:
1117 return Intrinsic::vector_reduce_fmax;
1118 case RecurKind::FMin:
1119 case RecurKind::FMinNum:
1120 return Intrinsic::vector_reduce_fmin;
1121 case RecurKind::FMaximum:
1122 return Intrinsic::vector_reduce_fmaximum;
1123 case RecurKind::FMinimum:
1124 return Intrinsic::vector_reduce_fminimum;
1125 case RecurKind::FMaximumNum:
1126 return Intrinsic::vector_reduce_fmax;
1127 case RecurKind::FMinimumNum:
1128 return Intrinsic::vector_reduce_fmin;
1129 }
1130}
1131
1132Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
1133 switch (IID) {
1134 default:
1135 llvm_unreachable("Unexpected intrinsic id");
1136 case Intrinsic::umin:
1137 return Intrinsic::vector_reduce_umin;
1138 case Intrinsic::umax:
1139 return Intrinsic::vector_reduce_umax;
1140 case Intrinsic::smin:
1141 return Intrinsic::vector_reduce_smin;
1142 case Intrinsic::smax:
1143 return Intrinsic::vector_reduce_smax;
1144 }
1145}
1146
1147// This is the inverse to getReductionForBinop
1148unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
1149 switch (RdxID) {
1150 case Intrinsic::vector_reduce_fadd:
1151 return Instruction::FAdd;
1152 case Intrinsic::vector_reduce_fmul:
1153 return Instruction::FMul;
1154 case Intrinsic::vector_reduce_add:
1155 return Instruction::Add;
1156 case Intrinsic::vector_reduce_mul:
1157 return Instruction::Mul;
1158 case Intrinsic::vector_reduce_and:
1159 return Instruction::And;
1160 case Intrinsic::vector_reduce_or:
1161 return Instruction::Or;
1162 case Intrinsic::vector_reduce_xor:
1163 return Instruction::Xor;
1164 case Intrinsic::vector_reduce_smax:
1165 case Intrinsic::vector_reduce_smin:
1166 case Intrinsic::vector_reduce_umax:
1167 case Intrinsic::vector_reduce_umin:
1168 return Instruction::ICmp;
1169 case Intrinsic::vector_reduce_fmax:
1170 case Intrinsic::vector_reduce_fmin:
1171 return Instruction::FCmp;
1172 default:
1173 llvm_unreachable("Unexpected ID");
1174 }
1175}
1176
1177// This is the inverse to getArithmeticReductionInstruction
1178Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
1179 switch (Opc) {
1180 default:
1181 break;
1182 case Instruction::Add:
1183 return Intrinsic::vector_reduce_add;
1184 case Instruction::Mul:
1185 return Intrinsic::vector_reduce_mul;
1186 case Instruction::And:
1187 return Intrinsic::vector_reduce_and;
1188 case Instruction::Or:
1189 return Intrinsic::vector_reduce_or;
1190 case Instruction::Xor:
1191 return Intrinsic::vector_reduce_xor;
1192 }
1193 return Intrinsic::not_intrinsic;
1194}
1195
1196Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
1197 switch (RdxID) {
1198 default:
1199 llvm_unreachable("Unknown min/max recurrence kind");
1200 case Intrinsic::vector_reduce_umin:
1201 return Intrinsic::umin;
1202 case Intrinsic::vector_reduce_umax:
1203 return Intrinsic::umax;
1204 case Intrinsic::vector_reduce_smin:
1205 return Intrinsic::smin;
1206 case Intrinsic::vector_reduce_smax:
1207 return Intrinsic::smax;
1208 case Intrinsic::vector_reduce_fmin:
1209 return Intrinsic::minnum;
1210 case Intrinsic::vector_reduce_fmax:
1211 return Intrinsic::maxnum;
1212 case Intrinsic::vector_reduce_fminimum:
1213 return Intrinsic::minimum;
1214 case Intrinsic::vector_reduce_fmaximum:
1215 return Intrinsic::maximum;
1216 }
1217}
1218
1219Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
1220 switch (RK) {
1221 default:
1222 llvm_unreachable("Unknown min/max recurrence kind");
1223 case RecurKind::UMin:
1224 return Intrinsic::umin;
1225 case RecurKind::UMax:
1226 return Intrinsic::umax;
1227 case RecurKind::SMin:
1228 return Intrinsic::smin;
1229 case RecurKind::SMax:
1230 return Intrinsic::smax;
1231 case RecurKind::FMin:
1232 case RecurKind::FMinNum:
1233 return Intrinsic::minnum;
1234 case RecurKind::FMax:
1235 case RecurKind::FMaxNum:
1236 return Intrinsic::maxnum;
1237 case RecurKind::FMinimum:
1238 return Intrinsic::minimum;
1239 case RecurKind::FMaximum:
1240 return Intrinsic::maximum;
1241 case RecurKind::FMinimumNum:
1242 return Intrinsic::minimumnum;
1243 case RecurKind::FMaximumNum:
1244 return Intrinsic::maximumnum;
1245 }
1246}
1247
1248RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
1249 switch (RdxID) {
1250 case Intrinsic::vector_reduce_smax:
1251 return RecurKind::SMax;
1252 case Intrinsic::vector_reduce_smin:
1253 return RecurKind::SMin;
1254 case Intrinsic::vector_reduce_umax:
1255 return RecurKind::UMax;
1256 case Intrinsic::vector_reduce_umin:
1257 return RecurKind::UMin;
1258 case Intrinsic::vector_reduce_fmax:
1259 return RecurKind::FMax;
1260 case Intrinsic::vector_reduce_fmin:
1261 return RecurKind::FMin;
1262 default:
1263 return RecurKind::None;
1264 }
1265}
1266
1267CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
1268 switch (RK) {
1269 default:
1270 llvm_unreachable("Unknown min/max recurrence kind");
1271 case RecurKind::UMin:
1272 return CmpInst::ICMP_ULT;
1273 case RecurKind::UMax:
1274 return CmpInst::ICMP_UGT;
1275 case RecurKind::SMin:
1276 return CmpInst::ICMP_SLT;
1277 case RecurKind::SMax:
1278 return CmpInst::ICMP_SGT;
1279 case RecurKind::FMin:
1280 return CmpInst::FCMP_OLT;
1281 case RecurKind::FMax:
1282 return CmpInst::FCMP_OGT;
1283 // We do not add FMinimum/FMaximum recurrence kind here since there is no
1284 // equivalent predicate which compares signed zeroes according to the
1285 // semantics of the intrinsics (llvm.minimum/maximum).
1286 }
1287}
1288
1289Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
1290 Value *Right) {
1291 Type *Ty = Left->getType();
1292 if (Ty->isIntOrIntVectorTy() ||
1293 (RK == RecurKind::FMinNum || RK == RecurKind::FMaxNum ||
1294 RK == RecurKind::FMinimum || RK == RecurKind::FMaximum ||
1295 RK == RecurKind::FMinimumNum || RK == RecurKind::FMaximumNum)) {
1296 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK);
1297 return Builder.CreateIntrinsic(RetTy: Ty, ID: Id, Args: {Left, Right}, FMFSource: nullptr,
1298 Name: "rdx.minmax");
1299 }
1300 CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
1301 Value *Cmp = Builder.CreateCmp(Pred, LHS: Left, RHS: Right, Name: "rdx.minmax.cmp");
1302 Value *Select = Builder.CreateSelect(C: Cmp, True: Left, False: Right, Name: "rdx.minmax.select");
1303 return Select;
1304}
1305
1306// Helper to generate an ordered reduction.
1307Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
1308 unsigned Op, RecurKind RdxKind) {
1309 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1310
1311 // Extract and apply reduction ops in ascending order:
1312 // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
1313 Value *Result = Acc;
1314 for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
1315 Value *Ext =
1316 Builder.CreateExtractElement(Vec: Src, Idx: Builder.getInt32(C: ExtractIdx));
1317
1318 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1319 Result = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: Result, RHS: Ext,
1320 Name: "bin.rdx");
1321 } else {
1322 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1323 "Invalid min/max");
1324 Result = createMinMaxOp(Builder, RK: RdxKind, Left: Result, Right: Ext);
1325 }
1326 }
1327
1328 return Result;
1329}
1330
1331// Helper to generate a log2 shuffle reduction.
1332Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1333 unsigned Op,
1334 TargetTransformInfo::ReductionShuffle RS,
1335 RecurKind RdxKind) {
1336 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1337 // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
1338 // and vector ops, reducing the set of values being computed by half each
1339 // round.
1340 assert(isPowerOf2_32(VF) &&
1341 "Reduction emission only supported for pow2 vectors!");
1342 // Note: fast-math-flags flags are controlled by the builder configuration
1343 // and are assumed to apply to all generated arithmetic instructions. Other
1344 // poison generating flags (nsw/nuw/inbounds/inrange/exact) are not part
1345 // of the builder configuration, and since they're not passed explicitly,
1346 // will never be relevant here. Note that it would be generally unsound to
1347 // propagate these from an intrinsic call to the expansion anyways as we/
1348 // change the order of operations.
1349 auto BuildShuffledOp = [&Builder, &Op,
1350 &RdxKind](SmallVectorImpl<int> &ShuffleMask,
1351 Value *&TmpVec) -> void {
1352 Value *Shuf = Builder.CreateShuffleVector(V: TmpVec, Mask: ShuffleMask, Name: "rdx.shuf");
1353 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1354 TmpVec = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: TmpVec, RHS: Shuf,
1355 Name: "bin.rdx");
1356 } else {
1357 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1358 "Invalid min/max");
1359 TmpVec = createMinMaxOp(Builder, RK: RdxKind, Left: TmpVec, Right: Shuf);
1360 }
1361 };
1362
1363 Value *TmpVec = Src;
1364 if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1365 SmallVector<int, 32> ShuffleMask(VF);
1366 for (unsigned stride = 1; stride < VF; stride <<= 1) {
1367 // Initialise the mask with undef.
1368 llvm::fill(Range&: ShuffleMask, Value: -1);
1369 for (unsigned j = 0; j < VF; j += stride << 1) {
1370 ShuffleMask[j] = j + stride;
1371 }
1372 BuildShuffledOp(ShuffleMask, TmpVec);
1373 }
1374 } else {
1375 SmallVector<int, 32> ShuffleMask(VF);
1376 for (unsigned i = VF; i != 1; i >>= 1) {
1377 // Move the upper half of the vector to the lower half.
1378 for (unsigned j = 0; j != i / 2; ++j)
1379 ShuffleMask[j] = i / 2 + j;
1380
1381 // Fill the rest of the mask with undef.
1382 std::fill(first: &ShuffleMask[i / 2], last: ShuffleMask.end(), value: -1);
1383 BuildShuffledOp(ShuffleMask, TmpVec);
1384 }
1385 }
1386 // The result is in the first element of the vector.
1387 return Builder.CreateExtractElement(Vec: TmpVec, Idx: Builder.getInt32(C: 0));
1388}
1389
1390Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
1391 Value *InitVal, PHINode *OrigPhi) {
1392 Value *NewVal = nullptr;
1393
1394 // First use the original phi to determine the new value we're trying to
1395 // select from in the loop.
1396 SelectInst *SI = nullptr;
1397 for (auto *U : OrigPhi->users()) {
1398 if ((SI = dyn_cast<SelectInst>(Val: U)))
1399 break;
1400 }
1401 assert(SI && "One user of the original phi should be a select");
1402
1403 if (SI->getTrueValue() == OrigPhi)
1404 NewVal = SI->getFalseValue();
1405 else {
1406 assert(SI->getFalseValue() == OrigPhi &&
1407 "At least one input to the select should be the original Phi");
1408 NewVal = SI->getTrueValue();
1409 }
1410
1411 // If any predicate is true it means that we want to select the new value.
1412 Value *AnyOf =
1413 Src->getType()->isVectorTy() ? Builder.CreateOrReduce(Src) : Src;
1414 // The compares in the loop may yield poison, which propagates through the
1415 // bitwise ORs. Freeze it here before the condition is used.
1416 AnyOf = Builder.CreateFreeze(V: AnyOf);
1417 return Builder.CreateSelect(C: AnyOf, True: NewVal, False: InitVal, Name: "rdx.select");
1418}
1419
1420Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
1421 FastMathFlags Flags) {
1422 bool Negative = false;
1423 switch (RdxID) {
1424 default:
1425 llvm_unreachable("Expecting a reduction intrinsic");
1426 case Intrinsic::vector_reduce_add:
1427 case Intrinsic::vector_reduce_mul:
1428 case Intrinsic::vector_reduce_or:
1429 case Intrinsic::vector_reduce_xor:
1430 case Intrinsic::vector_reduce_and:
1431 case Intrinsic::vector_reduce_fadd:
1432 case Intrinsic::vector_reduce_fmul: {
1433 unsigned Opc = getArithmeticReductionInstruction(RdxID);
1434 return ConstantExpr::getBinOpIdentity(Opcode: Opc, Ty, AllowRHSConstant: false,
1435 NSZ: Flags.noSignedZeros());
1436 }
1437 case Intrinsic::vector_reduce_umax:
1438 case Intrinsic::vector_reduce_umin:
1439 case Intrinsic::vector_reduce_smin:
1440 case Intrinsic::vector_reduce_smax: {
1441 Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
1442 return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
1443 }
1444 case Intrinsic::vector_reduce_fmax:
1445 case Intrinsic::vector_reduce_fmaximum:
1446 Negative = true;
1447 [[fallthrough]];
1448 case Intrinsic::vector_reduce_fmin:
1449 case Intrinsic::vector_reduce_fminimum: {
1450 bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
1451 RdxID == Intrinsic::vector_reduce_fmaximum;
1452 const fltSemantics &Semantics = Ty->getFltSemantics();
1453 return (!Flags.noNaNs() && !PropagatesNaN)
1454 ? ConstantFP::getQNaN(Ty, Negative)
1455 : !Flags.noInfs()
1456 ? ConstantFP::getInfinity(Ty, Negative)
1457 : ConstantFP::get(Ty, V: APFloat::getLargest(Sem: Semantics, Negative));
1458 }
1459 }
1460}
1461
1462Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
1463 assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
1464 (FMF.noNaNs() && FMF.noSignedZeros())) &&
1465 "nnan, nsz is expected to be set for FP min/max reduction.");
1466 Intrinsic::ID RdxID = getReductionIntrinsicID(RK: K);
1467 return getReductionIdentity(RdxID, Ty: Tp, Flags: FMF);
1468}
1469
1470Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1471 RecurKind RdxKind) {
1472 auto *SrcVecEltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1473 auto getIdentity = [&]() {
1474 return getRecurrenceIdentity(K: RdxKind, Tp: SrcVecEltTy,
1475 FMF: Builder.getFastMathFlags());
1476 };
1477 switch (RdxKind) {
1478 case RecurKind::AddChainWithSubs:
1479 case RecurKind::Sub:
1480 case RecurKind::Add:
1481 case RecurKind::Mul:
1482 case RecurKind::And:
1483 case RecurKind::Or:
1484 case RecurKind::Xor:
1485 case RecurKind::SMax:
1486 case RecurKind::SMin:
1487 case RecurKind::UMax:
1488 case RecurKind::UMin:
1489 case RecurKind::FMax:
1490 case RecurKind::FMin:
1491 case RecurKind::FMinNum:
1492 case RecurKind::FMaxNum:
1493 case RecurKind::FMinimum:
1494 case RecurKind::FMaximum:
1495 case RecurKind::FMinimumNum:
1496 case RecurKind::FMaximumNum:
1497 return Builder.CreateUnaryIntrinsic(ID: getReductionIntrinsicID(RK: RdxKind), V: Src);
1498 case RecurKind::FMulAdd:
1499 case RecurKind::FAdd:
1500 return Builder.CreateFAddReduce(Acc: getIdentity(), Src);
1501 case RecurKind::FMul:
1502 return Builder.CreateFMulReduce(Acc: getIdentity(), Src);
1503 default:
1504 llvm_unreachable("Unhandled opcode");
1505 }
1506}
1507
1508Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1509 RecurKind Kind, Value *Mask, Value *EVL) {
1510 assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1511 !RecurrenceDescriptor::isFindRecurrenceKind(Kind) &&
1512 "AnyOf and FindIV reductions are not supported.");
1513 Intrinsic::ID Id = getReductionIntrinsicID(RK: Kind);
1514 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1515 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1516 "No VPIntrinsic for this reduction");
1517 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1518 Value *Iden = getRecurrenceIdentity(K: Kind, Tp: EltTy, FMF: Builder.getFastMathFlags());
1519 Value *Ops[] = {Iden, Src, Mask, EVL};
1520 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1521}
1522
1523Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
1524 Value *Src, Value *Start) {
1525 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1526 "Unexpected reduction kind");
1527 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1528 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1529
1530 return B.CreateFAddReduce(Acc: Start, Src);
1531}
1532
1533Value *llvm::createOrderedReduction(IRBuilderBase &Builder, RecurKind Kind,
1534 Value *Src, Value *Start, Value *Mask,
1535 Value *EVL) {
1536 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1537 "Unexpected reduction kind");
1538 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1539 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1540
1541 Intrinsic::ID Id = getReductionIntrinsicID(RK: RecurKind::FAdd);
1542 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1543 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1544 "No VPIntrinsic for this reduction");
1545 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1546 Value *Ops[] = {Start, Src, Mask, EVL};
1547 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1548}
1549
1550void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
1551 bool IncludeWrapFlags) {
1552 auto *VecOp = dyn_cast<Instruction>(Val: I);
1553 if (!VecOp)
1554 return;
1555 auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(Val: VL[0])
1556 : dyn_cast<Instruction>(Val: OpValue);
1557 if (!Intersection)
1558 return;
1559 const unsigned Opcode = Intersection->getOpcode();
1560 VecOp->copyIRFlags(V: Intersection, IncludeWrapFlags);
1561 for (auto *V : VL) {
1562 auto *Instr = dyn_cast<Instruction>(Val: V);
1563 if (!Instr)
1564 continue;
1565 if (OpValue == nullptr || Opcode == Instr->getOpcode())
1566 VecOp->andIRFlags(V);
1567 }
1568}
1569
1570bool llvm::isKnownNegativeInLoop(const SCEV *S, const Loop *L,
1571 ScalarEvolution &SE) {
1572 const SCEV *Zero = SE.getZero(Ty: S->getType());
1573 return SE.isAvailableAtLoopEntry(S, L) &&
1574 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLT, LHS: S, RHS: Zero);
1575}
1576
1577bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L,
1578 ScalarEvolution &SE) {
1579 const SCEV *Zero = SE.getZero(Ty: S->getType());
1580 return SE.isAvailableAtLoopEntry(S, L) &&
1581 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGE, LHS: S, RHS: Zero);
1582}
1583
1584bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L,
1585 ScalarEvolution &SE) {
1586 const SCEV *Zero = SE.getZero(Ty: S->getType());
1587 return SE.isAvailableAtLoopEntry(S, L) &&
1588 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGT, LHS: S, RHS: Zero);
1589}
1590
1591bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L,
1592 ScalarEvolution &SE) {
1593 const SCEV *Zero = SE.getZero(Ty: S->getType());
1594 return SE.isAvailableAtLoopEntry(S, L) &&
1595 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLE, LHS: S, RHS: Zero);
1596}
1597
1598bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1599 bool Signed) {
1600 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1601 APInt Min = Signed ? APInt::getSignedMinValue(numBits: BitWidth) :
1602 APInt::getMinValue(numBits: BitWidth);
1603 auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
1604 return SE.isAvailableAtLoopEntry(S, L) &&
1605 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1606 RHS: SE.getConstant(Val: Min));
1607}
1608
1609bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1610 bool Signed) {
1611 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1612 APInt Max = Signed ? APInt::getSignedMaxValue(numBits: BitWidth) :
1613 APInt::getMaxValue(numBits: BitWidth);
1614 auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1615 return SE.isAvailableAtLoopEntry(S, L) &&
1616 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1617 RHS: SE.getConstant(Val: Max));
1618}
1619
1620//===----------------------------------------------------------------------===//
1621// rewriteLoopExitValues - Optimize IV users outside the loop.
1622// As a side effect, reduces the amount of IV processing within the loop.
1623//===----------------------------------------------------------------------===//
1624
1625static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
1626 SmallPtrSet<const Instruction *, 8> Visited;
1627 SmallVector<const Instruction *, 8> WorkList;
1628 Visited.insert(Ptr: I);
1629 WorkList.push_back(Elt: I);
1630 while (!WorkList.empty()) {
1631 const Instruction *Curr = WorkList.pop_back_val();
1632 // This use is outside the loop, nothing to do.
1633 if (!L->contains(Inst: Curr))
1634 continue;
1635 // Do we assume it is a "hard" use which will not be eliminated easily?
1636 if (Curr->mayHaveSideEffects())
1637 return true;
1638 // Otherwise, add all its users to worklist.
1639 for (const auto *U : Curr->users()) {
1640 auto *UI = cast<Instruction>(Val: U);
1641 if (Visited.insert(Ptr: UI).second)
1642 WorkList.push_back(Elt: UI);
1643 }
1644 }
1645 return false;
1646}
1647
1648// Collect information about PHI nodes which can be transformed in
1649// rewriteLoopExitValues.
1650struct RewritePhi {
1651 PHINode *PN; // For which PHI node is this replacement?
1652 unsigned Ith; // For which incoming value?
1653 const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
1654 Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
1655 bool HighCost; // Is this expansion a high-cost?
1656
1657 RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
1658 bool H)
1659 : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
1660 HighCost(H) {}
1661};
1662
1663// Check whether it is possible to delete the loop after rewriting exit
1664// value. If it is possible, ignore ReplaceExitValue and do rewriting
1665// aggressively.
1666static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) {
1667 BasicBlock *Preheader = L->getLoopPreheader();
1668 // If there is no preheader, the loop will not be deleted.
1669 if (!Preheader)
1670 return false;
1671
1672 // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1.
1673 // We obviate multiple ExitingBlocks case for simplicity.
1674 // TODO: If we see testcase with multiple ExitingBlocks can be deleted
1675 // after exit value rewriting, we can enhance the logic here.
1676 SmallVector<BasicBlock *, 4> ExitingBlocks;
1677 L->getExitingBlocks(ExitingBlocks);
1678 SmallVector<BasicBlock *, 8> ExitBlocks;
1679 L->getUniqueExitBlocks(ExitBlocks);
1680 if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
1681 return false;
1682
1683 BasicBlock *ExitBlock = ExitBlocks[0];
1684 BasicBlock::iterator BI = ExitBlock->begin();
1685 while (PHINode *P = dyn_cast<PHINode>(Val&: BI)) {
1686 Value *Incoming = P->getIncomingValueForBlock(BB: ExitingBlocks[0]);
1687
1688 // If the Incoming value of P is found in RewritePhiSet, we know it
1689 // could be rewritten to use a loop invariant value in transformation
1690 // phase later. Skip it in the loop invariant check below.
1691 bool found = false;
1692 for (const RewritePhi &Phi : RewritePhiSet) {
1693 unsigned i = Phi.Ith;
1694 if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
1695 found = true;
1696 break;
1697 }
1698 }
1699
1700 Instruction *I;
1701 if (!found && (I = dyn_cast<Instruction>(Val: Incoming)))
1702 if (!L->hasLoopInvariantOperands(I))
1703 return false;
1704
1705 ++BI;
1706 }
1707
1708 for (auto *BB : L->blocks())
1709 if (llvm::any_of(Range&: *BB, P: [](Instruction &I) {
1710 return I.mayHaveSideEffects();
1711 }))
1712 return false;
1713
1714 return true;
1715}
1716
1717/// Checks if it is safe to call InductionDescriptor::isInductionPHI for \p Phi,
1718/// and returns true if this Phi is an induction phi in the loop. When
1719/// isInductionPHI returns true, \p ID will be also be set by isInductionPHI.
1720static bool checkIsIndPhi(PHINode *Phi, Loop *L, ScalarEvolution *SE,
1721 InductionDescriptor &ID) {
1722 if (!Phi)
1723 return false;
1724 if (!L->getLoopPreheader())
1725 return false;
1726 if (Phi->getParent() != L->getHeader())
1727 return false;
1728 return InductionDescriptor::isInductionPHI(Phi, L, SE, D&: ID);
1729}
1730
1731int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
1732 ScalarEvolution *SE,
1733 const TargetTransformInfo *TTI,
1734 SCEVExpander &Rewriter, DominatorTree *DT,
1735 ReplaceExitVal ReplaceExitValue,
1736 SmallVector<WeakTrackingVH, 16> &DeadInsts) {
1737 // Check a pre-condition.
1738 assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1739 "Indvars did not preserve LCSSA!");
1740
1741 SmallVector<BasicBlock*, 8> ExitBlocks;
1742 L->getUniqueExitBlocks(ExitBlocks);
1743
1744 SmallVector<RewritePhi, 8> RewritePhiSet;
1745 // Find all values that are computed inside the loop, but used outside of it.
1746 // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan
1747 // the exit blocks of the loop to find them.
1748 for (BasicBlock *ExitBB : ExitBlocks) {
1749 // If there are no PHI nodes in this exit block, then no values defined
1750 // inside the loop are used on this path, skip it.
1751 PHINode *PN = dyn_cast<PHINode>(Val: ExitBB->begin());
1752 if (!PN) continue;
1753
1754 unsigned NumPreds = PN->getNumIncomingValues();
1755
1756 // Iterate over all of the PHI nodes.
1757 BasicBlock::iterator BBI = ExitBB->begin();
1758 while ((PN = dyn_cast<PHINode>(Val: BBI++))) {
1759 if (PN->use_empty())
1760 continue; // dead use, don't replace it
1761
1762 if (!SE->isSCEVable(Ty: PN->getType()))
1763 continue;
1764
1765 // Iterate over all of the values in all the PHI nodes.
1766 for (unsigned i = 0; i != NumPreds; ++i) {
1767 // If the value being merged in is not integer or is not defined
1768 // in the loop, skip it.
1769 Value *InVal = PN->getIncomingValue(i);
1770 if (!isa<Instruction>(Val: InVal))
1771 continue;
1772
1773 // If this pred is for a subloop, not L itself, skip it.
1774 if (LI->getLoopFor(BB: PN->getIncomingBlock(i)) != L)
1775 continue; // The Block is in a subloop, skip it.
1776
1777 // Check that InVal is defined in the loop.
1778 Instruction *Inst = cast<Instruction>(Val: InVal);
1779 if (!L->contains(Inst))
1780 continue;
1781
1782 // Find exit values which are induction variables in the loop, and are
1783 // unused in the loop, with the only use being the exit block PhiNode,
1784 // and the induction variable update binary operator.
1785 // The exit value can be replaced with the final value when it is cheap
1786 // to do so.
1787 if (ReplaceExitValue == UnusedIndVarInLoop) {
1788 InductionDescriptor ID;
1789 PHINode *IndPhi = dyn_cast<PHINode>(Val: Inst);
1790 if (IndPhi) {
1791 if (!checkIsIndPhi(Phi: IndPhi, L, SE, ID))
1792 continue;
1793 // This is an induction PHI. Check that the only users are PHI
1794 // nodes, and induction variable update binary operators.
1795 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1796 if (!isa<PHINode>(Val: U) && !isa<BinaryOperator>(Val: U))
1797 return true;
1798 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: U);
1799 if (B && B != ID.getInductionBinOp())
1800 return true;
1801 return false;
1802 }))
1803 continue;
1804 } else {
1805 // If it is not an induction phi, it must be an induction update
1806 // binary operator with an induction phi user.
1807 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: Inst);
1808 if (!B)
1809 continue;
1810 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1811 PHINode *Phi = dyn_cast<PHINode>(Val: U);
1812 if (Phi != PN && !checkIsIndPhi(Phi, L, SE, ID))
1813 return true;
1814 return false;
1815 }))
1816 continue;
1817 if (B != ID.getInductionBinOp())
1818 continue;
1819 }
1820 }
1821
1822 // Okay, this instruction has a user outside of the current loop
1823 // and varies predictably *inside* the loop. Evaluate the value it
1824 // contains when the loop exits, if possible. We prefer to start with
1825 // expressions which are true for all exits (so as to maximize
1826 // expression reuse by the SCEVExpander), but resort to per-exit
1827 // evaluation if that fails.
1828 const SCEV *ExitValue = SE->getSCEVAtScope(V: Inst, L: L->getParentLoop());
1829 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1830 !SE->isLoopInvariant(S: ExitValue, L) ||
1831 !Rewriter.isSafeToExpand(S: ExitValue)) {
1832 // TODO: This should probably be sunk into SCEV in some way; maybe a
1833 // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
1834 // most SCEV expressions and other recurrence types (e.g. shift
1835 // recurrences). Is there existing code we can reuse?
1836 const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: PN->getIncomingBlock(i));
1837 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
1838 continue;
1839 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Inst)))
1840 if (AddRec->getLoop() == L)
1841 ExitValue = AddRec->evaluateAtIteration(It: ExitCount, SE&: *SE);
1842 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1843 !SE->isLoopInvariant(S: ExitValue, L) ||
1844 !Rewriter.isSafeToExpand(S: ExitValue))
1845 continue;
1846 }
1847
1848 // Computing the value outside of the loop brings no benefit if it is
1849 // definitely used inside the loop in a way which can not be optimized
1850 // away. Avoid doing so unless we know we have a value which computes
1851 // the ExitValue already. TODO: This should be merged into SCEV
1852 // expander to leverage its knowledge of existing expressions.
1853 if (ReplaceExitValue != AlwaysRepl && !isa<SCEVConstant>(Val: ExitValue) &&
1854 !isa<SCEVUnknown>(Val: ExitValue) && hasHardUserWithinLoop(L, I: Inst))
1855 continue;
1856
1857 // Check if expansions of this SCEV would count as being high cost.
1858 bool HighCost = Rewriter.isHighCostExpansion(
1859 Exprs: ExitValue, L, Budget: SCEVCheapExpansionBudget, TTI, At: Inst);
1860
1861 // Note that we must not perform expansions until after
1862 // we query *all* the costs, because if we perform temporary expansion
1863 // inbetween, one that we might not intend to keep, said expansion
1864 // *may* affect cost calculation of the next SCEV's we'll query,
1865 // and next SCEV may errneously get smaller cost.
1866
1867 // Collect all the candidate PHINodes to be rewritten.
1868 Instruction *InsertPt =
1869 (isa<PHINode>(Val: Inst) || isa<LandingPadInst>(Val: Inst)) ?
1870 &*Inst->getParent()->getFirstInsertionPt() : Inst;
1871 RewritePhiSet.emplace_back(Args&: PN, Args&: i, Args&: ExitValue, Args&: InsertPt, Args&: HighCost);
1872 }
1873 }
1874 }
1875
1876 // TODO: evaluate whether it is beneficial to change how we calculate
1877 // high-cost: if we have SCEV 'A' which we know we will expand, should we
1878 // calculate the cost of other SCEV's after expanding SCEV 'A', thus
1879 // potentially giving cost bonus to those other SCEV's?
1880
1881 bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
1882 int NumReplaced = 0;
1883
1884 // Transformation.
1885 for (const RewritePhi &Phi : RewritePhiSet) {
1886 PHINode *PN = Phi.PN;
1887
1888 // Only do the rewrite when the ExitValue can be expanded cheaply.
1889 // If LoopCanBeDel is true, rewrite exit value aggressively.
1890 if ((ReplaceExitValue == OnlyCheapRepl ||
1891 ReplaceExitValue == UnusedIndVarInLoop) &&
1892 !LoopCanBeDel && Phi.HighCost)
1893 continue;
1894
1895 Value *ExitVal = Rewriter.expandCodeFor(
1896 SH: Phi.ExpansionSCEV, Ty: Phi.PN->getType(), I: Phi.ExpansionPoint);
1897
1898 LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal
1899 << '\n'
1900 << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
1901
1902#ifndef NDEBUG
1903 // If we reuse an instruction from a loop which is neither L nor one of
1904 // its containing loops, we end up breaking LCSSA form for this loop by
1905 // creating a new use of its instruction.
1906 if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
1907 if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1908 if (EVL != L)
1909 assert(EVL->contains(L) && "LCSSA breach detected!");
1910#endif
1911
1912 NumReplaced++;
1913 Instruction *Inst = cast<Instruction>(Val: PN->getIncomingValue(i: Phi.Ith));
1914 PN->setIncomingValue(i: Phi.Ith, V: ExitVal);
1915 // It's necessary to tell ScalarEvolution about this explicitly so that
1916 // it can walk the def-use list and forget all SCEVs, as it may not be
1917 // watching the PHI itself. Once the new exit value is in place, there
1918 // may not be a def-use connection between the loop and every instruction
1919 // which got a SCEVAddRecExpr for that loop.
1920 SE->forgetValue(V: PN);
1921
1922 // If this instruction is dead now, delete it. Don't do it now to avoid
1923 // invalidating iterators.
1924 if (isInstructionTriviallyDead(I: Inst, TLI))
1925 DeadInsts.push_back(Elt: Inst);
1926
1927 // Replace PN with ExitVal if that is legal and does not break LCSSA.
1928 if (PN->getNumIncomingValues() == 1 &&
1929 LI->replacementPreservesLCSSAForm(From: PN, To: ExitVal)) {
1930 PN->replaceAllUsesWith(V: ExitVal);
1931 PN->eraseFromParent();
1932 }
1933 }
1934
1935 // The insertion point instruction may have been deleted; clear it out
1936 // so that the rewriter doesn't trip over it later.
1937 Rewriter.clearInsertPoint();
1938 return NumReplaced;
1939}
1940
1941/// Utility that implements appending of loops onto a worklist.
1942/// Loops are added in preorder (analogous for reverse postorder for trees),
1943/// and the worklist is processed LIFO.
1944template <typename RangeT>
1945void llvm::appendReversedLoopsToWorklist(
1946 RangeT &&Loops, SmallPriorityWorklist<Loop *, 4> &Worklist) {
1947 // We use an internal worklist to build up the preorder traversal without
1948 // recursion.
1949 SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist;
1950
1951 // We walk the initial sequence of loops in reverse because we generally want
1952 // to visit defs before uses and the worklist is LIFO.
1953 for (Loop *RootL : Loops) {
1954 assert(PreOrderLoops.empty() && "Must start with an empty preorder walk.");
1955 assert(PreOrderWorklist.empty() &&
1956 "Must start with an empty preorder walk worklist.");
1957 PreOrderWorklist.push_back(Elt: RootL);
1958 do {
1959 Loop *L = PreOrderWorklist.pop_back_val();
1960 PreOrderWorklist.append(in_start: L->begin(), in_end: L->end());
1961 PreOrderLoops.push_back(Elt: L);
1962 } while (!PreOrderWorklist.empty());
1963
1964 Worklist.insert(Input: std::move(PreOrderLoops));
1965 PreOrderLoops.clear();
1966 }
1967}
1968
1969template <typename RangeT>
1970void llvm::appendLoopsToWorklist(RangeT &&Loops,
1971 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1972 appendReversedLoopsToWorklist(reverse(Loops), Worklist);
1973}
1974
1975template LLVM_EXPORT_TEMPLATE void
1976llvm::appendLoopsToWorklist<ArrayRef<Loop *> &>(
1977 ArrayRef<Loop *> &Loops, SmallPriorityWorklist<Loop *, 4> &Worklist);
1978
1979template LLVM_EXPORT_TEMPLATE void
1980llvm::appendLoopsToWorklist<Loop &>(Loop &L,
1981 SmallPriorityWorklist<Loop *, 4> &Worklist);
1982
1983void llvm::appendLoopsToWorklist(LoopInfo &LI,
1984 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1985 appendReversedLoopsToWorklist(Loops&: LI, Worklist);
1986}
1987
1988Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
1989 LoopInfo *LI, LPPassManager *LPM) {
1990 Loop &New = *LI->AllocateLoop();
1991 if (PL)
1992 PL->addChildLoop(NewChild: &New);
1993 else
1994 LI->addTopLevelLoop(New: &New);
1995
1996 if (LPM)
1997 LPM->addLoop(L&: New);
1998
1999 // Add all of the blocks in L to the new loop.
2000 for (BasicBlock *BB : L->blocks())
2001 if (LI->getLoopFor(BB) == L)
2002 New.addBasicBlockToLoop(NewBB: cast<BasicBlock>(Val&: VM[BB]), LI&: *LI);
2003
2004 // Add all of the subloops to the new loop.
2005 for (Loop *I : *L)
2006 cloneLoop(L: I, PL: &New, VM, LI, LPM);
2007
2008 return &New;
2009}
2010
2011/// IR Values for the lower and upper bounds of a pointer evolution. We
2012/// need to use value-handles because SCEV expansion can invalidate previously
2013/// expanded values. Thus expansion of a pointer can invalidate the bounds for
2014/// a previous one.
2015struct PointerBounds {
2016 TrackingVH<Value> Start;
2017 TrackingVH<Value> End;
2018 Value *StrideToCheck;
2019};
2020
2021/// Expand code for the lower and upper bound of the pointer group \p CG
2022/// in \p TheLoop. \return the values for the bounds.
2023static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
2024 Loop *TheLoop, Instruction *Loc,
2025 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2026 LLVMContext &Ctx = Loc->getContext();
2027 Type *PtrArithTy = PointerType::get(C&: Ctx, AddressSpace: CG->AddressSpace);
2028
2029 Value *Start = nullptr, *End = nullptr;
2030 LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
2031 const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
2032
2033 // If the Low and High values are themselves loop-variant, then we may want
2034 // to expand the range to include those covered by the outer loop as well.
2035 // There is a trade-off here with the advantage being that creating checks
2036 // using the expanded range permits the runtime memory checks to be hoisted
2037 // out of the outer loop. This reduces the cost of entering the inner loop,
2038 // which can be significant for low trip counts. The disadvantage is that
2039 // there is a chance we may now never enter the vectorized inner loop,
2040 // whereas using a restricted range check could have allowed us to enter at
2041 // least once. This is why the behaviour is not currently the default and is
2042 // controlled by the parameter 'HoistRuntimeChecks'.
2043 if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
2044 isa<SCEVAddRecExpr>(Val: High) && isa<SCEVAddRecExpr>(Val: Low)) {
2045 auto *HighAR = cast<SCEVAddRecExpr>(Val: High);
2046 auto *LowAR = cast<SCEVAddRecExpr>(Val: Low);
2047 const Loop *OuterLoop = TheLoop->getParentLoop();
2048 ScalarEvolution &SE = *Exp.getSE();
2049 const SCEV *Recur = LowAR->getStepRecurrence(SE);
2050 if (Recur == HighAR->getStepRecurrence(SE) &&
2051 HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
2052 BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
2053 const SCEV *OuterExitCount = SE.getExitCount(L: OuterLoop, ExitingBlock: OuterLoopLatch);
2054 if (!isa<SCEVCouldNotCompute>(Val: OuterExitCount) &&
2055 OuterExitCount->getType()->isIntegerTy()) {
2056 const SCEV *NewHigh =
2057 cast<SCEVAddRecExpr>(Val: High)->evaluateAtIteration(It: OuterExitCount, SE);
2058 if (!isa<SCEVCouldNotCompute>(Val: NewHigh)) {
2059 LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
2060 "outer loop in order to permit hoisting\n");
2061 High = NewHigh;
2062 Low = cast<SCEVAddRecExpr>(Val: Low)->getStart();
2063 // If there is a possibility that the stride is negative then we have
2064 // to generate extra checks to ensure the stride is positive.
2065 if (!SE.isKnownNonNegative(
2066 S: SE.applyLoopGuards(Expr: Recur, L: HighAR->getLoop()))) {
2067 Stride = Recur;
2068 LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
2069 "positive: "
2070 << *Stride << '\n');
2071 }
2072 }
2073 }
2074 }
2075 }
2076
2077 Start = Exp.expandCodeFor(SH: Low, Ty: PtrArithTy, I: Loc);
2078 End = Exp.expandCodeFor(SH: High, Ty: PtrArithTy, I: Loc);
2079 if (CG->NeedsFreeze) {
2080 IRBuilder<> Builder(Loc);
2081 Start = Builder.CreateFreeze(V: Start, Name: Start->getName() + ".fr");
2082 End = Builder.CreateFreeze(V: End, Name: End->getName() + ".fr");
2083 }
2084 Value *StrideVal =
2085 Stride ? Exp.expandCodeFor(SH: Stride, Ty: Stride->getType(), I: Loc) : nullptr;
2086 LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
2087 return {.Start: Start, .End: End, .StrideToCheck: StrideVal};
2088}
2089
2090/// Turns a collection of checks into a collection of expanded upper and
2091/// lower bounds for both pointers in the check.
2092static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
2093expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
2094 Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
2095 SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
2096
2097 // Here we're relying on the SCEV Expander's cache to only emit code for the
2098 // same bounds once.
2099 transform(Range: PointerChecks, d_first: std::back_inserter(x&: ChecksWithBounds),
2100 F: [&](const RuntimePointerCheck &Check) {
2101 PointerBounds First = expandBounds(CG: Check.first, TheLoop: L, Loc, Exp,
2102 HoistRuntimeChecks),
2103 Second = expandBounds(CG: Check.second, TheLoop: L, Loc, Exp,
2104 HoistRuntimeChecks);
2105 return std::make_pair(x&: First, y&: Second);
2106 });
2107
2108 return ChecksWithBounds;
2109}
2110
2111Value *llvm::addRuntimeChecks(
2112 Instruction *Loc, Loop *TheLoop,
2113 const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
2114 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2115 // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
2116 // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
2117 auto ExpandedChecks =
2118 expandBounds(PointerChecks, L: TheLoop, Loc, Exp, HoistRuntimeChecks);
2119
2120 LLVMContext &Ctx = Loc->getContext();
2121 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2122 ChkBuilder.SetInsertPoint(Loc);
2123 // Our instructions might fold to a constant.
2124 Value *MemoryRuntimeCheck = nullptr;
2125
2126 for (const auto &[A, B] : ExpandedChecks) {
2127 // Check if two pointers (A and B) conflict where conflict is computed as:
2128 // start(A) <= end(B) && start(B) <= end(A)
2129
2130 assert((A.Start->getType()->getPointerAddressSpace() ==
2131 B.End->getType()->getPointerAddressSpace()) &&
2132 (B.Start->getType()->getPointerAddressSpace() ==
2133 A.End->getType()->getPointerAddressSpace()) &&
2134 "Trying to bounds check pointers with different address spaces");
2135
2136 // [A|B].Start points to the first accessed byte under base [A|B].
2137 // [A|B].End points to the last accessed byte, plus one.
2138 // There is no conflict when the intervals are disjoint:
2139 // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
2140 //
2141 // bound0 = (B.Start < A.End)
2142 // bound1 = (A.Start < B.End)
2143 // IsConflict = bound0 & bound1
2144 Value *Cmp0 = ChkBuilder.CreateICmpULT(LHS: A.Start, RHS: B.End, Name: "bound0");
2145 Value *Cmp1 = ChkBuilder.CreateICmpULT(LHS: B.Start, RHS: A.End, Name: "bound1");
2146 Value *IsConflict = ChkBuilder.CreateAnd(LHS: Cmp0, RHS: Cmp1, Name: "found.conflict");
2147 if (A.StrideToCheck) {
2148 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2149 LHS: A.StrideToCheck, RHS: ConstantInt::get(Ty: A.StrideToCheck->getType(), V: 0),
2150 Name: "stride.check");
2151 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2152 }
2153 if (B.StrideToCheck) {
2154 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2155 LHS: B.StrideToCheck, RHS: ConstantInt::get(Ty: B.StrideToCheck->getType(), V: 0),
2156 Name: "stride.check");
2157 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2158 }
2159 if (MemoryRuntimeCheck) {
2160 IsConflict =
2161 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2162 }
2163 MemoryRuntimeCheck = IsConflict;
2164 }
2165
2166 Exp.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2167 return MemoryRuntimeCheck;
2168}
2169
2170namespace {
2171/// Rewriter to replace SCEVPtrToIntExpr with SCEVPtrToAddrExpr when the result
2172/// type matches the pointer address type. This allows expressions mixing
2173/// ptrtoint and ptrtoaddr to simplify properly.
2174struct SCEVPtrToAddrRewriter : SCEVRewriteVisitor<SCEVPtrToAddrRewriter> {
2175 const DataLayout &DL;
2176 SCEVPtrToAddrRewriter(ScalarEvolution &SE, const DataLayout &DL)
2177 : SCEVRewriteVisitor(SE), DL(DL) {}
2178
2179 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *E) {
2180 const SCEV *Op = visit(S: E->getOperand());
2181 if (E->getType() == DL.getAddressType(PtrTy: E->getOperand()->getType()))
2182 return SE.getPtrToAddrExpr(Op);
2183 return Op == E->getOperand() ? E : SE.getPtrToIntExpr(Op, Ty: E->getType());
2184 }
2185};
2186} // namespace
2187
2188Value *llvm::addDiffRuntimeChecks(
2189 Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
2190 function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
2191
2192 LLVMContext &Ctx = Loc->getContext();
2193 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2194 ChkBuilder.SetInsertPoint(Loc);
2195 // Our instructions might fold to a constant.
2196 Value *MemoryRuntimeCheck = nullptr;
2197
2198 auto &SE = *Expander.getSE();
2199 const DataLayout &DL = Loc->getDataLayout();
2200 SCEVPtrToAddrRewriter Rewriter(SE, DL);
2201 // Map to keep track of created compares, The key is the pair of operands for
2202 // the compare, to allow detecting and re-using redundant compares.
2203 DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2204 for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2205 Type *Ty = SinkStart->getType();
2206 // Compute VF * IC * AccessSize.
2207 auto *VFTimesICTimesSize =
2208 ChkBuilder.CreateMul(LHS: GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2209 RHS: ConstantInt::get(Ty, V: IC * AccessSize));
2210 const SCEV *SinkStartRewritten = Rewriter.visit(S: SinkStart);
2211 const SCEV *SrcStartRewritten = Rewriter.visit(S: SrcStart);
2212 Value *Diff = Expander.expandCodeFor(
2213 SH: SE.getMinusSCEV(LHS: SinkStartRewritten, RHS: SrcStartRewritten), Ty, I: Loc);
2214
2215 // Check if the same compare has already been created earlier. In that case,
2216 // there is no need to check it again.
2217 Value *IsConflict = SeenCompares.lookup(Val: {Diff, VFTimesICTimesSize});
2218 if (IsConflict)
2219 continue;
2220
2221 IsConflict =
2222 ChkBuilder.CreateICmpULT(LHS: Diff, RHS: VFTimesICTimesSize, Name: "diff.check");
2223 SeenCompares.insert(KV: {{Diff, VFTimesICTimesSize}, IsConflict});
2224 if (NeedsFreeze)
2225 IsConflict =
2226 ChkBuilder.CreateFreeze(V: IsConflict, Name: IsConflict->getName() + ".fr");
2227 if (MemoryRuntimeCheck) {
2228 IsConflict =
2229 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2230 }
2231 MemoryRuntimeCheck = IsConflict;
2232 }
2233
2234 Expander.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2235 return MemoryRuntimeCheck;
2236}
2237
2238std::optional<IVConditionInfo>
2239llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold,
2240 const MemorySSA &MSSA, AAResults &AA) {
2241 auto *TI = dyn_cast<CondBrInst>(Val: L.getHeader()->getTerminator());
2242 if (!TI)
2243 return {};
2244
2245 auto *CondI = dyn_cast<Instruction>(Val: TI->getCondition());
2246 // The case with the condition outside the loop should already be handled
2247 // earlier.
2248 // Allow CmpInst and TruncInsts as they may be users of load instructions
2249 // and have potential for partial unswitching
2250 if (!CondI || !isa<CmpInst, TruncInst>(Val: CondI) || !L.contains(Inst: CondI))
2251 return {};
2252
2253 SmallVector<Instruction *> InstToDuplicate;
2254 InstToDuplicate.push_back(Elt: CondI);
2255
2256 SmallVector<Value *, 4> WorkList;
2257 WorkList.append(in_start: CondI->op_begin(), in_end: CondI->op_end());
2258
2259 SmallVector<MemoryAccess *, 4> AccessesToCheck;
2260 SmallVector<MemoryLocation, 4> AccessedLocs;
2261 while (!WorkList.empty()) {
2262 Instruction *I = dyn_cast<Instruction>(Val: WorkList.pop_back_val());
2263 if (!I || !L.contains(Inst: I))
2264 continue;
2265
2266 // TODO: support additional instructions.
2267 if (!isa<LoadInst>(Val: I) && !isa<GetElementPtrInst>(Val: I))
2268 return {};
2269
2270 // Do not duplicate volatile and atomic loads.
2271 if (auto *LI = dyn_cast<LoadInst>(Val: I))
2272 if (LI->isVolatile() || LI->isAtomic())
2273 return {};
2274
2275 InstToDuplicate.push_back(Elt: I);
2276 if (MemoryAccess *MA = MSSA.getMemoryAccess(I)) {
2277 if (auto *MemUse = dyn_cast_or_null<MemoryUse>(Val: MA)) {
2278 // Queue the defining access to check for alias checks.
2279 AccessesToCheck.push_back(Elt: MemUse->getDefiningAccess());
2280 AccessedLocs.push_back(Elt: MemoryLocation::get(Inst: I));
2281 } else {
2282 // MemoryDefs may clobber the location or may be atomic memory
2283 // operations. Bail out.
2284 return {};
2285 }
2286 }
2287 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
2288 }
2289
2290 if (InstToDuplicate.empty())
2291 return {};
2292
2293 SmallVector<BasicBlock *, 4> ExitingBlocks;
2294 L.getExitingBlocks(ExitingBlocks);
2295 auto HasNoClobbersOnPath =
2296 [&L, &AA, &AccessedLocs, &ExitingBlocks, &InstToDuplicate,
2297 MSSAThreshold](BasicBlock *Succ, BasicBlock *Header,
2298 SmallVector<MemoryAccess *, 4> AccessesToCheck)
2299 -> std::optional<IVConditionInfo> {
2300 IVConditionInfo Info;
2301 // First, collect all blocks in the loop that are on a patch from Succ
2302 // to the header.
2303 SmallVector<BasicBlock *, 4> WorkList;
2304 WorkList.push_back(Elt: Succ);
2305 WorkList.push_back(Elt: Header);
2306 SmallPtrSet<BasicBlock *, 4> Seen;
2307 Seen.insert(Ptr: Header);
2308 Info.PathIsNoop &=
2309 all_of(Range&: *Header, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2310
2311 while (!WorkList.empty()) {
2312 BasicBlock *Current = WorkList.pop_back_val();
2313 if (!L.contains(BB: Current))
2314 continue;
2315 const auto &SeenIns = Seen.insert(Ptr: Current);
2316 if (!SeenIns.second)
2317 continue;
2318
2319 Info.PathIsNoop &= all_of(
2320 Range&: *Current, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2321 WorkList.append(in_start: succ_begin(BB: Current), in_end: succ_end(BB: Current));
2322 }
2323
2324 // Require at least 2 blocks on a path through the loop. This skips
2325 // paths that directly exit the loop.
2326 if (Seen.size() < 2)
2327 return {};
2328
2329 // Next, check if there are any MemoryDefs that are on the path through
2330 // the loop (in the Seen set) and they may-alias any of the locations in
2331 // AccessedLocs. If that is the case, they may modify the condition and
2332 // partial unswitching is not possible.
2333 SmallPtrSet<MemoryAccess *, 4> SeenAccesses;
2334 while (!AccessesToCheck.empty()) {
2335 MemoryAccess *Current = AccessesToCheck.pop_back_val();
2336 auto SeenI = SeenAccesses.insert(Ptr: Current);
2337 if (!SeenI.second || !Seen.contains(Ptr: Current->getBlock()))
2338 continue;
2339
2340 // Bail out if exceeded the threshold.
2341 if (SeenAccesses.size() >= MSSAThreshold)
2342 return {};
2343
2344 // MemoryUse are read-only accesses.
2345 if (isa<MemoryUse>(Val: Current))
2346 continue;
2347
2348 // For a MemoryDef, check if is aliases any of the location feeding
2349 // the original condition.
2350 if (auto *CurrentDef = dyn_cast<MemoryDef>(Val: Current)) {
2351 if (any_of(Range&: AccessedLocs, P: [&AA, CurrentDef](MemoryLocation &Loc) {
2352 return isModSet(
2353 MRI: AA.getModRefInfo(I: CurrentDef->getMemoryInst(), OptLoc: Loc));
2354 }))
2355 return {};
2356 }
2357
2358 for (Use &U : Current->uses())
2359 AccessesToCheck.push_back(Elt: cast<MemoryAccess>(Val: U.getUser()));
2360 }
2361
2362 // We could also allow loops with known trip counts without mustprogress,
2363 // but ScalarEvolution may not be available.
2364 Info.PathIsNoop &= isMustProgress(L: &L);
2365
2366 // If the path is considered a no-op so far, check if it reaches a
2367 // single exit block without any phis. This ensures no values from the
2368 // loop are used outside of the loop.
2369 if (Info.PathIsNoop) {
2370 for (auto *Exiting : ExitingBlocks) {
2371 if (!Seen.contains(Ptr: Exiting))
2372 continue;
2373 for (auto *Succ : successors(BB: Exiting)) {
2374 if (L.contains(BB: Succ))
2375 continue;
2376
2377 Info.PathIsNoop &= Succ->phis().empty() &&
2378 (!Info.ExitForPath || Info.ExitForPath == Succ);
2379 if (!Info.PathIsNoop)
2380 break;
2381 assert((!Info.ExitForPath || Info.ExitForPath == Succ) &&
2382 "cannot have multiple exit blocks");
2383 Info.ExitForPath = Succ;
2384 }
2385 }
2386 }
2387 if (!Info.ExitForPath)
2388 Info.PathIsNoop = false;
2389
2390 Info.InstToDuplicate = std::move(InstToDuplicate);
2391 return Info;
2392 };
2393
2394 // If we branch to the same successor, partial unswitching will not be
2395 // beneficial.
2396 if (TI->getSuccessor(i: 0) == TI->getSuccessor(i: 1))
2397 return {};
2398
2399 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 0), L.getHeader(),
2400 AccessesToCheck)) {
2401 Info->KnownValue = ConstantInt::getTrue(Context&: TI->getContext());
2402 return Info;
2403 }
2404 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 1), L.getHeader(),
2405 AccessesToCheck)) {
2406 Info->KnownValue = ConstantInt::getFalse(Context&: TI->getContext());
2407 return Info;
2408 }
2409
2410 return {};
2411}
2412