1//===-- LoopUtils.cpp - Loop Utility functions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines common loop utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/LoopUtils.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Analysis/AliasAnalysis.h"
21#include "llvm/Analysis/BasicAliasAnalysis.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/LoopAccessAnalysis.h"
26#include "llvm/Analysis/LoopInfo.h"
27#include "llvm/Analysis/LoopPass.h"
28#include "llvm/Analysis/MemorySSA.h"
29#include "llvm/Analysis/MemorySSAUpdater.h"
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
32#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33#include "llvm/IR/DIBuilder.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/Instructions.h"
36#include "llvm/IR/IntrinsicInst.h"
37#include "llvm/IR/MDBuilder.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/PatternMatch.h"
40#include "llvm/IR/ProfDataUtils.h"
41#include "llvm/IR/ValueHandle.h"
42#include "llvm/InitializePasses.h"
43#include "llvm/Pass.h"
44#include "llvm/Support/Compiler.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Transforms/Utils/BasicBlockUtils.h"
47#include "llvm/Transforms/Utils/Local.h"
48#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
49
50using namespace llvm;
51using namespace llvm::PatternMatch;
52
53#define DEBUG_TYPE "loop-utils"
54
55static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced";
56static const char *LLVMLoopDisableLICM = "llvm.licm.disable";
57namespace llvm {
58extern cl::opt<bool> ProfcheckDisableMetadataFixes;
59} // namespace llvm
60
61bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
62 MemorySSAUpdater *MSSAU,
63 bool PreserveLCSSA) {
64 bool Changed = false;
65
66 // We re-use a vector for the in-loop predecesosrs.
67 SmallVector<BasicBlock *, 4> InLoopPredecessors;
68
69 auto RewriteExit = [&](BasicBlock *BB) {
70 assert(InLoopPredecessors.empty() &&
71 "Must start with an empty predecessors list!");
72 llvm::scope_exit Cleanup([&] { InLoopPredecessors.clear(); });
73
74 // See if there are any non-loop predecessors of this exit block and
75 // keep track of the in-loop predecessors.
76 bool IsDedicatedExit = true;
77 for (auto *PredBB : predecessors(BB))
78 if (L->contains(BB: PredBB)) {
79 if (isa<IndirectBrInst>(Val: PredBB->getTerminator()))
80 // We cannot rewrite exiting edges from an indirectbr.
81 return false;
82
83 InLoopPredecessors.push_back(Elt: PredBB);
84 } else {
85 IsDedicatedExit = false;
86 }
87
88 assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!");
89
90 // Nothing to do if this is already a dedicated exit.
91 if (IsDedicatedExit)
92 return false;
93
94 auto *NewExitBB = SplitBlockPredecessors(
95 BB, Preds: InLoopPredecessors, Suffix: ".loopexit", DT, LI, MSSAU, PreserveLCSSA);
96
97 if (!NewExitBB)
98 LLVM_DEBUG(
99 dbgs() << "WARNING: Can't create a dedicated exit block for loop: "
100 << *L << "\n");
101 else
102 LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block "
103 << NewExitBB->getName() << "\n");
104 return true;
105 };
106
107 // Walk the exit blocks directly rather than building up a data structure for
108 // them, but only visit each one once.
109 SmallPtrSet<BasicBlock *, 4> Visited;
110 for (auto *BB : L->blocks())
111 for (auto *SuccBB : successors(BB)) {
112 // We're looking for exit blocks so skip in-loop successors.
113 if (L->contains(BB: SuccBB))
114 continue;
115
116 // Visit each exit block exactly once.
117 if (!Visited.insert(Ptr: SuccBB).second)
118 continue;
119
120 Changed |= RewriteExit(SuccBB);
121 }
122
123 return Changed;
124}
125
126/// Returns the instructions that use values defined in the loop.
127SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) {
128 SmallVector<Instruction *, 8> UsedOutside;
129
130 for (auto *Block : L->getBlocks())
131 // FIXME: I believe that this could use copy_if if the Inst reference could
132 // be adapted into a pointer.
133 for (auto &Inst : *Block) {
134 auto Users = Inst.users();
135 if (any_of(Range&: Users, P: [&](User *U) {
136 auto *Use = cast<Instruction>(Val: U);
137 return !L->contains(BB: Use->getParent());
138 }))
139 UsedOutside.push_back(Elt: &Inst);
140 }
141
142 return UsedOutside;
143}
144
145void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
146 // By definition, all loop passes need the LoopInfo analysis and the
147 // Dominator tree it depends on. Because they all participate in the loop
148 // pass manager, they must also preserve these.
149 AU.addRequired<DominatorTreeWrapperPass>();
150 AU.addPreserved<DominatorTreeWrapperPass>();
151 AU.addRequired<LoopInfoWrapperPass>();
152 AU.addPreserved<LoopInfoWrapperPass>();
153
154 // We must also preserve LoopSimplify and LCSSA. We locally access their IDs
155 // here because users shouldn't directly get them from this header.
156 extern char &LoopSimplifyID;
157 extern char &LCSSAID;
158 AU.addRequiredID(ID&: LoopSimplifyID);
159 AU.addPreservedID(ID&: LoopSimplifyID);
160 AU.addRequiredID(ID&: LCSSAID);
161 AU.addPreservedID(ID&: LCSSAID);
162 // This is used in the LPPassManager to perform LCSSA verification on passes
163 // which preserve lcssa form
164 AU.addRequired<LCSSAVerificationPass>();
165 AU.addPreserved<LCSSAVerificationPass>();
166
167 // Loop passes are designed to run inside of a loop pass manager which means
168 // that any function analyses they require must be required by the first loop
169 // pass in the manager (so that it is computed before the loop pass manager
170 // runs) and preserved by all loop pasess in the manager. To make this
171 // reasonably robust, the set needed for most loop passes is maintained here.
172 // If your loop pass requires an analysis not listed here, you will need to
173 // carefully audit the loop pass manager nesting structure that results.
174 AU.addRequired<AAResultsWrapperPass>();
175 AU.addPreserved<AAResultsWrapperPass>();
176 AU.addPreserved<BasicAAWrapperPass>();
177 AU.addPreserved<GlobalsAAWrapperPass>();
178 AU.addPreserved<SCEVAAWrapperPass>();
179 AU.addRequired<ScalarEvolutionWrapperPass>();
180 AU.addPreserved<ScalarEvolutionWrapperPass>();
181 // FIXME: When all loop passes preserve MemorySSA, it can be required and
182 // preserved here instead of the individual handling in each pass.
183}
184
185/// Manually defined generic "LoopPass" dependency initialization. This is used
186/// to initialize the exact set of passes from above in \c
187/// getLoopAnalysisUsage. It can be used within a loop pass's initialization
188/// with:
189///
190/// INITIALIZE_PASS_DEPENDENCY(LoopPass)
191///
192/// As-if "LoopPass" were a pass.
193void llvm::initializeLoopPassPass(PassRegistry &Registry) {
194 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
195 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
197 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
198 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
199 INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
200 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
201 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
202 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
203 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
204}
205
206/// Create MDNode for input string.
207static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
208 LLVMContext &Context = TheLoop->getHeader()->getContext();
209 Metadata *MDs[] = {
210 MDString::get(Context, Str: Name),
211 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V))};
212 return MDNode::get(Context, MDs);
213}
214
215/// Set input string into loop metadata by keeping other values intact.
216/// If the string is already in loop metadata update value if it is
217/// different.
218void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
219 unsigned V) {
220 SmallVector<Metadata *, 4> MDs(1);
221 // If the loop already has metadata, retain it.
222 MDNode *LoopID = TheLoop->getLoopID();
223 if (LoopID) {
224 for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
225 MDNode *Node = cast<MDNode>(Val: LoopID->getOperand(I: i));
226 // If it is of form key = value, try to parse it.
227 if (Node->getNumOperands() == 2) {
228 MDString *S = dyn_cast<MDString>(Val: Node->getOperand(I: 0));
229 if (S && S->getString() == StringMD) {
230 ConstantInt *IntMD =
231 mdconst::extract_or_null<ConstantInt>(MD: Node->getOperand(I: 1));
232 if (IntMD && IntMD->getSExtValue() == V)
233 // It is already in place. Do nothing.
234 return;
235 // We need to update the value, so just skip it here and it will
236 // be added after copying other existed nodes.
237 continue;
238 }
239 }
240 MDs.push_back(Elt: Node);
241 }
242 }
243 // Add new metadata.
244 MDs.push_back(Elt: createStringMetadata(TheLoop, Name: StringMD, V));
245 // Replace current metadata node with new one.
246 LLVMContext &Context = TheLoop->getHeader()->getContext();
247 MDNode *NewLoopID = MDNode::get(Context, MDs);
248 // Set operand 0 to refer to the loop id itself.
249 NewLoopID->replaceOperandWith(I: 0, New: NewLoopID);
250 TheLoop->setLoopID(NewLoopID);
251}
252
253std::optional<ElementCount>
254llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) {
255 std::optional<int> Width =
256 getOptionalIntLoopAttribute(TheLoop, Name: "llvm.loop.vectorize.width");
257
258 if (Width) {
259 std::optional<int> IsScalable = getOptionalIntLoopAttribute(
260 TheLoop, Name: "llvm.loop.vectorize.scalable.enable");
261 return ElementCount::get(MinVal: *Width, Scalable: IsScalable.value_or(u: false));
262 }
263
264 return std::nullopt;
265}
266
267std::optional<MDNode *> llvm::makeFollowupLoopID(
268 MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions,
269 const char *InheritOptionsExceptPrefix, bool AlwaysNew) {
270 if (!OrigLoopID) {
271 if (AlwaysNew)
272 return nullptr;
273 return std::nullopt;
274 }
275
276 assert(OrigLoopID->getOperand(0) == OrigLoopID);
277
278 bool InheritAllAttrs = !InheritOptionsExceptPrefix;
279 bool InheritSomeAttrs =
280 InheritOptionsExceptPrefix && InheritOptionsExceptPrefix[0] != '\0';
281 SmallVector<Metadata *, 8> MDs;
282 MDs.push_back(Elt: nullptr);
283
284 bool Changed = false;
285 if (InheritAllAttrs || InheritSomeAttrs) {
286 for (const MDOperand &Existing : drop_begin(RangeOrContainer: OrigLoopID->operands())) {
287 MDNode *Op = cast<MDNode>(Val: Existing.get());
288
289 auto InheritThisAttribute = [InheritSomeAttrs,
290 InheritOptionsExceptPrefix](MDNode *Op) {
291 if (!InheritSomeAttrs)
292 return false;
293
294 // Skip malformatted attribute metadata nodes.
295 if (Op->getNumOperands() == 0)
296 return true;
297 Metadata *NameMD = Op->getOperand(I: 0).get();
298 if (!isa<MDString>(Val: NameMD))
299 return true;
300 StringRef AttrName = cast<MDString>(Val: NameMD)->getString();
301
302 // Do not inherit excluded attributes.
303 return !AttrName.starts_with(Prefix: InheritOptionsExceptPrefix);
304 };
305
306 if (InheritThisAttribute(Op))
307 MDs.push_back(Elt: Op);
308 else
309 Changed = true;
310 }
311 } else {
312 // Modified if we dropped at least one attribute.
313 Changed = OrigLoopID->getNumOperands() > 1;
314 }
315
316 bool HasAnyFollowup = false;
317 for (StringRef OptionName : FollowupOptions) {
318 MDNode *FollowupNode = findOptionMDForLoopID(LoopID: OrigLoopID, Name: OptionName);
319 if (!FollowupNode)
320 continue;
321
322 HasAnyFollowup = true;
323 for (const MDOperand &Option : drop_begin(RangeOrContainer: FollowupNode->operands())) {
324 MDs.push_back(Elt: Option.get());
325 Changed = true;
326 }
327 }
328
329 // Attributes of the followup loop not specified explicity, so signal to the
330 // transformation pass to add suitable attributes.
331 if (!AlwaysNew && !HasAnyFollowup)
332 return std::nullopt;
333
334 // If no attributes were added or remove, the previous loop Id can be reused.
335 if (!AlwaysNew && !Changed)
336 return OrigLoopID;
337
338 // No attributes is equivalent to having no !llvm.loop metadata at all.
339 if (MDs.size() == 1)
340 return nullptr;
341
342 // Build the new loop ID.
343 MDTuple *FollowupLoopID = MDNode::get(Context&: OrigLoopID->getContext(), MDs);
344 FollowupLoopID->replaceOperandWith(I: 0, New: FollowupLoopID);
345 return FollowupLoopID;
346}
347
348bool llvm::hasDisableAllTransformsHint(const Loop *L) {
349 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableNonforced);
350}
351
352bool llvm::hasDisableLICMTransformsHint(const Loop *L) {
353 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableLICM);
354}
355
356TransformationMode llvm::hasUnrollTransformation(const Loop *L) {
357 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable"))
358 return TM_SuppressedByUser;
359
360 std::optional<int> Count =
361 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count");
362 if (Count)
363 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
364
365 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable"))
366 return TM_ForcedByUser;
367
368 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full"))
369 return TM_ForcedByUser;
370
371 if (hasDisableAllTransformsHint(L))
372 return TM_Disable;
373
374 return TM_Unspecified;
375}
376
377TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) {
378 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.disable"))
379 return TM_SuppressedByUser;
380
381 std::optional<int> Count =
382 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.count");
383 if (Count)
384 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
385
386 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.enable"))
387 return TM_ForcedByUser;
388
389 if (hasDisableAllTransformsHint(L))
390 return TM_Disable;
391
392 return TM_Unspecified;
393}
394
395TransformationMode llvm::hasVectorizeTransformation(const Loop *L) {
396 std::optional<bool> Enable =
397 getOptionalBoolLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.enable");
398
399 if (Enable == false)
400 return TM_SuppressedByUser;
401
402 std::optional<ElementCount> VectorizeWidth =
403 getOptionalElementCountLoopAttribute(TheLoop: L);
404 std::optional<int> InterleaveCount =
405 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.interleave.count");
406
407 // 'Forcing' vector width and interleave count to one effectively disables
408 // this tranformation.
409 if (Enable == true && VectorizeWidth && VectorizeWidth->isScalar() &&
410 InterleaveCount == 1)
411 return TM_SuppressedByUser;
412
413 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.isvectorized"))
414 return TM_Disable;
415
416 if (Enable == true)
417 return TM_ForcedByUser;
418
419 if ((VectorizeWidth && VectorizeWidth->isScalar()) && InterleaveCount == 1)
420 return TM_Disable;
421
422 if ((VectorizeWidth && VectorizeWidth->isVector()) || InterleaveCount > 1)
423 return TM_Enable;
424
425 if (hasDisableAllTransformsHint(L))
426 return TM_Disable;
427
428 return TM_Unspecified;
429}
430
431TransformationMode llvm::hasDistributeTransformation(const Loop *L) {
432 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.distribute.enable"))
433 return TM_ForcedByUser;
434
435 if (hasDisableAllTransformsHint(L))
436 return TM_Disable;
437
438 return TM_Unspecified;
439}
440
441TransformationMode llvm::hasLICMVersioningTransformation(const Loop *L) {
442 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.licm_versioning.disable"))
443 return TM_SuppressedByUser;
444
445 if (hasDisableAllTransformsHint(L))
446 return TM_Disable;
447
448 return TM_Unspecified;
449}
450
451/// Does a BFS from a given node to all of its children inside a given loop.
452/// The returned vector of basic blocks includes the starting point.
453SmallVector<BasicBlock *, 16> llvm::collectChildrenInLoop(DominatorTree *DT,
454 DomTreeNode *N,
455 const Loop *CurLoop) {
456 SmallVector<BasicBlock *, 16> Worklist;
457 auto AddRegionToWorklist = [&](DomTreeNode *DTN) {
458 // Only include subregions in the top level loop.
459 BasicBlock *BB = DTN->getBlock();
460 if (CurLoop->contains(BB))
461 Worklist.push_back(Elt: DTN->getBlock());
462 };
463
464 AddRegionToWorklist(N);
465
466 for (size_t I = 0; I < Worklist.size(); I++) {
467 for (DomTreeNode *Child : DT->getNode(BB: Worklist[I])->children())
468 AddRegionToWorklist(Child);
469 }
470
471 return Worklist;
472}
473
474bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
475 int LatchIdx = PN->getBasicBlockIndex(BB: LatchBlock);
476 assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode");
477 Value *IncV = PN->getIncomingValue(i: LatchIdx);
478
479 for (User *U : PN->users())
480 if (U != Cond && U != IncV) return false;
481
482 for (User *U : IncV->users())
483 if (U != Cond && U != PN) return false;
484 return true;
485}
486
487
488void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
489 LoopInfo *LI, MemorySSA *MSSA) {
490 assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
491 auto *Preheader = L->getLoopPreheader();
492 assert(Preheader && "Preheader should exist!");
493
494 std::unique_ptr<MemorySSAUpdater> MSSAU;
495 if (MSSA)
496 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
497
498 // Now that we know the removal is safe, remove the loop by changing the
499 // branch from the preheader to go to the single exit block.
500 //
501 // Because we're deleting a large chunk of code at once, the sequence in which
502 // we remove things is very important to avoid invalidation issues.
503
504 // Tell ScalarEvolution that the loop is deleted. Do this before
505 // deleting the loop so that ScalarEvolution can look at the loop
506 // to determine what it needs to clean up.
507 if (SE) {
508 SE->forgetLoop(L);
509 SE->forgetBlockAndLoopDispositions();
510 }
511
512 Instruction *OldTerm = Preheader->getTerminator();
513 assert(!OldTerm->mayHaveSideEffects() &&
514 "Preheader must end with a side-effect-free terminator");
515 assert(OldTerm->getNumSuccessors() == 1 &&
516 "Preheader must have a single successor");
517 // Connect the preheader to the exit block. Keep the old edge to the header
518 // around to perform the dominator tree update in two separate steps
519 // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge
520 // preheader -> header.
521 //
522 //
523 // 0. Preheader 1. Preheader 2. Preheader
524 // | | | |
525 // V | V |
526 // Header <--\ | Header <--\ | Header <--\
527 // | | | | | | | | | | |
528 // | V | | | V | | | V |
529 // | Body --/ | | Body --/ | | Body --/
530 // V V V V V
531 // Exit Exit Exit
532 //
533 // By doing this is two separate steps we can perform the dominator tree
534 // update without using the batch update API.
535 //
536 // Even when the loop is never executed, we cannot remove the edge from the
537 // source block to the exit block. Consider the case where the unexecuted loop
538 // branches back to an outer loop. If we deleted the loop and removed the edge
539 // coming to this inner loop, this will break the outer loop structure (by
540 // deleting the backedge of the outer loop). If the outer loop is indeed a
541 // non-loop, it will be deleted in a future iteration of loop deletion pass.
542 IRBuilder<> Builder(OldTerm);
543
544 auto *ExitBlock = L->getUniqueExitBlock();
545 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
546 if (ExitBlock) {
547 assert(ExitBlock && "Should have a unique exit block!");
548 assert(L->hasDedicatedExits() && "Loop should have dedicated exits!");
549
550 Builder.CreateCondBr(Cond: Builder.getFalse(), True: L->getHeader(), False: ExitBlock);
551 // Remove the old branch. The conditional branch becomes a new terminator.
552 OldTerm->eraseFromParent();
553
554 // Rewrite phis in the exit block to get their inputs from the Preheader
555 // instead of the exiting block.
556 for (PHINode &P : ExitBlock->phis()) {
557 // Set the zero'th element of Phi to be from the preheader and remove all
558 // other incoming values. Given the loop has dedicated exits, all other
559 // incoming values must be from the exiting blocks.
560 int PredIndex = 0;
561 P.setIncomingBlock(i: PredIndex, BB: Preheader);
562 // Removes all incoming values from all other exiting blocks (including
563 // duplicate values from an exiting block).
564 // Nuke all entries except the zero'th entry which is the preheader entry.
565 P.removeIncomingValueIf(Predicate: [](unsigned Idx) { return Idx != 0; },
566 /* DeletePHIIfEmpty */ false);
567
568 assert((P.getNumIncomingValues() == 1 &&
569 P.getIncomingBlock(PredIndex) == Preheader) &&
570 "Should have exactly one value and that's from the preheader!");
571 }
572
573 if (DT) {
574 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}});
575 if (MSSA) {
576 MSSAU->applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}},
577 DT&: *DT);
578 if (VerifyMemorySSA)
579 MSSA->verifyMemorySSA();
580 }
581 }
582
583 // Disconnect the loop body by branching directly to its exit.
584 Builder.SetInsertPoint(Preheader->getTerminator());
585 Builder.CreateBr(Dest: ExitBlock);
586 // Remove the old branch.
587 Preheader->getTerminator()->eraseFromParent();
588 } else {
589 assert(L->hasNoExitBlocks() &&
590 "Loop should have either zero or one exit blocks.");
591
592 Builder.SetInsertPoint(OldTerm);
593 Builder.CreateUnreachable();
594 Preheader->getTerminator()->eraseFromParent();
595 }
596
597 if (DT) {
598 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}});
599 if (MSSA) {
600 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}},
601 DT&: *DT);
602 SmallSetVector<BasicBlock *, 8> DeadBlockSet(L->block_begin(),
603 L->block_end());
604 MSSAU->removeBlocks(DeadBlocks: DeadBlockSet);
605 if (VerifyMemorySSA)
606 MSSA->verifyMemorySSA();
607 }
608 }
609
610 // Use a map to unique and a vector to guarantee deterministic ordering.
611 llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
612 llvm::SmallVector<DbgVariableRecord *, 4> DeadDbgVariableRecords;
613
614 // Given LCSSA form is satisfied, we should not have users of instructions
615 // within the dead loop outside of the loop. However, LCSSA doesn't take
616 // unreachable uses into account. We handle them here.
617 // We could do it after drop all references (in this case all users in the
618 // loop will be already eliminated and we have less work to do but according
619 // to API doc of User::dropAllReferences only valid operation after dropping
620 // references, is deletion. So let's substitute all usages of
621 // instruction from the loop with poison value of corresponding type first.
622 for (auto *Block : L->blocks())
623 for (Instruction &I : *Block) {
624 auto *Poison = PoisonValue::get(T: I.getType());
625 for (Use &U : llvm::make_early_inc_range(Range: I.uses())) {
626 if (auto *Usr = dyn_cast<Instruction>(Val: U.getUser()))
627 if (L->contains(BB: Usr->getParent()))
628 continue;
629 // If we have a DT then we can check that uses outside a loop only in
630 // unreachable block.
631 if (DT)
632 assert(!DT->isReachableFromEntry(U) &&
633 "Unexpected user in reachable block");
634 U.set(Poison);
635 }
636
637 if (ExitBlock) {
638 // For one of each variable encountered, preserve a debug record (set
639 // to Poison) and transfer it to the loop exit. This terminates any
640 // variable locations that were set during the loop.
641 for (DbgVariableRecord &DVR :
642 llvm::make_early_inc_range(Range: filterDbgVars(R: I.getDbgRecordRange()))) {
643 DebugVariable Key(DVR.getVariable(), DVR.getExpression(),
644 DVR.getDebugLoc().get());
645 if (!DeadDebugSet.insert(V: Key).second)
646 continue;
647 // Unlinks the DVR from it's container, for later insertion.
648 DVR.removeFromParent();
649 DeadDbgVariableRecords.push_back(Elt: &DVR);
650 }
651 }
652 }
653
654 if (ExitBlock) {
655 // After the loop has been deleted all the values defined and modified
656 // inside the loop are going to be unavailable. Values computed in the
657 // loop will have been deleted, automatically causing their debug uses
658 // be be replaced with undef. Loop invariant values will still be available.
659 // Move dbg.values out the loop so that earlier location ranges are still
660 // terminated and loop invariant assignments are preserved.
661 DIBuilder DIB(*ExitBlock->getModule());
662 BasicBlock::iterator InsertDbgValueBefore =
663 ExitBlock->getFirstInsertionPt();
664 assert(InsertDbgValueBefore != ExitBlock->end() &&
665 "There should be a non-PHI instruction in exit block, else these "
666 "instructions will have no parent.");
667
668 // Due to the "head" bit in BasicBlock::iterator, we're going to insert
669 // each DbgVariableRecord right at the start of the block, wheras dbg.values
670 // would be repeatedly inserted before the first instruction. To replicate
671 // this behaviour, do it backwards.
672 for (DbgVariableRecord *DVR : llvm::reverse(C&: DeadDbgVariableRecords))
673 ExitBlock->insertDbgRecordBefore(DR: DVR, Here: InsertDbgValueBefore);
674 }
675
676 // Remove the block from the reference counting scheme, so that we can
677 // delete it freely later.
678 for (auto *Block : L->blocks())
679 Block->dropAllReferences();
680
681 if (MSSA && VerifyMemorySSA)
682 MSSA->verifyMemorySSA();
683
684 if (LI) {
685 SmallPtrSet<BasicBlock *, 8> Blocks(llvm::from_range, L->blocks());
686
687 // Erase the instructions and the blocks without having to worry
688 // about ordering because we already dropped the references.
689 // Remove blocks from loopinfo before erasing them, otherwise the loopinfo
690 // cannot find the loop using block numbers.
691 for (BasicBlock *BB : Blocks) {
692 LI->removeBlock(BB);
693 BB->eraseFromParent();
694 }
695
696 // The last step is to update LoopInfo now that we've eliminated this loop.
697 // Note: LoopInfo::erase remove the given loop and relink its subloops with
698 // its parent. While removeLoop/removeChildLoop remove the given loop but
699 // not relink its subloops, which is what we want.
700 if (Loop *ParentLoop = L->getParentLoop()) {
701 Loop::iterator I = find(Range&: *ParentLoop, Val: L);
702 assert(I != ParentLoop->end() && "Couldn't find loop");
703 ParentLoop->removeChildLoop(I);
704 } else {
705 Loop::iterator I = find(Range&: *LI, Val: L);
706 assert(I != LI->end() && "Couldn't find loop");
707 LI->removeLoop(I);
708 }
709 LI->destroy(L);
710 }
711}
712
713void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
714 LoopInfo &LI, MemorySSA *MSSA) {
715 auto *Latch = L->getLoopLatch();
716 assert(Latch && "multiple latches not yet supported");
717 auto *Header = L->getHeader();
718 Loop *OutermostLoop = L->getOutermostLoop();
719
720 SE.forgetLoop(L);
721 SE.forgetBlockAndLoopDispositions();
722
723 std::unique_ptr<MemorySSAUpdater> MSSAU;
724 if (MSSA)
725 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
726
727 // Update the CFG and domtree. We chose to special case a couple of
728 // of common cases for code quality and test readability reasons.
729 [&]() -> void {
730 if (auto *BI = dyn_cast<UncondBrInst>(Val: Latch->getTerminator())) {
731 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
732 (void)changeToUnreachable(I: BI, /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
733 return;
734 }
735 if (auto *BI = dyn_cast<CondBrInst>(Val: Latch->getTerminator())) {
736 // Conditional latch/exit - note that latch can be shared by inner
737 // and outer loop so the other target doesn't need to an exit
738 if (L->isLoopExiting(BB: Latch)) {
739 // TODO: Generalize ConstantFoldTerminator so that it can be used
740 // here without invalidating LCSSA or MemorySSA. (Tricky case for
741 // LCSSA: header is an exit block of a preceeding sibling loop w/o
742 // dedicated exits.)
743 const unsigned ExitIdx = L->contains(BB: BI->getSuccessor(i: 0)) ? 1 : 0;
744 BasicBlock *ExitBB = BI->getSuccessor(i: ExitIdx);
745
746 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
747 Header->removePredecessor(Pred: Latch, KeepOneInputPHIs: true);
748
749 IRBuilder<> Builder(BI);
750 auto *NewBI = Builder.CreateBr(Dest: ExitBB);
751 // Transfer the metadata to the new branch instruction (minus the
752 // loop info since this is no longer a loop)
753 NewBI->copyMetadata(SrcInst: *BI, WL: {LLVMContext::MD_dbg,
754 LLVMContext::MD_annotation});
755
756 BI->eraseFromParent();
757 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}});
758 if (MSSA)
759 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}}, DT);
760 return;
761 }
762 }
763
764 // General case. By splitting the backedge, and then explicitly making it
765 // unreachable we gracefully handle corner cases such as switch and invoke
766 // termiantors.
767 auto *BackedgeBB = SplitEdge(From: Latch, To: Header, DT: &DT, LI: &LI, MSSAU: MSSAU.get());
768
769 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
770 (void)changeToUnreachable(I: BackedgeBB->getTerminator(),
771 /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
772 }();
773
774 // Erase (and destroy) this loop instance. Handles relinking sub-loops
775 // and blocks within the loop as needed.
776 LI.erase(L);
777
778 // If the loop we broke had a parent, then changeToUnreachable might have
779 // caused a block to be removed from the parent loop (see loop_nest_lcssa
780 // test case in zero-btc.ll for an example), thus changing the parent's
781 // exit blocks. If that happened, we need to rebuild LCSSA on the outermost
782 // loop which might have a had a block removed.
783 if (OutermostLoop != L)
784 formLCSSARecursively(L&: *OutermostLoop, DT, LI: &LI, SE: &SE);
785}
786
787
788/// Checks if \p L has an exiting latch branch. There may also be other
789/// exiting blocks. Returns branch instruction terminating the loop
790/// latch if above check is successful, nullptr otherwise.
791static CondBrInst *getExpectedExitLoopLatchBranch(Loop *L) {
792 BasicBlock *Latch = L->getLoopLatch();
793 if (!Latch)
794 return nullptr;
795
796 CondBrInst *LatchBR = dyn_cast<CondBrInst>(Val: Latch->getTerminator());
797 if (!LatchBR || !L->isLoopExiting(BB: Latch))
798 return nullptr;
799
800 assert((LatchBR->getSuccessor(0) == L->getHeader() ||
801 LatchBR->getSuccessor(1) == L->getHeader()) &&
802 "At least one edge out of the latch must go to the header");
803
804 return LatchBR;
805}
806
807struct DbgLoop {
808 const Loop *L;
809 explicit DbgLoop(const Loop *L) : L(L) {}
810};
811
812#ifndef NDEBUG
813static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
814 OS << "function ";
815 D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
816 return OS << " " << *D.L;
817}
818#endif // NDEBUG
819
820static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
821 // Currently we take the estimate exit count only from the loop latch,
822 // ignoring other exiting blocks. This can overestimate the trip count
823 // if we exit through another exit, but can never underestimate it.
824 // TODO: incorporate information from other exits
825 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
826 if (!ExitingBranch) {
827 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
828 << "latch branch of required form in " << DbgLoop(L)
829 << "\n");
830 return std::nullopt;
831 }
832
833 // To estimate the number of times the loop body was executed, we want to
834 // know the number of times the backedge was taken, vs. the number of times
835 // we exited the loop.
836 uint64_t LoopWeight, ExitWeight;
837 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight)) {
838 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
839 << "weights for " << DbgLoop(L) << "\n");
840 return std::nullopt;
841 }
842
843 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
844 std::swap(a&: LoopWeight, b&: ExitWeight);
845
846 if (!ExitWeight) {
847 // Don't have a way to return predicated infinite
848 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
849 << "probability for " << DbgLoop(L) << "\n");
850 return std::nullopt;
851 }
852
853 // Estimated exit count is a ratio of the loop weight by the weight of the
854 // edge exiting the loop, rounded to nearest.
855 uint64_t ExitCount = llvm::divideNearest(Numerator: LoopWeight, Denominator: ExitWeight);
856
857 // When ExitCount + 1 would wrap in unsigned, saturate at UINT_MAX.
858 if (ExitCount >= std::numeric_limits<unsigned>::max())
859 return std::numeric_limits<unsigned>::max();
860
861 // Estimated trip count is one plus estimated exit count.
862 uint64_t TC = ExitCount + 1;
863 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Estimated trip count of " << TC
864 << " for " << DbgLoop(L) << "\n");
865 return TC;
866}
867
868std::optional<unsigned>
869llvm::getLoopEstimatedTripCount(Loop *L,
870 unsigned *EstimatedLoopInvocationWeight) {
871 // If EstimatedLoopInvocationWeight, we do not support this loop if
872 // getExpectedExitLoopLatchBranch returns nullptr.
873 //
874 // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
875 // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
876 // it was created for an inner loop. The problem is that loop metadata is
877 // attached to the branch instruction in the loop latch block, but that can be
878 // shared by the loops. A solution is to attach loop metadata to loop headers
879 // instead, but that would be a large change to LLVM.
880 //
881 // Until that happens, we work around the problem as follows.
882 // getExpectedExitLoopLatchBranch (which also guards
883 // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
884 // one latch and that latch has exactly two successors one of which is an exit
885 // from the loop. If the latch is shared by nested loops, then that condition
886 // might hold for the inner loop but cannot hold for the outer loop:
887 // - Because the latch is shared, it must have at least two successors: the
888 // inner loop header and the outer loop header, which is also an exit for
889 // the inner loop. That satisifies the condition for the inner loop.
890 // - To satsify the condition for the outer loop, the latch must have a third
891 // successor that is an exit for the outer loop. But that violates the
892 // condition for both loops.
893 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
894 if (!ExitingBranch)
895 return std::nullopt;
896
897 // If requested, either compute *EstimatedLoopInvocationWeight or return
898 // nullopt if cannot.
899 //
900 // TODO: Eventually, once all passes have migrated away from setting branch
901 // weights to indicate estimated trip counts, this function will drop the
902 // EstimatedLoopInvocationWeight parameter.
903 if (EstimatedLoopInvocationWeight) {
904 uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
905 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight))
906 return std::nullopt;
907 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
908 std::swap(a&: LoopWeight, b&: ExitWeight);
909 if (!ExitWeight)
910 return std::nullopt;
911 *EstimatedLoopInvocationWeight = ExitWeight;
912 }
913
914 // Return the estimated trip count from metadata unless the metadata is
915 // missing or has no value.
916 //
917 // Some passes set llvm.loop.estimated_trip_count to 0. For example, after
918 // peeling 10 or more iterations from a loop with an estimated trip count of
919 // 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It
920 // indicates that, each time execution reaches the peeled iterations,
921 // execution is estimated to exit them without reaching the remaining loop's
922 // header.
923 //
924 // Even if the probability of reaching a loop's header is low, if it is
925 // reached, it is the start of an iteration. Consequently, some passes
926 // historically assume that llvm::getLoopEstimatedTripCount always returns a
927 // positive count or std::nullopt. Thus, return std::nullopt when
928 // llvm.loop.estimated_trip_count is 0.
929 if (auto TC = getOptionalIntLoopAttribute(TheLoop: L, Name: LLVMLoopEstimatedTripCount)) {
930 LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
931 << LLVMLoopEstimatedTripCount << " metadata has trip "
932 << "count of " << *TC
933 << (*TC == 0 ? " (returning std::nullopt)" : "")
934 << " for " << DbgLoop(L) << "\n");
935 return *TC == 0 ? std::nullopt : std::optional(*TC);
936 }
937
938 // Estimate the trip count from latch branch weights.
939 return estimateLoopTripCount(L);
940}
941
942bool llvm::setLoopEstimatedTripCount(
943 Loop *L, unsigned EstimatedTripCount,
944 std::optional<unsigned> EstimatedloopInvocationWeight) {
945 // If EstimatedLoopInvocationWeight, we do not support this loop if
946 // getExpectedExitLoopLatchBranch returns nullptr.
947 //
948 // FIXME: See comments in getLoopEstimatedTripCount for why this is required
949 // here regardless of EstimatedLoopInvocationWeight.
950 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
951 if (!LatchBranch)
952 return false;
953
954 // Set the metadata.
955 addStringMetadataToLoop(TheLoop: L, StringMD: LLVMLoopEstimatedTripCount, V: EstimatedTripCount);
956
957 // At the moment, we currently support changing the estimated trip count in
958 // the latch branch's branch weights only. We could extend this API to
959 // manipulate estimated trip counts for any exit.
960 //
961 // TODO: Eventually, once all passes have migrated away from setting branch
962 // weights to indicate estimated trip counts, we will not set branch weights
963 // here at all.
964 if (!EstimatedloopInvocationWeight)
965 return true;
966
967 // Calculate taken and exit weights.
968 unsigned LatchExitWeight = ProfcheckDisableMetadataFixes ? 0 : 1;
969 unsigned BackedgeTakenWeight = 0;
970
971 if (EstimatedTripCount != 0) {
972 LatchExitWeight = *EstimatedloopInvocationWeight;
973 BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
974 }
975
976 // Make a swap if back edge is taken when condition is "false".
977 if (LatchBranch->getSuccessor(i: 0) != L->getHeader())
978 std::swap(a&: BackedgeTakenWeight, b&: LatchExitWeight);
979
980 // Set/Update profile metadata.
981 setBranchWeights(I&: *LatchBranch, Weights: {BackedgeTakenWeight, LatchExitWeight},
982 /*IsExpected=*/false);
983
984 return true;
985}
986
987BranchProbability llvm::getLoopProbability(Loop *L) {
988 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
989 if (!LatchBranch)
990 return BranchProbability::getUnknown();
991 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
992 return getBranchProbability(B: LatchBranch, ForFirstTarget: FirstTargetIsLoop);
993}
994
995bool llvm::setLoopProbability(Loop *L, BranchProbability P) {
996 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
997 if (!LatchBranch)
998 return false;
999 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
1000 setBranchProbability(B: LatchBranch, P, ForFirstTarget: FirstTargetIsLoop);
1001 return true;
1002}
1003
1004BranchProbability llvm::getBranchProbability(CondBrInst *B,
1005 bool ForFirstTarget) {
1006 uint64_t Weight0, Weight1;
1007 if (!extractBranchWeights(I: *B, TrueVal&: Weight0, FalseVal&: Weight1))
1008 return BranchProbability::getUnknown();
1009 uint64_t Denominator = Weight0 + Weight1;
1010 if (Denominator == 0)
1011 return BranchProbability::getUnknown();
1012 if (!ForFirstTarget)
1013 std::swap(a&: Weight0, b&: Weight1);
1014 return BranchProbability::getBranchProbability(Numerator: Weight0, Denominator);
1015}
1016
1017BranchProbability llvm::getBranchProbability(BasicBlock *Src, BasicBlock *Dst) {
1018 assert(Src != Dst && "Passed in same source as destination");
1019
1020 Instruction *TI = Src->getTerminator();
1021 if (!TI || TI->getNumSuccessors() == 0)
1022 return BranchProbability::getZero();
1023
1024 SmallVector<uint32_t, 4> Weights;
1025
1026 if (!extractBranchWeights(I: *TI, Weights)) {
1027 // No metadata
1028 return BranchProbability::getUnknown();
1029 }
1030 assert(TI->getNumSuccessors() == Weights.size() &&
1031 "Missing weights in branch_weights");
1032
1033 uint64_t Total = 0;
1034 uint32_t Numerator = 0;
1035 for (auto [i, Weight] : llvm::enumerate(First&: Weights)) {
1036 if (TI->getSuccessor(Idx: i) == Dst)
1037 Numerator += Weight;
1038 Total += Weight;
1039 }
1040
1041 // Total of edges might be 0 if the metadata is incorrect/set by hand
1042 // or missing. In such case return here to avoid division by 0 later on.
1043 // There might also be a case where the value of Total cannot fit into
1044 // uint32_t, in such case, just bail out.
1045 if (Total == 0 || Total > std::numeric_limits<uint32_t>::max())
1046 return BranchProbability::getUnknown();
1047
1048 return BranchProbability(Numerator, Total);
1049}
1050
1051void llvm::setBranchProbability(CondBrInst *B, BranchProbability P,
1052 bool ForFirstTarget) {
1053 BranchProbability Prob0 = P;
1054 BranchProbability Prob1 = P.getCompl();
1055 if (!ForFirstTarget)
1056 std::swap(a&: Prob0, b&: Prob1);
1057 setBranchWeights(I&: *B, Weights: {Prob0.getNumerator(), Prob1.getNumerator()},
1058 /*IsExpected=*/false);
1059}
1060
1061bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
1062 ScalarEvolution &SE) {
1063 Loop *OuterL = InnerLoop->getParentLoop();
1064 if (!OuterL)
1065 return true;
1066
1067 // Get the backedge taken count for the inner loop
1068 BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch();
1069 const SCEV *InnerLoopBECountSC = SE.getExitCount(L: InnerLoop, ExitingBlock: InnerLoopLatch);
1070 if (isa<SCEVCouldNotCompute>(Val: InnerLoopBECountSC) ||
1071 !InnerLoopBECountSC->getType()->isIntegerTy())
1072 return false;
1073
1074 // Get whether count is invariant to the outer loop
1075 ScalarEvolution::LoopDisposition LD =
1076 SE.getLoopDisposition(S: InnerLoopBECountSC, L: OuterL);
1077 if (LD != ScalarEvolution::LoopInvariant)
1078 return false;
1079
1080 return true;
1081}
1082
1083constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
1084 switch (RK) {
1085 default:
1086 llvm_unreachable("Unexpected recurrence kind");
1087 case RecurKind::AddChainWithSubs:
1088 case RecurKind::Sub:
1089 case RecurKind::Add:
1090 return Intrinsic::vector_reduce_add;
1091 case RecurKind::Mul:
1092 return Intrinsic::vector_reduce_mul;
1093 case RecurKind::And:
1094 return Intrinsic::vector_reduce_and;
1095 case RecurKind::Or:
1096 return Intrinsic::vector_reduce_or;
1097 case RecurKind::Xor:
1098 return Intrinsic::vector_reduce_xor;
1099 case RecurKind::FMulAdd:
1100 case RecurKind::FAdd:
1101 return Intrinsic::vector_reduce_fadd;
1102 case RecurKind::FMul:
1103 return Intrinsic::vector_reduce_fmul;
1104 case RecurKind::SMax:
1105 return Intrinsic::vector_reduce_smax;
1106 case RecurKind::SMin:
1107 return Intrinsic::vector_reduce_smin;
1108 case RecurKind::UMax:
1109 return Intrinsic::vector_reduce_umax;
1110 case RecurKind::UMin:
1111 return Intrinsic::vector_reduce_umin;
1112 case RecurKind::FMax:
1113 case RecurKind::FMaxNum:
1114 return Intrinsic::vector_reduce_fmax;
1115 case RecurKind::FMin:
1116 case RecurKind::FMinNum:
1117 return Intrinsic::vector_reduce_fmin;
1118 case RecurKind::FMaximum:
1119 return Intrinsic::vector_reduce_fmaximum;
1120 case RecurKind::FMinimum:
1121 return Intrinsic::vector_reduce_fminimum;
1122 case RecurKind::FMaximumNum:
1123 return Intrinsic::vector_reduce_fmax;
1124 case RecurKind::FMinimumNum:
1125 return Intrinsic::vector_reduce_fmin;
1126 }
1127}
1128
1129Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
1130 switch (IID) {
1131 default:
1132 llvm_unreachable("Unexpected intrinsic id");
1133 case Intrinsic::umin:
1134 return Intrinsic::vector_reduce_umin;
1135 case Intrinsic::umax:
1136 return Intrinsic::vector_reduce_umax;
1137 case Intrinsic::smin:
1138 return Intrinsic::vector_reduce_smin;
1139 case Intrinsic::smax:
1140 return Intrinsic::vector_reduce_smax;
1141 }
1142}
1143
1144// This is the inverse to getReductionForBinop
1145unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
1146 switch (RdxID) {
1147 case Intrinsic::vector_reduce_fadd:
1148 return Instruction::FAdd;
1149 case Intrinsic::vector_reduce_fmul:
1150 return Instruction::FMul;
1151 case Intrinsic::vector_reduce_add:
1152 return Instruction::Add;
1153 case Intrinsic::vector_reduce_mul:
1154 return Instruction::Mul;
1155 case Intrinsic::vector_reduce_and:
1156 return Instruction::And;
1157 case Intrinsic::vector_reduce_or:
1158 return Instruction::Or;
1159 case Intrinsic::vector_reduce_xor:
1160 return Instruction::Xor;
1161 case Intrinsic::vector_reduce_smax:
1162 case Intrinsic::vector_reduce_smin:
1163 case Intrinsic::vector_reduce_umax:
1164 case Intrinsic::vector_reduce_umin:
1165 return Instruction::ICmp;
1166 case Intrinsic::vector_reduce_fmax:
1167 case Intrinsic::vector_reduce_fmin:
1168 return Instruction::FCmp;
1169 default:
1170 llvm_unreachable("Unexpected ID");
1171 }
1172}
1173
1174// This is the inverse to getArithmeticReductionInstruction
1175Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
1176 switch (Opc) {
1177 default:
1178 break;
1179 case Instruction::Add:
1180 return Intrinsic::vector_reduce_add;
1181 case Instruction::Mul:
1182 return Intrinsic::vector_reduce_mul;
1183 case Instruction::And:
1184 return Intrinsic::vector_reduce_and;
1185 case Instruction::Or:
1186 return Intrinsic::vector_reduce_or;
1187 case Instruction::Xor:
1188 return Intrinsic::vector_reduce_xor;
1189 }
1190 return Intrinsic::not_intrinsic;
1191}
1192
1193Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
1194 switch (RdxID) {
1195 default:
1196 llvm_unreachable("Unknown min/max recurrence kind");
1197 case Intrinsic::vector_reduce_umin:
1198 return Intrinsic::umin;
1199 case Intrinsic::vector_reduce_umax:
1200 return Intrinsic::umax;
1201 case Intrinsic::vector_reduce_smin:
1202 return Intrinsic::smin;
1203 case Intrinsic::vector_reduce_smax:
1204 return Intrinsic::smax;
1205 case Intrinsic::vector_reduce_fmin:
1206 return Intrinsic::minnum;
1207 case Intrinsic::vector_reduce_fmax:
1208 return Intrinsic::maxnum;
1209 case Intrinsic::vector_reduce_fminimum:
1210 return Intrinsic::minimum;
1211 case Intrinsic::vector_reduce_fmaximum:
1212 return Intrinsic::maximum;
1213 }
1214}
1215
1216Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
1217 switch (RK) {
1218 default:
1219 llvm_unreachable("Unknown min/max recurrence kind");
1220 case RecurKind::UMin:
1221 return Intrinsic::umin;
1222 case RecurKind::UMax:
1223 return Intrinsic::umax;
1224 case RecurKind::SMin:
1225 return Intrinsic::smin;
1226 case RecurKind::SMax:
1227 return Intrinsic::smax;
1228 case RecurKind::FMin:
1229 case RecurKind::FMinNum:
1230 return Intrinsic::minnum;
1231 case RecurKind::FMax:
1232 case RecurKind::FMaxNum:
1233 return Intrinsic::maxnum;
1234 case RecurKind::FMinimum:
1235 return Intrinsic::minimum;
1236 case RecurKind::FMaximum:
1237 return Intrinsic::maximum;
1238 case RecurKind::FMinimumNum:
1239 return Intrinsic::minimumnum;
1240 case RecurKind::FMaximumNum:
1241 return Intrinsic::maximumnum;
1242 }
1243}
1244
1245RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
1246 switch (RdxID) {
1247 case Intrinsic::vector_reduce_smax:
1248 return RecurKind::SMax;
1249 case Intrinsic::vector_reduce_smin:
1250 return RecurKind::SMin;
1251 case Intrinsic::vector_reduce_umax:
1252 return RecurKind::UMax;
1253 case Intrinsic::vector_reduce_umin:
1254 return RecurKind::UMin;
1255 case Intrinsic::vector_reduce_fmax:
1256 return RecurKind::FMax;
1257 case Intrinsic::vector_reduce_fmin:
1258 return RecurKind::FMin;
1259 default:
1260 return RecurKind::None;
1261 }
1262}
1263
1264CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
1265 switch (RK) {
1266 default:
1267 llvm_unreachable("Unknown min/max recurrence kind");
1268 case RecurKind::UMin:
1269 return CmpInst::ICMP_ULT;
1270 case RecurKind::UMax:
1271 return CmpInst::ICMP_UGT;
1272 case RecurKind::SMin:
1273 return CmpInst::ICMP_SLT;
1274 case RecurKind::SMax:
1275 return CmpInst::ICMP_SGT;
1276 case RecurKind::FMin:
1277 return CmpInst::FCMP_OLT;
1278 case RecurKind::FMax:
1279 return CmpInst::FCMP_OGT;
1280 // We do not add FMinimum/FMaximum recurrence kind here since there is no
1281 // equivalent predicate which compares signed zeroes according to the
1282 // semantics of the intrinsics (llvm.minimum/maximum).
1283 }
1284}
1285
1286Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
1287 Value *Right) {
1288 Type *Ty = Left->getType();
1289 if (Ty->isIntOrIntVectorTy() ||
1290 (RK == RecurKind::FMinNum || RK == RecurKind::FMaxNum ||
1291 RK == RecurKind::FMinimum || RK == RecurKind::FMaximum ||
1292 RK == RecurKind::FMinimumNum || RK == RecurKind::FMaximumNum)) {
1293 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK);
1294 return Builder.CreateIntrinsic(RetTy: Ty, ID: Id, Args: {Left, Right}, FMFSource: nullptr,
1295 Name: "rdx.minmax");
1296 }
1297 CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
1298 Value *Cmp = Builder.CreateCmp(Pred, LHS: Left, RHS: Right, Name: "rdx.minmax.cmp");
1299 Value *Select = Builder.CreateSelect(C: Cmp, True: Left, False: Right, Name: "rdx.minmax.select");
1300 return Select;
1301}
1302
1303// Helper to generate an ordered reduction.
1304Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
1305 unsigned Op, RecurKind RdxKind) {
1306 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1307
1308 // Extract and apply reduction ops in ascending order:
1309 // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
1310 Value *Result = Acc;
1311 for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
1312 Value *Ext =
1313 Builder.CreateExtractElement(Vec: Src, Idx: Builder.getInt32(C: ExtractIdx));
1314
1315 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1316 Result = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: Result, RHS: Ext,
1317 Name: "bin.rdx");
1318 } else {
1319 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1320 "Invalid min/max");
1321 Result = createMinMaxOp(Builder, RK: RdxKind, Left: Result, Right: Ext);
1322 }
1323 }
1324
1325 return Result;
1326}
1327
1328// Helper to generate a log2 shuffle reduction.
1329Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1330 unsigned Op,
1331 TargetTransformInfo::ReductionShuffle RS,
1332 RecurKind RdxKind) {
1333 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1334 // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
1335 // and vector ops, reducing the set of values being computed by half each
1336 // round.
1337 assert(isPowerOf2_32(VF) &&
1338 "Reduction emission only supported for pow2 vectors!");
1339 // Note: fast-math-flags flags are controlled by the builder configuration
1340 // and are assumed to apply to all generated arithmetic instructions. Other
1341 // poison generating flags (nsw/nuw/inbounds/inrange/exact) are not part
1342 // of the builder configuration, and since they're not passed explicitly,
1343 // will never be relevant here. Note that it would be generally unsound to
1344 // propagate these from an intrinsic call to the expansion anyways as we/
1345 // change the order of operations.
1346 auto BuildShuffledOp = [&Builder, &Op,
1347 &RdxKind](SmallVectorImpl<int> &ShuffleMask,
1348 Value *&TmpVec) -> void {
1349 Value *Shuf = Builder.CreateShuffleVector(V: TmpVec, Mask: ShuffleMask, Name: "rdx.shuf");
1350 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1351 TmpVec = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: TmpVec, RHS: Shuf,
1352 Name: "bin.rdx");
1353 } else {
1354 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1355 "Invalid min/max");
1356 TmpVec = createMinMaxOp(Builder, RK: RdxKind, Left: TmpVec, Right: Shuf);
1357 }
1358 };
1359
1360 Value *TmpVec = Src;
1361 if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1362 SmallVector<int, 32> ShuffleMask(VF);
1363 for (unsigned stride = 1; stride < VF; stride <<= 1) {
1364 // Initialise the mask with undef.
1365 llvm::fill(Range&: ShuffleMask, Value: -1);
1366 for (unsigned j = 0; j < VF; j += stride << 1) {
1367 ShuffleMask[j] = j + stride;
1368 }
1369 BuildShuffledOp(ShuffleMask, TmpVec);
1370 }
1371 } else {
1372 SmallVector<int, 32> ShuffleMask(VF);
1373 for (unsigned i = VF; i != 1; i >>= 1) {
1374 // Move the upper half of the vector to the lower half.
1375 for (unsigned j = 0; j != i / 2; ++j)
1376 ShuffleMask[j] = i / 2 + j;
1377
1378 // Fill the rest of the mask with undef.
1379 std::fill(first: &ShuffleMask[i / 2], last: ShuffleMask.end(), value: -1);
1380 BuildShuffledOp(ShuffleMask, TmpVec);
1381 }
1382 }
1383 // The result is in the first element of the vector.
1384 return Builder.CreateExtractElement(Vec: TmpVec, Idx: Builder.getInt32(C: 0));
1385}
1386
1387Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
1388 Value *InitVal, PHINode *OrigPhi) {
1389 Value *NewVal = nullptr;
1390
1391 // First use the original phi to determine the new value we're trying to
1392 // select from in the loop.
1393 SelectInst *SI = nullptr;
1394 for (auto *U : OrigPhi->users()) {
1395 if ((SI = dyn_cast<SelectInst>(Val: U)))
1396 break;
1397 }
1398 assert(SI && "One user of the original phi should be a select");
1399
1400 if (SI->getTrueValue() == OrigPhi)
1401 NewVal = SI->getFalseValue();
1402 else {
1403 assert(SI->getFalseValue() == OrigPhi &&
1404 "At least one input to the select should be the original Phi");
1405 NewVal = SI->getTrueValue();
1406 }
1407
1408 // If any predicate is true it means that we want to select the new value.
1409 Value *AnyOf =
1410 Src->getType()->isVectorTy() ? Builder.CreateOrReduce(Src) : Src;
1411 // The compares in the loop may yield poison, which propagates through the
1412 // bitwise ORs. Freeze it here before the condition is used.
1413 AnyOf = Builder.CreateFreeze(V: AnyOf);
1414 return Builder.CreateSelect(C: AnyOf, True: NewVal, False: InitVal, Name: "rdx.select");
1415}
1416
1417Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
1418 FastMathFlags Flags) {
1419 bool Negative = false;
1420 switch (RdxID) {
1421 default:
1422 llvm_unreachable("Expecting a reduction intrinsic");
1423 case Intrinsic::vector_reduce_add:
1424 case Intrinsic::vector_reduce_mul:
1425 case Intrinsic::vector_reduce_or:
1426 case Intrinsic::vector_reduce_xor:
1427 case Intrinsic::vector_reduce_and:
1428 case Intrinsic::vector_reduce_fadd:
1429 case Intrinsic::vector_reduce_fmul: {
1430 unsigned Opc = getArithmeticReductionInstruction(RdxID);
1431 return ConstantExpr::getBinOpIdentity(Opcode: Opc, Ty, AllowRHSConstant: false,
1432 NSZ: Flags.noSignedZeros());
1433 }
1434 case Intrinsic::vector_reduce_umax:
1435 case Intrinsic::vector_reduce_umin:
1436 case Intrinsic::vector_reduce_smin:
1437 case Intrinsic::vector_reduce_smax: {
1438 Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
1439 return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
1440 }
1441 case Intrinsic::vector_reduce_fmax:
1442 case Intrinsic::vector_reduce_fmaximum:
1443 Negative = true;
1444 [[fallthrough]];
1445 case Intrinsic::vector_reduce_fmin:
1446 case Intrinsic::vector_reduce_fminimum: {
1447 bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
1448 RdxID == Intrinsic::vector_reduce_fmaximum;
1449 const fltSemantics &Semantics = Ty->getFltSemantics();
1450 return (!Flags.noNaNs() && !PropagatesNaN)
1451 ? ConstantFP::getQNaN(Ty, Negative)
1452 : !Flags.noInfs()
1453 ? ConstantFP::getInfinity(Ty, Negative)
1454 : ConstantFP::get(Ty, V: APFloat::getLargest(Sem: Semantics, Negative));
1455 }
1456 }
1457}
1458
1459Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
1460 assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
1461 (FMF.noNaNs() && FMF.noSignedZeros())) &&
1462 "nnan, nsz is expected to be set for FP min/max reduction.");
1463 Intrinsic::ID RdxID = getReductionIntrinsicID(RK: K);
1464 return getReductionIdentity(RdxID, Ty: Tp, Flags: FMF);
1465}
1466
1467Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1468 RecurKind RdxKind) {
1469 auto *SrcVecEltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1470 auto getIdentity = [&]() {
1471 return getRecurrenceIdentity(K: RdxKind, Tp: SrcVecEltTy,
1472 FMF: Builder.getFastMathFlags());
1473 };
1474 switch (RdxKind) {
1475 case RecurKind::AddChainWithSubs:
1476 case RecurKind::Sub:
1477 case RecurKind::Add:
1478 case RecurKind::Mul:
1479 case RecurKind::And:
1480 case RecurKind::Or:
1481 case RecurKind::Xor:
1482 case RecurKind::SMax:
1483 case RecurKind::SMin:
1484 case RecurKind::UMax:
1485 case RecurKind::UMin:
1486 case RecurKind::FMax:
1487 case RecurKind::FMin:
1488 case RecurKind::FMinNum:
1489 case RecurKind::FMaxNum:
1490 case RecurKind::FMinimum:
1491 case RecurKind::FMaximum:
1492 case RecurKind::FMinimumNum:
1493 case RecurKind::FMaximumNum:
1494 return Builder.CreateUnaryIntrinsic(ID: getReductionIntrinsicID(RK: RdxKind), V: Src);
1495 case RecurKind::FMulAdd:
1496 case RecurKind::FAdd:
1497 return Builder.CreateFAddReduce(Acc: getIdentity(), Src);
1498 case RecurKind::FMul:
1499 return Builder.CreateFMulReduce(Acc: getIdentity(), Src);
1500 default:
1501 llvm_unreachable("Unhandled opcode");
1502 }
1503}
1504
1505Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1506 RecurKind Kind, Value *Mask, Value *EVL) {
1507 assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1508 !RecurrenceDescriptor::isFindRecurrenceKind(Kind) &&
1509 "AnyOf and FindIV reductions are not supported.");
1510 Intrinsic::ID Id = getReductionIntrinsicID(RK: Kind);
1511 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1512 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1513 "No VPIntrinsic for this reduction");
1514 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1515 Value *Iden = getRecurrenceIdentity(K: Kind, Tp: EltTy, FMF: Builder.getFastMathFlags());
1516 Value *Ops[] = {Iden, Src, Mask, EVL};
1517 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1518}
1519
1520Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
1521 Value *Src, Value *Start) {
1522 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1523 "Unexpected reduction kind");
1524 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1525 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1526
1527 return B.CreateFAddReduce(Acc: Start, Src);
1528}
1529
1530Value *llvm::createOrderedReduction(IRBuilderBase &Builder, RecurKind Kind,
1531 Value *Src, Value *Start, Value *Mask,
1532 Value *EVL) {
1533 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1534 "Unexpected reduction kind");
1535 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1536 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1537
1538 Intrinsic::ID Id = getReductionIntrinsicID(RK: RecurKind::FAdd);
1539 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1540 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1541 "No VPIntrinsic for this reduction");
1542 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1543 Value *Ops[] = {Start, Src, Mask, EVL};
1544 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1545}
1546
1547void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
1548 bool IncludeWrapFlags) {
1549 auto *VecOp = dyn_cast<Instruction>(Val: I);
1550 if (!VecOp)
1551 return;
1552 auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(Val: VL[0])
1553 : dyn_cast<Instruction>(Val: OpValue);
1554 if (!Intersection)
1555 return;
1556 const unsigned Opcode = Intersection->getOpcode();
1557 VecOp->copyIRFlags(V: Intersection, IncludeWrapFlags);
1558 for (auto *V : VL) {
1559 auto *Instr = dyn_cast<Instruction>(Val: V);
1560 if (!Instr)
1561 continue;
1562 if (OpValue == nullptr || Opcode == Instr->getOpcode())
1563 VecOp->andIRFlags(V);
1564 }
1565}
1566
1567bool llvm::isKnownNegativeInLoop(const SCEV *S, const Loop *L,
1568 ScalarEvolution &SE) {
1569 const SCEV *Zero = SE.getZero(Ty: S->getType());
1570 return SE.isAvailableAtLoopEntry(S, L) &&
1571 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLT, LHS: S, RHS: Zero);
1572}
1573
1574bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L,
1575 ScalarEvolution &SE) {
1576 const SCEV *Zero = SE.getZero(Ty: S->getType());
1577 return SE.isAvailableAtLoopEntry(S, L) &&
1578 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGE, LHS: S, RHS: Zero);
1579}
1580
1581bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L,
1582 ScalarEvolution &SE) {
1583 const SCEV *Zero = SE.getZero(Ty: S->getType());
1584 return SE.isAvailableAtLoopEntry(S, L) &&
1585 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGT, LHS: S, RHS: Zero);
1586}
1587
1588bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L,
1589 ScalarEvolution &SE) {
1590 const SCEV *Zero = SE.getZero(Ty: S->getType());
1591 return SE.isAvailableAtLoopEntry(S, L) &&
1592 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLE, LHS: S, RHS: Zero);
1593}
1594
1595bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1596 bool Signed) {
1597 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1598 APInt Min = Signed ? APInt::getSignedMinValue(numBits: BitWidth) :
1599 APInt::getMinValue(numBits: BitWidth);
1600 auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
1601 return SE.isAvailableAtLoopEntry(S, L) &&
1602 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1603 RHS: SE.getConstant(Val: Min));
1604}
1605
1606bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1607 bool Signed) {
1608 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1609 APInt Max = Signed ? APInt::getSignedMaxValue(numBits: BitWidth) :
1610 APInt::getMaxValue(numBits: BitWidth);
1611 auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1612 return SE.isAvailableAtLoopEntry(S, L) &&
1613 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1614 RHS: SE.getConstant(Val: Max));
1615}
1616
1617//===----------------------------------------------------------------------===//
1618// rewriteLoopExitValues - Optimize IV users outside the loop.
1619// As a side effect, reduces the amount of IV processing within the loop.
1620//===----------------------------------------------------------------------===//
1621
1622static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
1623 SmallPtrSet<const Instruction *, 8> Visited;
1624 SmallVector<const Instruction *, 8> WorkList;
1625 Visited.insert(Ptr: I);
1626 WorkList.push_back(Elt: I);
1627 while (!WorkList.empty()) {
1628 const Instruction *Curr = WorkList.pop_back_val();
1629 // This use is outside the loop, nothing to do.
1630 if (!L->contains(Inst: Curr))
1631 continue;
1632 // Do we assume it is a "hard" use which will not be eliminated easily?
1633 if (Curr->mayHaveSideEffects())
1634 return true;
1635 // Otherwise, add all its users to worklist.
1636 for (const auto *U : Curr->users()) {
1637 auto *UI = cast<Instruction>(Val: U);
1638 if (Visited.insert(Ptr: UI).second)
1639 WorkList.push_back(Elt: UI);
1640 }
1641 }
1642 return false;
1643}
1644
1645// Collect information about PHI nodes which can be transformed in
1646// rewriteLoopExitValues.
1647struct RewritePhi {
1648 PHINode *PN; // For which PHI node is this replacement?
1649 unsigned Ith; // For which incoming value?
1650 const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
1651 Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
1652 bool HighCost; // Is this expansion a high-cost?
1653
1654 RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
1655 bool H)
1656 : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
1657 HighCost(H) {}
1658};
1659
1660// Check whether it is possible to delete the loop after rewriting exit
1661// value. If it is possible, ignore ReplaceExitValue and do rewriting
1662// aggressively.
1663static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) {
1664 BasicBlock *Preheader = L->getLoopPreheader();
1665 // If there is no preheader, the loop will not be deleted.
1666 if (!Preheader)
1667 return false;
1668
1669 // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1.
1670 // We obviate multiple ExitingBlocks case for simplicity.
1671 // TODO: If we see testcase with multiple ExitingBlocks can be deleted
1672 // after exit value rewriting, we can enhance the logic here.
1673 SmallVector<BasicBlock *, 4> ExitingBlocks;
1674 L->getExitingBlocks(ExitingBlocks);
1675 SmallVector<BasicBlock *, 8> ExitBlocks;
1676 L->getUniqueExitBlocks(ExitBlocks);
1677 if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
1678 return false;
1679
1680 BasicBlock *ExitBlock = ExitBlocks[0];
1681 BasicBlock::iterator BI = ExitBlock->begin();
1682 while (PHINode *P = dyn_cast<PHINode>(Val&: BI)) {
1683 Value *Incoming = P->getIncomingValueForBlock(BB: ExitingBlocks[0]);
1684
1685 // If the Incoming value of P is found in RewritePhiSet, we know it
1686 // could be rewritten to use a loop invariant value in transformation
1687 // phase later. Skip it in the loop invariant check below.
1688 bool found = false;
1689 for (const RewritePhi &Phi : RewritePhiSet) {
1690 unsigned i = Phi.Ith;
1691 if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
1692 found = true;
1693 break;
1694 }
1695 }
1696
1697 Instruction *I;
1698 if (!found && (I = dyn_cast<Instruction>(Val: Incoming)))
1699 if (!L->hasLoopInvariantOperands(I))
1700 return false;
1701
1702 ++BI;
1703 }
1704
1705 for (auto *BB : L->blocks())
1706 if (llvm::any_of(Range&: *BB, P: [](Instruction &I) {
1707 return I.mayHaveSideEffects();
1708 }))
1709 return false;
1710
1711 return true;
1712}
1713
1714/// Checks if it is safe to call InductionDescriptor::isInductionPHI for \p Phi,
1715/// and returns true if this Phi is an induction phi in the loop. When
1716/// isInductionPHI returns true, \p ID will be also be set by isInductionPHI.
1717static bool checkIsIndPhi(PHINode *Phi, Loop *L, ScalarEvolution *SE,
1718 InductionDescriptor &ID) {
1719 if (!Phi)
1720 return false;
1721 if (!L->getLoopPreheader())
1722 return false;
1723 if (Phi->getParent() != L->getHeader())
1724 return false;
1725 return InductionDescriptor::isInductionPHI(Phi, L, SE, D&: ID);
1726}
1727
1728int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
1729 ScalarEvolution *SE,
1730 const TargetTransformInfo *TTI,
1731 SCEVExpander &Rewriter, DominatorTree *DT,
1732 ReplaceExitVal ReplaceExitValue,
1733 SmallVector<WeakTrackingVH, 16> &DeadInsts) {
1734 // Check a pre-condition.
1735 assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1736 "Indvars did not preserve LCSSA!");
1737
1738 SmallVector<BasicBlock*, 8> ExitBlocks;
1739 L->getUniqueExitBlocks(ExitBlocks);
1740
1741 SmallVector<RewritePhi, 8> RewritePhiSet;
1742 // Find all values that are computed inside the loop, but used outside of it.
1743 // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan
1744 // the exit blocks of the loop to find them.
1745 for (BasicBlock *ExitBB : ExitBlocks) {
1746 // If there are no PHI nodes in this exit block, then no values defined
1747 // inside the loop are used on this path, skip it.
1748 PHINode *PN = dyn_cast<PHINode>(Val: ExitBB->begin());
1749 if (!PN) continue;
1750
1751 unsigned NumPreds = PN->getNumIncomingValues();
1752
1753 // Iterate over all of the PHI nodes.
1754 BasicBlock::iterator BBI = ExitBB->begin();
1755 while ((PN = dyn_cast<PHINode>(Val: BBI++))) {
1756 if (PN->use_empty())
1757 continue; // dead use, don't replace it
1758
1759 if (!SE->isSCEVable(Ty: PN->getType()))
1760 continue;
1761
1762 // Iterate over all of the values in all the PHI nodes.
1763 for (unsigned i = 0; i != NumPreds; ++i) {
1764 // If the value being merged in is not integer or is not defined
1765 // in the loop, skip it.
1766 Value *InVal = PN->getIncomingValue(i);
1767 if (!isa<Instruction>(Val: InVal))
1768 continue;
1769
1770 // If this pred is for a subloop, not L itself, skip it.
1771 if (LI->getLoopFor(BB: PN->getIncomingBlock(i)) != L)
1772 continue; // The Block is in a subloop, skip it.
1773
1774 // Check that InVal is defined in the loop.
1775 Instruction *Inst = cast<Instruction>(Val: InVal);
1776 if (!L->contains(Inst))
1777 continue;
1778
1779 // Find exit values which are induction variables in the loop, and are
1780 // unused in the loop, with the only use being the exit block PhiNode,
1781 // and the induction variable update binary operator.
1782 // The exit value can be replaced with the final value when it is cheap
1783 // to do so.
1784 if (ReplaceExitValue == UnusedIndVarInLoop) {
1785 InductionDescriptor ID;
1786 PHINode *IndPhi = dyn_cast<PHINode>(Val: Inst);
1787 if (IndPhi) {
1788 if (!checkIsIndPhi(Phi: IndPhi, L, SE, ID))
1789 continue;
1790 // This is an induction PHI. Check that the only users are PHI
1791 // nodes, and induction variable update binary operators.
1792 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1793 if (!isa<PHINode>(Val: U) && !isa<BinaryOperator>(Val: U))
1794 return true;
1795 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: U);
1796 if (B && B != ID.getInductionBinOp())
1797 return true;
1798 return false;
1799 }))
1800 continue;
1801 } else {
1802 // If it is not an induction phi, it must be an induction update
1803 // binary operator with an induction phi user.
1804 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: Inst);
1805 if (!B)
1806 continue;
1807 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1808 PHINode *Phi = dyn_cast<PHINode>(Val: U);
1809 if (Phi != PN && !checkIsIndPhi(Phi, L, SE, ID))
1810 return true;
1811 return false;
1812 }))
1813 continue;
1814 if (B != ID.getInductionBinOp())
1815 continue;
1816 }
1817 }
1818
1819 // Okay, this instruction has a user outside of the current loop
1820 // and varies predictably *inside* the loop. Evaluate the value it
1821 // contains when the loop exits, if possible. We prefer to start with
1822 // expressions which are true for all exits (so as to maximize
1823 // expression reuse by the SCEVExpander), but resort to per-exit
1824 // evaluation if that fails.
1825 const SCEV *ExitValue = SE->getSCEVAtScope(V: Inst, L: L->getParentLoop());
1826 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1827 !SE->isLoopInvariant(S: ExitValue, L) ||
1828 !Rewriter.isSafeToExpand(S: ExitValue)) {
1829 // TODO: This should probably be sunk into SCEV in some way; maybe a
1830 // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
1831 // most SCEV expressions and other recurrence types (e.g. shift
1832 // recurrences). Is there existing code we can reuse?
1833 const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: PN->getIncomingBlock(i));
1834 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
1835 continue;
1836 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Inst)))
1837 if (AddRec->getLoop() == L)
1838 ExitValue = AddRec->evaluateAtIteration(It: ExitCount, SE&: *SE);
1839 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1840 !SE->isLoopInvariant(S: ExitValue, L) ||
1841 !Rewriter.isSafeToExpand(S: ExitValue))
1842 continue;
1843 }
1844
1845 // Computing the value outside of the loop brings no benefit if it is
1846 // definitely used inside the loop in a way which can not be optimized
1847 // away. Avoid doing so unless we know we have a value which computes
1848 // the ExitValue already. TODO: This should be merged into SCEV
1849 // expander to leverage its knowledge of existing expressions.
1850 if (ReplaceExitValue != AlwaysRepl && !isa<SCEVConstant>(Val: ExitValue) &&
1851 !isa<SCEVUnknown>(Val: ExitValue) && hasHardUserWithinLoop(L, I: Inst))
1852 continue;
1853
1854 // Check if expansions of this SCEV would count as being high cost.
1855 bool HighCost = Rewriter.isHighCostExpansion(
1856 Exprs: ExitValue, L, Budget: SCEVCheapExpansionBudget, TTI, At: Inst);
1857
1858 // Note that we must not perform expansions until after
1859 // we query *all* the costs, because if we perform temporary expansion
1860 // inbetween, one that we might not intend to keep, said expansion
1861 // *may* affect cost calculation of the next SCEV's we'll query,
1862 // and next SCEV may errneously get smaller cost.
1863
1864 // Collect all the candidate PHINodes to be rewritten.
1865 Instruction *InsertPt =
1866 (isa<PHINode>(Val: Inst) || isa<LandingPadInst>(Val: Inst)) ?
1867 &*Inst->getParent()->getFirstInsertionPt() : Inst;
1868 RewritePhiSet.emplace_back(Args&: PN, Args&: i, Args&: ExitValue, Args&: InsertPt, Args&: HighCost);
1869 }
1870 }
1871 }
1872
1873 // TODO: evaluate whether it is beneficial to change how we calculate
1874 // high-cost: if we have SCEV 'A' which we know we will expand, should we
1875 // calculate the cost of other SCEV's after expanding SCEV 'A', thus
1876 // potentially giving cost bonus to those other SCEV's?
1877
1878 bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
1879 int NumReplaced = 0;
1880
1881 // Transformation.
1882 for (const RewritePhi &Phi : RewritePhiSet) {
1883 PHINode *PN = Phi.PN;
1884
1885 // Only do the rewrite when the ExitValue can be expanded cheaply.
1886 // If LoopCanBeDel is true, rewrite exit value aggressively.
1887 if ((ReplaceExitValue == OnlyCheapRepl ||
1888 ReplaceExitValue == UnusedIndVarInLoop) &&
1889 !LoopCanBeDel && Phi.HighCost)
1890 continue;
1891
1892 Value *ExitVal = Rewriter.expandCodeFor(
1893 SH: Phi.ExpansionSCEV, Ty: Phi.PN->getType(), I: Phi.ExpansionPoint);
1894
1895 LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal
1896 << '\n'
1897 << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
1898
1899#ifndef NDEBUG
1900 // If we reuse an instruction from a loop which is neither L nor one of
1901 // its containing loops, we end up breaking LCSSA form for this loop by
1902 // creating a new use of its instruction.
1903 if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
1904 if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1905 if (EVL != L)
1906 assert(EVL->contains(L) && "LCSSA breach detected!");
1907#endif
1908
1909 NumReplaced++;
1910 Instruction *Inst = cast<Instruction>(Val: PN->getIncomingValue(i: Phi.Ith));
1911 PN->setIncomingValue(i: Phi.Ith, V: ExitVal);
1912 // It's necessary to tell ScalarEvolution about this explicitly so that
1913 // it can walk the def-use list and forget all SCEVs, as it may not be
1914 // watching the PHI itself. Once the new exit value is in place, there
1915 // may not be a def-use connection between the loop and every instruction
1916 // which got a SCEVAddRecExpr for that loop.
1917 SE->forgetValue(V: PN);
1918
1919 // If this instruction is dead now, delete it. Don't do it now to avoid
1920 // invalidating iterators.
1921 if (isInstructionTriviallyDead(I: Inst, TLI))
1922 DeadInsts.push_back(Elt: Inst);
1923
1924 // Replace PN with ExitVal if that is legal and does not break LCSSA.
1925 if (PN->getNumIncomingValues() == 1 &&
1926 LI->replacementPreservesLCSSAForm(From: PN, To: ExitVal)) {
1927 PN->replaceAllUsesWith(V: ExitVal);
1928 PN->eraseFromParent();
1929 }
1930 }
1931
1932 // The insertion point instruction may have been deleted; clear it out
1933 // so that the rewriter doesn't trip over it later.
1934 Rewriter.clearInsertPoint();
1935 return NumReplaced;
1936}
1937
1938/// Utility that implements appending of loops onto a worklist.
1939/// Loops are added in preorder (analogous for reverse postorder for trees),
1940/// and the worklist is processed LIFO.
1941template <typename RangeT>
1942void llvm::appendReversedLoopsToWorklist(
1943 RangeT &&Loops, SmallPriorityWorklist<Loop *, 4> &Worklist) {
1944 // We use an internal worklist to build up the preorder traversal without
1945 // recursion.
1946 SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist;
1947
1948 // We walk the initial sequence of loops in reverse because we generally want
1949 // to visit defs before uses and the worklist is LIFO.
1950 for (Loop *RootL : Loops) {
1951 assert(PreOrderLoops.empty() && "Must start with an empty preorder walk.");
1952 assert(PreOrderWorklist.empty() &&
1953 "Must start with an empty preorder walk worklist.");
1954 PreOrderWorklist.push_back(Elt: RootL);
1955 do {
1956 Loop *L = PreOrderWorklist.pop_back_val();
1957 PreOrderWorklist.append(in_start: L->begin(), in_end: L->end());
1958 PreOrderLoops.push_back(Elt: L);
1959 } while (!PreOrderWorklist.empty());
1960
1961 Worklist.insert(Input: std::move(PreOrderLoops));
1962 PreOrderLoops.clear();
1963 }
1964}
1965
1966template <typename RangeT>
1967void llvm::appendLoopsToWorklist(RangeT &&Loops,
1968 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1969 appendReversedLoopsToWorklist(reverse(Loops), Worklist);
1970}
1971
1972template LLVM_EXPORT_TEMPLATE void
1973llvm::appendLoopsToWorklist<ArrayRef<Loop *> &>(
1974 ArrayRef<Loop *> &Loops, SmallPriorityWorklist<Loop *, 4> &Worklist);
1975
1976template LLVM_EXPORT_TEMPLATE void
1977llvm::appendLoopsToWorklist<Loop &>(Loop &L,
1978 SmallPriorityWorklist<Loop *, 4> &Worklist);
1979
1980void llvm::appendLoopsToWorklist(LoopInfo &LI,
1981 SmallPriorityWorklist<Loop *, 4> &Worklist) {
1982 appendReversedLoopsToWorklist(Loops&: LI, Worklist);
1983}
1984
1985Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
1986 LoopInfo *LI, LPPassManager *LPM) {
1987 Loop &New = *LI->AllocateLoop();
1988 if (PL)
1989 PL->addChildLoop(NewChild: &New);
1990 else
1991 LI->addTopLevelLoop(New: &New);
1992
1993 if (LPM)
1994 LPM->addLoop(L&: New);
1995
1996 // Add all of the blocks in L to the new loop.
1997 for (BasicBlock *BB : L->blocks())
1998 if (LI->getLoopFor(BB) == L)
1999 New.addBasicBlockToLoop(NewBB: cast<BasicBlock>(Val&: VM[BB]), LI&: *LI);
2000
2001 // Add all of the subloops to the new loop.
2002 for (Loop *I : *L)
2003 cloneLoop(L: I, PL: &New, VM, LI, LPM);
2004
2005 return &New;
2006}
2007
2008/// IR Values for the lower and upper bounds of a pointer evolution. We
2009/// need to use value-handles because SCEV expansion can invalidate previously
2010/// expanded values. Thus expansion of a pointer can invalidate the bounds for
2011/// a previous one.
2012struct PointerBounds {
2013 TrackingVH<Value> Start;
2014 TrackingVH<Value> End;
2015 Value *StrideToCheck;
2016};
2017
2018/// Expand code for the lower and upper bound of the pointer group \p CG
2019/// in \p TheLoop. \return the values for the bounds.
2020static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
2021 Loop *TheLoop, Instruction *Loc,
2022 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2023 LLVMContext &Ctx = Loc->getContext();
2024 Type *PtrArithTy = PointerType::get(C&: Ctx, AddressSpace: CG->AddressSpace);
2025
2026 Value *Start = nullptr, *End = nullptr;
2027 LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
2028 const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
2029
2030 // If the Low and High values are themselves loop-variant, then we may want
2031 // to expand the range to include those covered by the outer loop as well.
2032 // There is a trade-off here with the advantage being that creating checks
2033 // using the expanded range permits the runtime memory checks to be hoisted
2034 // out of the outer loop. This reduces the cost of entering the inner loop,
2035 // which can be significant for low trip counts. The disadvantage is that
2036 // there is a chance we may now never enter the vectorized inner loop,
2037 // whereas using a restricted range check could have allowed us to enter at
2038 // least once. This is why the behaviour is not currently the default and is
2039 // controlled by the parameter 'HoistRuntimeChecks'.
2040 if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
2041 isa<SCEVAddRecExpr>(Val: High) && isa<SCEVAddRecExpr>(Val: Low)) {
2042 auto *HighAR = cast<SCEVAddRecExpr>(Val: High);
2043 auto *LowAR = cast<SCEVAddRecExpr>(Val: Low);
2044 const Loop *OuterLoop = TheLoop->getParentLoop();
2045 ScalarEvolution &SE = *Exp.getSE();
2046 const SCEV *Recur = LowAR->getStepRecurrence(SE);
2047 if (Recur == HighAR->getStepRecurrence(SE) &&
2048 HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
2049 BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
2050 const SCEV *OuterExitCount = SE.getExitCount(L: OuterLoop, ExitingBlock: OuterLoopLatch);
2051 if (!isa<SCEVCouldNotCompute>(Val: OuterExitCount) &&
2052 OuterExitCount->getType()->isIntegerTy()) {
2053 const SCEV *NewHigh =
2054 cast<SCEVAddRecExpr>(Val: High)->evaluateAtIteration(It: OuterExitCount, SE);
2055 if (!isa<SCEVCouldNotCompute>(Val: NewHigh)) {
2056 LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
2057 "outer loop in order to permit hoisting\n");
2058 High = NewHigh;
2059 Low = cast<SCEVAddRecExpr>(Val: Low)->getStart();
2060 // If there is a possibility that the stride is negative then we have
2061 // to generate extra checks to ensure the stride is positive.
2062 if (!SE.isKnownNonNegative(
2063 S: SE.applyLoopGuards(Expr: Recur, L: HighAR->getLoop()))) {
2064 Stride = Recur;
2065 LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
2066 "positive: "
2067 << *Stride << '\n');
2068 }
2069 }
2070 }
2071 }
2072 }
2073
2074 Start = Exp.expandCodeFor(SH: Low, Ty: PtrArithTy, I: Loc);
2075 End = Exp.expandCodeFor(SH: High, Ty: PtrArithTy, I: Loc);
2076 if (CG->NeedsFreeze) {
2077 IRBuilder<> Builder(Loc);
2078 Start = Builder.CreateFreeze(V: Start, Name: Start->getName() + ".fr");
2079 End = Builder.CreateFreeze(V: End, Name: End->getName() + ".fr");
2080 }
2081 Value *StrideVal =
2082 Stride ? Exp.expandCodeFor(SH: Stride, Ty: Stride->getType(), I: Loc) : nullptr;
2083 LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
2084 return {.Start: Start, .End: End, .StrideToCheck: StrideVal};
2085}
2086
2087/// Turns a collection of checks into a collection of expanded upper and
2088/// lower bounds for both pointers in the check.
2089static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
2090expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
2091 Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
2092 SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
2093
2094 // Here we're relying on the SCEV Expander's cache to only emit code for the
2095 // same bounds once.
2096 transform(Range: PointerChecks, d_first: std::back_inserter(x&: ChecksWithBounds),
2097 F: [&](const RuntimePointerCheck &Check) {
2098 PointerBounds First = expandBounds(CG: Check.first, TheLoop: L, Loc, Exp,
2099 HoistRuntimeChecks),
2100 Second = expandBounds(CG: Check.second, TheLoop: L, Loc, Exp,
2101 HoistRuntimeChecks);
2102 return std::make_pair(x&: First, y&: Second);
2103 });
2104
2105 return ChecksWithBounds;
2106}
2107
2108Value *llvm::addRuntimeChecks(
2109 Instruction *Loc, Loop *TheLoop,
2110 const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
2111 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2112 // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
2113 // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
2114 auto ExpandedChecks =
2115 expandBounds(PointerChecks, L: TheLoop, Loc, Exp, HoistRuntimeChecks);
2116
2117 LLVMContext &Ctx = Loc->getContext();
2118 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2119 ChkBuilder.SetInsertPoint(Loc);
2120 // Our instructions might fold to a constant.
2121 Value *MemoryRuntimeCheck = nullptr;
2122
2123 for (const auto &[A, B] : ExpandedChecks) {
2124 // Check if two pointers (A and B) conflict where conflict is computed as:
2125 // start(A) <= end(B) && start(B) <= end(A)
2126
2127 assert((A.Start->getType()->getPointerAddressSpace() ==
2128 B.End->getType()->getPointerAddressSpace()) &&
2129 (B.Start->getType()->getPointerAddressSpace() ==
2130 A.End->getType()->getPointerAddressSpace()) &&
2131 "Trying to bounds check pointers with different address spaces");
2132
2133 // [A|B].Start points to the first accessed byte under base [A|B].
2134 // [A|B].End points to the last accessed byte, plus one.
2135 // There is no conflict when the intervals are disjoint:
2136 // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
2137 //
2138 // bound0 = (B.Start < A.End)
2139 // bound1 = (A.Start < B.End)
2140 // IsConflict = bound0 & bound1
2141 Value *Cmp0 = ChkBuilder.CreateICmpULT(LHS: A.Start, RHS: B.End, Name: "bound0");
2142 Value *Cmp1 = ChkBuilder.CreateICmpULT(LHS: B.Start, RHS: A.End, Name: "bound1");
2143 Value *IsConflict = ChkBuilder.CreateAnd(LHS: Cmp0, RHS: Cmp1, Name: "found.conflict");
2144 if (A.StrideToCheck) {
2145 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2146 LHS: A.StrideToCheck, RHS: ConstantInt::get(Ty: A.StrideToCheck->getType(), V: 0),
2147 Name: "stride.check");
2148 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2149 }
2150 if (B.StrideToCheck) {
2151 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2152 LHS: B.StrideToCheck, RHS: ConstantInt::get(Ty: B.StrideToCheck->getType(), V: 0),
2153 Name: "stride.check");
2154 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2155 }
2156 if (MemoryRuntimeCheck) {
2157 IsConflict =
2158 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2159 }
2160 MemoryRuntimeCheck = IsConflict;
2161 }
2162
2163 Exp.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2164 return MemoryRuntimeCheck;
2165}
2166
2167namespace {
2168/// Rewriter to replace SCEVPtrToIntExpr with SCEVPtrToAddrExpr when the result
2169/// type matches the pointer address type. This allows expressions mixing
2170/// ptrtoint and ptrtoaddr to simplify properly.
2171struct SCEVPtrToAddrRewriter : SCEVRewriteVisitor<SCEVPtrToAddrRewriter> {
2172 const DataLayout &DL;
2173 SCEVPtrToAddrRewriter(ScalarEvolution &SE, const DataLayout &DL)
2174 : SCEVRewriteVisitor(SE), DL(DL) {}
2175
2176 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *E) {
2177 const SCEV *Op = visit(S: E->getOperand());
2178 if (E->getType() == DL.getAddressType(PtrTy: E->getOperand()->getType()))
2179 return SE.getPtrToAddrExpr(Op);
2180 return Op == E->getOperand() ? E : SE.getPtrToIntExpr(Op, Ty: E->getType());
2181 }
2182};
2183} // namespace
2184
2185Value *llvm::addDiffRuntimeChecks(
2186 Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
2187 function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
2188
2189 LLVMContext &Ctx = Loc->getContext();
2190 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2191 ChkBuilder.SetInsertPoint(Loc);
2192 // Our instructions might fold to a constant.
2193 Value *MemoryRuntimeCheck = nullptr;
2194
2195 auto &SE = *Expander.getSE();
2196 const DataLayout &DL = Loc->getDataLayout();
2197 SCEVPtrToAddrRewriter Rewriter(SE, DL);
2198 // Map to keep track of created compares, The key is the pair of operands for
2199 // the compare, to allow detecting and re-using redundant compares.
2200 DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2201 for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2202 Type *Ty = SinkStart->getType();
2203 // Compute VF * IC * AccessSize.
2204 auto *VFTimesICTimesSize =
2205 ChkBuilder.CreateMul(LHS: GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2206 RHS: ConstantInt::get(Ty, V: IC * AccessSize));
2207 const SCEV *SinkStartRewritten = Rewriter.visit(S: SinkStart);
2208 const SCEV *SrcStartRewritten = Rewriter.visit(S: SrcStart);
2209 Value *Diff = Expander.expandCodeFor(
2210 SH: SE.getMinusSCEV(LHS: SinkStartRewritten, RHS: SrcStartRewritten), Ty, I: Loc);
2211
2212 // Check if the same compare has already been created earlier. In that case,
2213 // there is no need to check it again.
2214 Value *IsConflict = SeenCompares.lookup(Val: {Diff, VFTimesICTimesSize});
2215 if (IsConflict)
2216 continue;
2217
2218 IsConflict =
2219 ChkBuilder.CreateICmpULT(LHS: Diff, RHS: VFTimesICTimesSize, Name: "diff.check");
2220 SeenCompares.insert(KV: {{Diff, VFTimesICTimesSize}, IsConflict});
2221 if (NeedsFreeze)
2222 IsConflict =
2223 ChkBuilder.CreateFreeze(V: IsConflict, Name: IsConflict->getName() + ".fr");
2224 if (MemoryRuntimeCheck) {
2225 IsConflict =
2226 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2227 }
2228 MemoryRuntimeCheck = IsConflict;
2229 }
2230
2231 Expander.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2232 return MemoryRuntimeCheck;
2233}
2234
2235std::optional<IVConditionInfo>
2236llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold,
2237 const MemorySSA &MSSA, AAResults &AA) {
2238 auto *TI = dyn_cast<CondBrInst>(Val: L.getHeader()->getTerminator());
2239 if (!TI)
2240 return {};
2241
2242 auto *CondI = dyn_cast<Instruction>(Val: TI->getCondition());
2243 // The case with the condition outside the loop should already be handled
2244 // earlier.
2245 // Allow CmpInst and TruncInsts as they may be users of load instructions
2246 // and have potential for partial unswitching
2247 if (!CondI || !isa<CmpInst, TruncInst>(Val: CondI) || !L.contains(Inst: CondI))
2248 return {};
2249
2250 SmallVector<Instruction *> InstToDuplicate;
2251 InstToDuplicate.push_back(Elt: CondI);
2252
2253 SmallVector<Value *, 4> WorkList;
2254 WorkList.append(in_start: CondI->op_begin(), in_end: CondI->op_end());
2255
2256 SmallVector<MemoryAccess *, 4> AccessesToCheck;
2257 SmallVector<MemoryLocation, 4> AccessedLocs;
2258 while (!WorkList.empty()) {
2259 Instruction *I = dyn_cast<Instruction>(Val: WorkList.pop_back_val());
2260 if (!I || !L.contains(Inst: I))
2261 continue;
2262
2263 // TODO: support additional instructions.
2264 if (!isa<LoadInst>(Val: I) && !isa<GetElementPtrInst>(Val: I))
2265 return {};
2266
2267 // Do not duplicate volatile and atomic loads.
2268 if (auto *LI = dyn_cast<LoadInst>(Val: I))
2269 if (LI->isVolatile() || LI->isAtomic())
2270 return {};
2271
2272 InstToDuplicate.push_back(Elt: I);
2273 if (MemoryAccess *MA = MSSA.getMemoryAccess(I)) {
2274 if (auto *MemUse = dyn_cast_or_null<MemoryUse>(Val: MA)) {
2275 // Queue the defining access to check for alias checks.
2276 AccessesToCheck.push_back(Elt: MemUse->getDefiningAccess());
2277 AccessedLocs.push_back(Elt: MemoryLocation::get(Inst: I));
2278 } else {
2279 // MemoryDefs may clobber the location or may be atomic memory
2280 // operations. Bail out.
2281 return {};
2282 }
2283 }
2284 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
2285 }
2286
2287 if (InstToDuplicate.empty())
2288 return {};
2289
2290 SmallVector<BasicBlock *, 4> ExitingBlocks;
2291 L.getExitingBlocks(ExitingBlocks);
2292 auto HasNoClobbersOnPath =
2293 [&L, &AA, &AccessedLocs, &ExitingBlocks, &InstToDuplicate,
2294 MSSAThreshold](BasicBlock *Succ, BasicBlock *Header,
2295 SmallVector<MemoryAccess *, 4> AccessesToCheck)
2296 -> std::optional<IVConditionInfo> {
2297 IVConditionInfo Info;
2298 // First, collect all blocks in the loop that are on a patch from Succ
2299 // to the header.
2300 SmallVector<BasicBlock *, 4> WorkList;
2301 WorkList.push_back(Elt: Succ);
2302 WorkList.push_back(Elt: Header);
2303 SmallPtrSet<BasicBlock *, 4> Seen;
2304 Seen.insert(Ptr: Header);
2305 Info.PathIsNoop &=
2306 all_of(Range&: *Header, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2307
2308 while (!WorkList.empty()) {
2309 BasicBlock *Current = WorkList.pop_back_val();
2310 if (!L.contains(BB: Current))
2311 continue;
2312 const auto &SeenIns = Seen.insert(Ptr: Current);
2313 if (!SeenIns.second)
2314 continue;
2315
2316 Info.PathIsNoop &= all_of(
2317 Range&: *Current, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2318 WorkList.append(in_start: succ_begin(BB: Current), in_end: succ_end(BB: Current));
2319 }
2320
2321 // Require at least 2 blocks on a path through the loop. This skips
2322 // paths that directly exit the loop.
2323 if (Seen.size() < 2)
2324 return {};
2325
2326 // Next, check if there are any MemoryDefs that are on the path through
2327 // the loop (in the Seen set) and they may-alias any of the locations in
2328 // AccessedLocs. If that is the case, they may modify the condition and
2329 // partial unswitching is not possible.
2330 SmallPtrSet<MemoryAccess *, 4> SeenAccesses;
2331 while (!AccessesToCheck.empty()) {
2332 MemoryAccess *Current = AccessesToCheck.pop_back_val();
2333 auto SeenI = SeenAccesses.insert(Ptr: Current);
2334 if (!SeenI.second || !Seen.contains(Ptr: Current->getBlock()))
2335 continue;
2336
2337 // Bail out if exceeded the threshold.
2338 if (SeenAccesses.size() >= MSSAThreshold)
2339 return {};
2340
2341 // MemoryUse are read-only accesses.
2342 if (isa<MemoryUse>(Val: Current))
2343 continue;
2344
2345 // For a MemoryDef, check if is aliases any of the location feeding
2346 // the original condition.
2347 if (auto *CurrentDef = dyn_cast<MemoryDef>(Val: Current)) {
2348 if (any_of(Range&: AccessedLocs, P: [&AA, CurrentDef](MemoryLocation &Loc) {
2349 return isModSet(
2350 MRI: AA.getModRefInfo(I: CurrentDef->getMemoryInst(), OptLoc: Loc));
2351 }))
2352 return {};
2353 }
2354
2355 for (Use &U : Current->uses())
2356 AccessesToCheck.push_back(Elt: cast<MemoryAccess>(Val: U.getUser()));
2357 }
2358
2359 // We could also allow loops with known trip counts without mustprogress,
2360 // but ScalarEvolution may not be available.
2361 Info.PathIsNoop &= isMustProgress(L: &L);
2362
2363 // If the path is considered a no-op so far, check if it reaches a
2364 // single exit block without any phis. This ensures no values from the
2365 // loop are used outside of the loop.
2366 if (Info.PathIsNoop) {
2367 for (auto *Exiting : ExitingBlocks) {
2368 if (!Seen.contains(Ptr: Exiting))
2369 continue;
2370 for (auto *Succ : successors(BB: Exiting)) {
2371 if (L.contains(BB: Succ))
2372 continue;
2373
2374 Info.PathIsNoop &= Succ->phis().empty() &&
2375 (!Info.ExitForPath || Info.ExitForPath == Succ);
2376 if (!Info.PathIsNoop)
2377 break;
2378 assert((!Info.ExitForPath || Info.ExitForPath == Succ) &&
2379 "cannot have multiple exit blocks");
2380 Info.ExitForPath = Succ;
2381 }
2382 }
2383 }
2384 if (!Info.ExitForPath)
2385 Info.PathIsNoop = false;
2386
2387 Info.InstToDuplicate = std::move(InstToDuplicate);
2388 return Info;
2389 };
2390
2391 // If we branch to the same successor, partial unswitching will not be
2392 // beneficial.
2393 if (TI->getSuccessor(i: 0) == TI->getSuccessor(i: 1))
2394 return {};
2395
2396 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 0), L.getHeader(),
2397 AccessesToCheck)) {
2398 Info->KnownValue = ConstantInt::getTrue(Context&: TI->getContext());
2399 return Info;
2400 }
2401 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 1), L.getHeader(),
2402 AccessesToCheck)) {
2403 Info->KnownValue = ConstantInt::getFalse(Context&: TI->getContext());
2404 return Info;
2405 }
2406
2407 return {};
2408}
2409