1//===-- LoopUtils.cpp - Loop Utility functions -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines common loop utility functions.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Utils/LoopUtils.h"
14#include "llvm/ADT/DenseSet.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/ScopeExit.h"
17#include "llvm/ADT/SetVector.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/Analysis/AliasAnalysis.h"
21#include "llvm/Analysis/BasicAliasAnalysis.h"
22#include "llvm/Analysis/DomTreeUpdater.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/LoopAccessAnalysis.h"
26#include "llvm/Analysis/LoopInfo.h"
27#include "llvm/Analysis/LoopPass.h"
28#include "llvm/Analysis/MemorySSA.h"
29#include "llvm/Analysis/MemorySSAUpdater.h"
30#include "llvm/Analysis/ScalarEvolution.h"
31#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
32#include "llvm/Analysis/ScalarEvolutionExpressions.h"
33#include "llvm/IR/DIBuilder.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/Instructions.h"
36#include "llvm/IR/IntrinsicInst.h"
37#include "llvm/IR/MDBuilder.h"
38#include "llvm/IR/Module.h"
39#include "llvm/IR/PatternMatch.h"
40#include "llvm/IR/ProfDataUtils.h"
41#include "llvm/IR/ValueHandle.h"
42#include "llvm/InitializePasses.h"
43#include "llvm/Pass.h"
44#include "llvm/Support/Compiler.h"
45#include "llvm/Support/Debug.h"
46#include "llvm/Transforms/Utils/BasicBlockUtils.h"
47#include "llvm/Transforms/Utils/Local.h"
48#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
49
50using namespace llvm;
51using namespace llvm::PatternMatch;
52
53#define DEBUG_TYPE "loop-utils"
54
55static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced";
56static const char *LLVMLoopDisableLICM = "llvm.licm.disable";
57namespace llvm {
58extern cl::opt<bool> ProfcheckDisableMetadataFixes;
59} // namespace llvm
60
61bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
62 MemorySSAUpdater *MSSAU,
63 bool PreserveLCSSA) {
64 bool Changed = false;
65
66 // We re-use a vector for the in-loop predecesosrs.
67 SmallVector<BasicBlock *, 4> InLoopPredecessors;
68
69 auto RewriteExit = [&](BasicBlock *BB) {
70 assert(InLoopPredecessors.empty() &&
71 "Must start with an empty predecessors list!");
72 llvm::scope_exit Cleanup([&] { InLoopPredecessors.clear(); });
73
74 // See if there are any non-loop predecessors of this exit block and
75 // keep track of the in-loop predecessors.
76 bool IsDedicatedExit = true;
77 for (auto *PredBB : predecessors(BB))
78 if (L->contains(BB: PredBB)) {
79 if (isa<IndirectBrInst>(Val: PredBB->getTerminator()))
80 // We cannot rewrite exiting edges from an indirectbr.
81 return false;
82
83 InLoopPredecessors.push_back(Elt: PredBB);
84 } else {
85 IsDedicatedExit = false;
86 }
87
88 assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!");
89
90 // Nothing to do if this is already a dedicated exit.
91 if (IsDedicatedExit)
92 return false;
93
94 auto *NewExitBB = SplitBlockPredecessors(
95 BB, Preds: InLoopPredecessors, Suffix: ".loopexit", DT, LI, MSSAU, PreserveLCSSA);
96
97 if (!NewExitBB)
98 LLVM_DEBUG(
99 dbgs() << "WARNING: Can't create a dedicated exit block for loop: "
100 << *L << "\n");
101 else
102 LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block "
103 << NewExitBB->getName() << "\n");
104 return true;
105 };
106
107 // Walk the exit blocks directly rather than building up a data structure for
108 // them, but only visit each one once.
109 SmallPtrSet<BasicBlock *, 4> Visited;
110 for (auto *BB : L->blocks())
111 for (auto *SuccBB : successors(BB)) {
112 // We're looking for exit blocks so skip in-loop successors.
113 if (L->contains(BB: SuccBB))
114 continue;
115
116 // Visit each exit block exactly once.
117 if (!Visited.insert(Ptr: SuccBB).second)
118 continue;
119
120 Changed |= RewriteExit(SuccBB);
121 }
122
123 return Changed;
124}
125
126/// Returns the instructions that use values defined in the loop.
127SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) {
128 SmallVector<Instruction *, 8> UsedOutside;
129
130 for (auto *Block : L->getBlocks())
131 // FIXME: I believe that this could use copy_if if the Inst reference could
132 // be adapted into a pointer.
133 for (auto &Inst : *Block) {
134 auto Users = Inst.users();
135 if (any_of(Range&: Users, P: [&](User *U) {
136 auto *Use = cast<Instruction>(Val: U);
137 return !L->contains(BB: Use->getParent());
138 }))
139 UsedOutside.push_back(Elt: &Inst);
140 }
141
142 return UsedOutside;
143}
144
145void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
146 // By definition, all loop passes need the LoopInfo analysis and the
147 // Dominator tree it depends on. Because they all participate in the loop
148 // pass manager, they must also preserve these.
149 AU.addRequired<DominatorTreeWrapperPass>();
150 AU.addPreserved<DominatorTreeWrapperPass>();
151 AU.addRequired<LoopInfoWrapperPass>();
152 AU.addPreserved<LoopInfoWrapperPass>();
153
154 // We must also preserve LoopSimplify and LCSSA. We locally access their IDs
155 // here because users shouldn't directly get them from this header.
156 extern char &LoopSimplifyID;
157 extern char &LCSSAID;
158 AU.addRequiredID(ID&: LoopSimplifyID);
159 AU.addPreservedID(ID&: LoopSimplifyID);
160 AU.addRequiredID(ID&: LCSSAID);
161 AU.addPreservedID(ID&: LCSSAID);
162 // This is used in the LPPassManager to perform LCSSA verification on passes
163 // which preserve lcssa form
164 AU.addRequired<LCSSAVerificationPass>();
165 AU.addPreserved<LCSSAVerificationPass>();
166
167 // Loop passes are designed to run inside of a loop pass manager which means
168 // that any function analyses they require must be required by the first loop
169 // pass in the manager (so that it is computed before the loop pass manager
170 // runs) and preserved by all loop pasess in the manager. To make this
171 // reasonably robust, the set needed for most loop passes is maintained here.
172 // If your loop pass requires an analysis not listed here, you will need to
173 // carefully audit the loop pass manager nesting structure that results.
174 AU.addRequired<AAResultsWrapperPass>();
175 AU.addPreserved<AAResultsWrapperPass>();
176 AU.addPreserved<BasicAAWrapperPass>();
177 AU.addPreserved<GlobalsAAWrapperPass>();
178 AU.addPreserved<SCEVAAWrapperPass>();
179 AU.addRequired<ScalarEvolutionWrapperPass>();
180 AU.addPreserved<ScalarEvolutionWrapperPass>();
181 // FIXME: When all loop passes preserve MemorySSA, it can be required and
182 // preserved here instead of the individual handling in each pass.
183}
184
185/// Manually defined generic "LoopPass" dependency initialization. This is used
186/// to initialize the exact set of passes from above in \c
187/// getLoopAnalysisUsage. It can be used within a loop pass's initialization
188/// with:
189///
190/// INITIALIZE_PASS_DEPENDENCY(LoopPass)
191///
192/// As-if "LoopPass" were a pass.
193void llvm::initializeLoopPassPass(PassRegistry &Registry) {
194 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
195 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
196 INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
197 INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
198 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
199 INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
200 INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
201 INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
202 INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
203 INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass)
204}
205
206/// Create MDNode for input string.
207static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) {
208 LLVMContext &Context = TheLoop->getHeader()->getContext();
209 Metadata *MDs[] = {
210 MDString::get(Context, Str: Name),
211 ConstantAsMetadata::get(C: ConstantInt::get(Ty: Type::getInt32Ty(C&: Context), V))};
212 return MDNode::get(Context, MDs);
213}
214
215/// Set input string into loop metadata by keeping other values intact.
216/// If the string is already in loop metadata update value if it is
217/// different.
218void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *StringMD,
219 unsigned V) {
220 SmallVector<Metadata *, 4> MDs(1);
221 // If the loop already has metadata, retain it.
222 MDNode *LoopID = TheLoop->getLoopID();
223 if (LoopID) {
224 for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) {
225 MDNode *Node = cast<MDNode>(Val: LoopID->getOperand(I: i));
226 // If it is of form key = value, try to parse it.
227 if (Node->getNumOperands() == 2) {
228 MDString *S = dyn_cast<MDString>(Val: Node->getOperand(I: 0));
229 if (S && S->getString() == StringMD) {
230 ConstantInt *IntMD =
231 mdconst::extract_or_null<ConstantInt>(MD: Node->getOperand(I: 1));
232 if (IntMD && IntMD->getSExtValue() == V)
233 // It is already in place. Do nothing.
234 return;
235 // We need to update the value, so just skip it here and it will
236 // be added after copying other existed nodes.
237 continue;
238 }
239 }
240 MDs.push_back(Elt: Node);
241 }
242 }
243 // Add new metadata.
244 MDs.push_back(Elt: createStringMetadata(TheLoop, Name: StringMD, V));
245 // Replace current metadata node with new one.
246 LLVMContext &Context = TheLoop->getHeader()->getContext();
247 MDNode *NewLoopID = MDNode::get(Context, MDs);
248 // Set operand 0 to refer to the loop id itself.
249 NewLoopID->replaceOperandWith(I: 0, New: NewLoopID);
250 TheLoop->setLoopID(NewLoopID);
251}
252
253std::optional<ElementCount>
254llvm::getOptionalElementCountLoopAttribute(const Loop *TheLoop) {
255 std::optional<int> Width =
256 getOptionalIntLoopAttribute(TheLoop, Name: "llvm.loop.vectorize.width");
257
258 if (Width) {
259 std::optional<int> IsScalable = getOptionalIntLoopAttribute(
260 TheLoop, Name: "llvm.loop.vectorize.scalable.enable");
261 return ElementCount::get(MinVal: *Width, Scalable: IsScalable.value_or(u: false));
262 }
263
264 return std::nullopt;
265}
266
267std::optional<MDNode *> llvm::makeFollowupLoopID(
268 MDNode *OrigLoopID, ArrayRef<StringRef> FollowupOptions,
269 const char *InheritOptionsExceptPrefix, bool AlwaysNew) {
270 if (!OrigLoopID) {
271 if (AlwaysNew)
272 return nullptr;
273 return std::nullopt;
274 }
275
276 assert(OrigLoopID->getOperand(0) == OrigLoopID);
277
278 bool InheritAllAttrs = !InheritOptionsExceptPrefix;
279 bool InheritSomeAttrs =
280 InheritOptionsExceptPrefix && InheritOptionsExceptPrefix[0] != '\0';
281 SmallVector<Metadata *, 8> MDs;
282 MDs.push_back(Elt: nullptr);
283
284 bool Changed = false;
285 if (InheritAllAttrs || InheritSomeAttrs) {
286 for (const MDOperand &Existing : drop_begin(RangeOrContainer: OrigLoopID->operands())) {
287 MDNode *Op = cast<MDNode>(Val: Existing.get());
288
289 auto InheritThisAttribute = [InheritSomeAttrs,
290 InheritOptionsExceptPrefix](MDNode *Op) {
291 if (!InheritSomeAttrs)
292 return false;
293
294 // Skip malformatted attribute metadata nodes.
295 if (Op->getNumOperands() == 0)
296 return true;
297 Metadata *NameMD = Op->getOperand(I: 0).get();
298 if (!isa<MDString>(Val: NameMD))
299 return true;
300 StringRef AttrName = cast<MDString>(Val: NameMD)->getString();
301
302 // Do not inherit excluded attributes.
303 return !AttrName.starts_with(Prefix: InheritOptionsExceptPrefix);
304 };
305
306 if (InheritThisAttribute(Op))
307 MDs.push_back(Elt: Op);
308 else
309 Changed = true;
310 }
311 } else {
312 // Modified if we dropped at least one attribute.
313 Changed = OrigLoopID->getNumOperands() > 1;
314 }
315
316 bool HasAnyFollowup = false;
317 for (StringRef OptionName : FollowupOptions) {
318 MDNode *FollowupNode = findOptionMDForLoopID(LoopID: OrigLoopID, Name: OptionName);
319 if (!FollowupNode)
320 continue;
321
322 HasAnyFollowup = true;
323 for (const MDOperand &Option : drop_begin(RangeOrContainer: FollowupNode->operands())) {
324 MDs.push_back(Elt: Option.get());
325 Changed = true;
326 }
327 }
328
329 // Attributes of the followup loop not specified explicity, so signal to the
330 // transformation pass to add suitable attributes.
331 if (!AlwaysNew && !HasAnyFollowup)
332 return std::nullopt;
333
334 // If no attributes were added or remove, the previous loop Id can be reused.
335 if (!AlwaysNew && !Changed)
336 return OrigLoopID;
337
338 // No attributes is equivalent to having no !llvm.loop metadata at all.
339 if (MDs.size() == 1)
340 return nullptr;
341
342 // Build the new loop ID.
343 MDTuple *FollowupLoopID = MDNode::get(Context&: OrigLoopID->getContext(), MDs);
344 FollowupLoopID->replaceOperandWith(I: 0, New: FollowupLoopID);
345 return FollowupLoopID;
346}
347
348bool llvm::hasDisableAllTransformsHint(const Loop *L) {
349 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableNonforced);
350}
351
352bool llvm::hasDisableLICMTransformsHint(const Loop *L) {
353 return getBooleanLoopAttribute(TheLoop: L, Name: LLVMLoopDisableLICM);
354}
355
356StringRef llvm::getLoopVectorizeKindPrefix(const Loop *L) {
357 bool IsVectorBody = getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.body");
358 bool IsEpilogue = getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.epilogue");
359 if (IsVectorBody && IsEpilogue)
360 return "vectorized epilogue ";
361 if (IsVectorBody)
362 return "vectorized ";
363 if (IsEpilogue)
364 return "epilogue ";
365 return "";
366}
367
368TransformationMode llvm::hasUnrollTransformation(const Loop *L) {
369 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.disable"))
370 return TM_SuppressedByUser;
371
372 std::optional<int> Count =
373 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.count");
374 if (Count)
375 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
376
377 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.enable"))
378 return TM_ForcedByUser;
379
380 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll.full"))
381 return TM_ForcedByUser;
382
383 if (hasDisableAllTransformsHint(L))
384 return TM_Disable;
385
386 return TM_Unspecified;
387}
388
389TransformationMode llvm::hasUnrollAndJamTransformation(const Loop *L) {
390 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.disable"))
391 return TM_SuppressedByUser;
392
393 std::optional<int> Count =
394 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.count");
395 if (Count)
396 return *Count == 1 ? TM_SuppressedByUser : TM_ForcedByUser;
397
398 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.unroll_and_jam.enable"))
399 return TM_ForcedByUser;
400
401 if (hasDisableAllTransformsHint(L))
402 return TM_Disable;
403
404 return TM_Unspecified;
405}
406
407TransformationMode llvm::hasVectorizeTransformation(const Loop *L) {
408 std::optional<bool> Enable =
409 getOptionalBoolLoopAttribute(TheLoop: L, Name: "llvm.loop.vectorize.enable");
410
411 if (Enable == false)
412 return TM_SuppressedByUser;
413
414 std::optional<ElementCount> VectorizeWidth =
415 getOptionalElementCountLoopAttribute(TheLoop: L);
416 std::optional<int> InterleaveCount =
417 getOptionalIntLoopAttribute(TheLoop: L, Name: "llvm.loop.interleave.count");
418
419 // 'Forcing' vector width and interleave count to one effectively disables
420 // this tranformation.
421 if (Enable == true && VectorizeWidth && VectorizeWidth->isScalar() &&
422 InterleaveCount == 1)
423 return TM_SuppressedByUser;
424
425 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.isvectorized"))
426 return TM_Disable;
427
428 if (Enable == true)
429 return TM_ForcedByUser;
430
431 if ((VectorizeWidth && VectorizeWidth->isScalar()) && InterleaveCount == 1)
432 return TM_Disable;
433
434 if ((VectorizeWidth && VectorizeWidth->isVector()) || InterleaveCount > 1)
435 return TM_Enable;
436
437 if (hasDisableAllTransformsHint(L))
438 return TM_Disable;
439
440 return TM_Unspecified;
441}
442
443TransformationMode llvm::hasDistributeTransformation(const Loop *L) {
444 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.distribute.enable"))
445 return TM_ForcedByUser;
446
447 if (hasDisableAllTransformsHint(L))
448 return TM_Disable;
449
450 return TM_Unspecified;
451}
452
453TransformationMode llvm::hasLICMVersioningTransformation(const Loop *L) {
454 if (getBooleanLoopAttribute(TheLoop: L, Name: "llvm.loop.licm_versioning.disable"))
455 return TM_SuppressedByUser;
456
457 if (hasDisableAllTransformsHint(L))
458 return TM_Disable;
459
460 return TM_Unspecified;
461}
462
463/// Does a BFS from a given node to all of its children inside a given loop.
464/// The returned vector of basic blocks includes the starting point.
465SmallVector<BasicBlock *, 16> llvm::collectChildrenInLoop(DominatorTree *DT,
466 DomTreeNode *N,
467 const Loop *CurLoop) {
468 SmallVector<BasicBlock *, 16> Worklist;
469 auto AddRegionToWorklist = [&](DomTreeNode *DTN) {
470 // Only include subregions in the top level loop.
471 BasicBlock *BB = DTN->getBlock();
472 if (CurLoop->contains(BB))
473 Worklist.push_back(Elt: DTN->getBlock());
474 };
475
476 AddRegionToWorklist(N);
477
478 for (size_t I = 0; I < Worklist.size(); I++) {
479 for (DomTreeNode *Child : DT->getNode(BB: Worklist[I])->children())
480 AddRegionToWorklist(Child);
481 }
482
483 return Worklist;
484}
485
486bool llvm::isAlmostDeadIV(PHINode *PN, BasicBlock *LatchBlock, Value *Cond) {
487 int LatchIdx = PN->getBasicBlockIndex(BB: LatchBlock);
488 assert(LatchIdx != -1 && "LatchBlock is not a case in this PHINode");
489 Value *IncV = PN->getIncomingValue(i: LatchIdx);
490
491 for (User *U : PN->users())
492 if (U != Cond && U != IncV) return false;
493
494 for (User *U : IncV->users())
495 if (U != Cond && U != PN) return false;
496 return true;
497}
498
499
500void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
501 LoopInfo *LI, MemorySSA *MSSA) {
502 assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
503 auto *Preheader = L->getLoopPreheader();
504 assert(Preheader && "Preheader should exist!");
505
506 std::unique_ptr<MemorySSAUpdater> MSSAU;
507 if (MSSA)
508 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
509
510 // Now that we know the removal is safe, remove the loop by changing the
511 // branch from the preheader to go to the single exit block.
512 //
513 // Because we're deleting a large chunk of code at once, the sequence in which
514 // we remove things is very important to avoid invalidation issues.
515
516 // Tell ScalarEvolution that the loop is deleted. Do this before
517 // deleting the loop so that ScalarEvolution can look at the loop
518 // to determine what it needs to clean up.
519 if (SE) {
520 SE->forgetLoop(L);
521 SE->forgetBlockAndLoopDispositions();
522 }
523
524 Instruction *OldTerm = Preheader->getTerminator();
525 assert(!OldTerm->mayHaveSideEffects() &&
526 "Preheader must end with a side-effect-free terminator");
527 assert(OldTerm->getNumSuccessors() == 1 &&
528 "Preheader must have a single successor");
529 // Connect the preheader to the exit block. Keep the old edge to the header
530 // around to perform the dominator tree update in two separate steps
531 // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge
532 // preheader -> header.
533 //
534 //
535 // 0. Preheader 1. Preheader 2. Preheader
536 // | | | |
537 // V | V |
538 // Header <--\ | Header <--\ | Header <--\
539 // | | | | | | | | | | |
540 // | V | | | V | | | V |
541 // | Body --/ | | Body --/ | | Body --/
542 // V V V V V
543 // Exit Exit Exit
544 //
545 // By doing this is two separate steps we can perform the dominator tree
546 // update without using the batch update API.
547 //
548 // Even when the loop is never executed, we cannot remove the edge from the
549 // source block to the exit block. Consider the case where the unexecuted loop
550 // branches back to an outer loop. If we deleted the loop and removed the edge
551 // coming to this inner loop, this will break the outer loop structure (by
552 // deleting the backedge of the outer loop). If the outer loop is indeed a
553 // non-loop, it will be deleted in a future iteration of loop deletion pass.
554 IRBuilder<> Builder(OldTerm);
555
556 auto *ExitBlock = L->getUniqueExitBlock();
557 DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
558 if (ExitBlock) {
559 assert(ExitBlock && "Should have a unique exit block!");
560 assert(L->hasDedicatedExits() && "Loop should have dedicated exits!");
561
562 Builder.CreateCondBr(Cond: Builder.getFalse(), True: L->getHeader(), False: ExitBlock);
563 // Remove the old branch. The conditional branch becomes a new terminator.
564 OldTerm->eraseFromParent();
565
566 // Rewrite phis in the exit block to get their inputs from the Preheader
567 // instead of the exiting block.
568 for (PHINode &P : ExitBlock->phis()) {
569 // Set the zero'th element of Phi to be from the preheader and remove all
570 // other incoming values. Given the loop has dedicated exits, all other
571 // incoming values must be from the exiting blocks.
572 int PredIndex = 0;
573 P.setIncomingBlock(i: PredIndex, BB: Preheader);
574 // Removes all incoming values from all other exiting blocks (including
575 // duplicate values from an exiting block).
576 // Nuke all entries except the zero'th entry which is the preheader entry.
577 P.removeIncomingValueIf(Predicate: [](unsigned Idx) { return Idx != 0; },
578 /* DeletePHIIfEmpty */ false);
579
580 assert((P.getNumIncomingValues() == 1 &&
581 P.getIncomingBlock(PredIndex) == Preheader) &&
582 "Should have exactly one value and that's from the preheader!");
583 }
584
585 if (DT) {
586 DTU.applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}});
587 if (MSSA) {
588 MSSAU->applyUpdates(Updates: {{DominatorTree::Insert, Preheader, ExitBlock}},
589 DT&: *DT);
590 if (VerifyMemorySSA)
591 MSSA->verifyMemorySSA();
592 }
593 }
594
595 // Disconnect the loop body by branching directly to its exit.
596 Builder.SetInsertPoint(Preheader->getTerminator());
597 Builder.CreateBr(Dest: ExitBlock);
598 // Remove the old branch.
599 Preheader->getTerminator()->eraseFromParent();
600 } else {
601 assert(L->hasNoExitBlocks() &&
602 "Loop should have either zero or one exit blocks.");
603
604 Builder.SetInsertPoint(OldTerm);
605 Builder.CreateUnreachable();
606 Preheader->getTerminator()->eraseFromParent();
607 }
608
609 if (DT) {
610 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}});
611 if (MSSA) {
612 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Preheader, L->getHeader()}},
613 DT&: *DT);
614 SmallSetVector<BasicBlock *, 8> DeadBlockSet(L->block_begin(),
615 L->block_end());
616 MSSAU->removeBlocks(DeadBlocks: DeadBlockSet);
617 if (VerifyMemorySSA)
618 MSSA->verifyMemorySSA();
619 }
620 }
621
622 // Use a map to unique and a vector to guarantee deterministic ordering.
623 llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
624 llvm::SmallVector<DbgVariableRecord *, 4> DeadDbgVariableRecords;
625
626 // Given LCSSA form is satisfied, we should not have users of instructions
627 // within the dead loop outside of the loop. However, LCSSA doesn't take
628 // unreachable uses into account. We handle them here.
629 // We could do it after drop all references (in this case all users in the
630 // loop will be already eliminated and we have less work to do but according
631 // to API doc of User::dropAllReferences only valid operation after dropping
632 // references, is deletion. So let's substitute all usages of
633 // instruction from the loop with poison value of corresponding type first.
634 for (auto *Block : L->blocks())
635 for (Instruction &I : *Block) {
636 auto *Poison = PoisonValue::get(T: I.getType());
637 for (Use &U : llvm::make_early_inc_range(Range: I.uses())) {
638 if (auto *Usr = dyn_cast<Instruction>(Val: U.getUser()))
639 if (L->contains(BB: Usr->getParent()))
640 continue;
641 // If we have a DT then we can check that uses outside a loop only in
642 // unreachable block.
643 if (DT)
644 assert(!DT->isReachableFromEntry(U) &&
645 "Unexpected user in reachable block");
646 U.set(Poison);
647 }
648
649 if (ExitBlock) {
650 // For one of each variable encountered, preserve a debug record (set
651 // to Poison) and transfer it to the loop exit. This terminates any
652 // variable locations that were set during the loop.
653 for (DbgVariableRecord &DVR :
654 llvm::make_early_inc_range(Range: filterDbgVars(R: I.getDbgRecordRange()))) {
655 DebugVariable Key(DVR.getVariable(), DVR.getExpression(),
656 DVR.getDebugLoc().get());
657 if (!DeadDebugSet.insert(V: Key).second)
658 continue;
659 // Unlinks the DVR from it's container, for later insertion.
660 DVR.removeFromParent();
661 DeadDbgVariableRecords.push_back(Elt: &DVR);
662 }
663 }
664 }
665
666 if (ExitBlock) {
667 // After the loop has been deleted all the values defined and modified
668 // inside the loop are going to be unavailable. Values computed in the
669 // loop will have been deleted, automatically causing their debug uses
670 // be be replaced with undef. Loop invariant values will still be available.
671 // Move dbg.values out the loop so that earlier location ranges are still
672 // terminated and loop invariant assignments are preserved.
673 DIBuilder DIB(*ExitBlock->getModule());
674 BasicBlock::iterator InsertDbgValueBefore =
675 ExitBlock->getFirstInsertionPt();
676 assert(InsertDbgValueBefore != ExitBlock->end() &&
677 "There should be a non-PHI instruction in exit block, else these "
678 "instructions will have no parent.");
679
680 // Due to the "head" bit in BasicBlock::iterator, we're going to insert
681 // each DbgVariableRecord right at the start of the block, wheras dbg.values
682 // would be repeatedly inserted before the first instruction. To replicate
683 // this behaviour, do it backwards.
684 for (DbgVariableRecord *DVR : llvm::reverse(C&: DeadDbgVariableRecords))
685 ExitBlock->insertDbgRecordBefore(DR: DVR, Here: InsertDbgValueBefore);
686 }
687
688 // Remove the block from the reference counting scheme, so that we can
689 // delete it freely later.
690 for (auto *Block : L->blocks())
691 Block->dropAllReferences();
692
693 if (MSSA && VerifyMemorySSA)
694 MSSA->verifyMemorySSA();
695
696 if (LI) {
697 SmallPtrSet<BasicBlock *, 8> Blocks(llvm::from_range, L->blocks());
698
699 // Erase the instructions and the blocks without having to worry
700 // about ordering because we already dropped the references.
701 // Remove blocks from loopinfo before erasing them, otherwise the loopinfo
702 // cannot find the loop using block numbers.
703 for (BasicBlock *BB : Blocks) {
704 LI->removeBlock(BB);
705 BB->eraseFromParent();
706 }
707
708 // The last step is to update LoopInfo now that we've eliminated this loop.
709 // Note: LoopInfo::erase remove the given loop and relink its subloops with
710 // its parent. While removeLoop/removeChildLoop remove the given loop but
711 // not relink its subloops, which is what we want.
712 if (Loop *ParentLoop = L->getParentLoop()) {
713 Loop::iterator I = find(Range&: *ParentLoop, Val: L);
714 assert(I != ParentLoop->end() && "Couldn't find loop");
715 ParentLoop->removeChildLoop(I);
716 } else {
717 Loop::iterator I = find(Range&: *LI, Val: L);
718 assert(I != LI->end() && "Couldn't find loop");
719 LI->removeLoop(I);
720 }
721 LI->destroy(L);
722 }
723}
724
725void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
726 LoopInfo &LI, MemorySSA *MSSA) {
727 auto *Latch = L->getLoopLatch();
728 assert(Latch && "multiple latches not yet supported");
729 auto *Header = L->getHeader();
730 Loop *OutermostLoop = L->getOutermostLoop();
731
732 SE.forgetLoop(L);
733 SE.forgetBlockAndLoopDispositions();
734
735 std::unique_ptr<MemorySSAUpdater> MSSAU;
736 if (MSSA)
737 MSSAU = std::make_unique<MemorySSAUpdater>(args&: MSSA);
738
739 // Update the CFG and domtree. We chose to special case a couple of
740 // of common cases for code quality and test readability reasons.
741 [&]() -> void {
742 if (auto *BI = dyn_cast<UncondBrInst>(Val: Latch->getTerminator())) {
743 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
744 (void)changeToUnreachable(I: BI, /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
745 return;
746 }
747 if (auto *BI = dyn_cast<CondBrInst>(Val: Latch->getTerminator())) {
748 // Conditional latch/exit - note that latch can be shared by inner
749 // and outer loop so the other target doesn't need to an exit
750 if (L->isLoopExiting(BB: Latch)) {
751 // TODO: Generalize ConstantFoldTerminator so that it can be used
752 // here without invalidating LCSSA or MemorySSA. (Tricky case for
753 // LCSSA: header is an exit block of a preceeding sibling loop w/o
754 // dedicated exits.)
755 const unsigned ExitIdx = L->contains(BB: BI->getSuccessor(i: 0)) ? 1 : 0;
756 BasicBlock *ExitBB = BI->getSuccessor(i: ExitIdx);
757
758 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
759 Header->removePredecessor(Pred: Latch, KeepOneInputPHIs: true);
760
761 IRBuilder<> Builder(BI);
762 auto *NewBI = Builder.CreateBr(Dest: ExitBB);
763 // Transfer the metadata to the new branch instruction (minus the
764 // loop info since this is no longer a loop)
765 NewBI->copyMetadata(SrcInst: *BI, WL: {LLVMContext::MD_dbg,
766 LLVMContext::MD_annotation});
767
768 BI->eraseFromParent();
769 DTU.applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}});
770 if (MSSA)
771 MSSAU->applyUpdates(Updates: {{DominatorTree::Delete, Latch, Header}}, DT);
772 return;
773 }
774 }
775
776 // General case. By splitting the backedge, and then explicitly making it
777 // unreachable we gracefully handle corner cases such as switch and invoke
778 // termiantors.
779 auto *BackedgeBB = SplitEdge(From: Latch, To: Header, DT: &DT, LI: &LI, MSSAU: MSSAU.get());
780
781 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
782 (void)changeToUnreachable(I: BackedgeBB->getTerminator(),
783 /*PreserveLCSSA*/ true, DTU: &DTU, MSSAU: MSSAU.get());
784 }();
785
786 // Erase (and destroy) this loop instance. Handles relinking sub-loops
787 // and blocks within the loop as needed.
788 LI.erase(L);
789
790 // If the loop we broke had a parent, then changeToUnreachable might have
791 // caused a block to be removed from the parent loop (see loop_nest_lcssa
792 // test case in zero-btc.ll for an example), thus changing the parent's
793 // exit blocks. If that happened, we need to rebuild LCSSA on the outermost
794 // loop which might have a had a block removed.
795 if (OutermostLoop != L)
796 formLCSSARecursively(L&: *OutermostLoop, DT, LI: &LI, SE: &SE);
797}
798
799
800/// Checks if \p L has an exiting latch branch. There may also be other
801/// exiting blocks. Returns branch instruction terminating the loop
802/// latch if above check is successful, nullptr otherwise.
803static CondBrInst *getExpectedExitLoopLatchBranch(Loop *L) {
804 BasicBlock *Latch = L->getLoopLatch();
805 if (!Latch)
806 return nullptr;
807
808 CondBrInst *LatchBR = dyn_cast<CondBrInst>(Val: Latch->getTerminator());
809 if (!LatchBR || !L->isLoopExiting(BB: Latch))
810 return nullptr;
811
812 assert((LatchBR->getSuccessor(0) == L->getHeader() ||
813 LatchBR->getSuccessor(1) == L->getHeader()) &&
814 "At least one edge out of the latch must go to the header");
815
816 return LatchBR;
817}
818
819struct DbgLoop {
820 const Loop *L;
821 explicit DbgLoop(const Loop *L) : L(L) {}
822};
823
824#ifndef NDEBUG
825static inline raw_ostream &operator<<(raw_ostream &OS, DbgLoop D) {
826 OS << "function ";
827 D.L->getHeader()->getParent()->printAsOperand(OS, /*PrintType=*/false);
828 return OS << " " << *D.L;
829}
830#endif // NDEBUG
831
832static std::optional<unsigned> estimateLoopTripCount(Loop *L) {
833 // Currently we take the estimate exit count only from the loop latch,
834 // ignoring other exiting blocks. This can overestimate the trip count
835 // if we exit through another exit, but can never underestimate it.
836 // TODO: incorporate information from other exits
837 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
838 if (!ExitingBranch) {
839 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to find exiting "
840 << "latch branch of required form in " << DbgLoop(L)
841 << "\n");
842 return std::nullopt;
843 }
844
845 // To estimate the number of times the loop body was executed, we want to
846 // know the number of times the backedge was taken, vs. the number of times
847 // we exited the loop.
848 uint64_t LoopWeight, ExitWeight;
849 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight)) {
850 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed to extract branch "
851 << "weights for " << DbgLoop(L) << "\n");
852 return std::nullopt;
853 }
854
855 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
856 std::swap(a&: LoopWeight, b&: ExitWeight);
857
858 if (!ExitWeight) {
859 // Don't have a way to return predicated infinite
860 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Failed because of zero exit "
861 << "probability for " << DbgLoop(L) << "\n");
862 return std::nullopt;
863 }
864
865 // Estimated exit count is a ratio of the loop weight by the weight of the
866 // edge exiting the loop, rounded to nearest.
867 uint64_t ExitCount = llvm::divideNearest(Numerator: LoopWeight, Denominator: ExitWeight);
868
869 // When ExitCount + 1 would wrap in unsigned, saturate at UINT_MAX.
870 if (ExitCount >= std::numeric_limits<unsigned>::max())
871 return std::numeric_limits<unsigned>::max();
872
873 // Estimated trip count is one plus estimated exit count.
874 uint64_t TC = ExitCount + 1;
875 LLVM_DEBUG(dbgs() << "estimateLoopTripCount: Estimated trip count of " << TC
876 << " for " << DbgLoop(L) << "\n");
877 return TC;
878}
879
880std::optional<unsigned>
881llvm::getLoopEstimatedTripCount(Loop *L,
882 unsigned *EstimatedLoopInvocationWeight) {
883 // If EstimatedLoopInvocationWeight, we do not support this loop if
884 // getExpectedExitLoopLatchBranch returns nullptr.
885 //
886 // FIXME: Also, this is a stop-gap solution for nested loops. It avoids
887 // mistaking LLVMLoopEstimatedTripCount metadata to be for an outer loop when
888 // it was created for an inner loop. The problem is that loop metadata is
889 // attached to the branch instruction in the loop latch block, but that can be
890 // shared by the loops. A solution is to attach loop metadata to loop headers
891 // instead, but that would be a large change to LLVM.
892 //
893 // Until that happens, we work around the problem as follows.
894 // getExpectedExitLoopLatchBranch (which also guards
895 // setLoopEstimatedTripCount) returns nullptr for a loop unless the loop has
896 // one latch and that latch has exactly two successors one of which is an exit
897 // from the loop. If the latch is shared by nested loops, then that condition
898 // might hold for the inner loop but cannot hold for the outer loop:
899 // - Because the latch is shared, it must have at least two successors: the
900 // inner loop header and the outer loop header, which is also an exit for
901 // the inner loop. That satisifies the condition for the inner loop.
902 // - To satsify the condition for the outer loop, the latch must have a third
903 // successor that is an exit for the outer loop. But that violates the
904 // condition for both loops.
905 CondBrInst *ExitingBranch = getExpectedExitLoopLatchBranch(L);
906 if (!ExitingBranch)
907 return std::nullopt;
908
909 // If requested, either compute *EstimatedLoopInvocationWeight or return
910 // nullopt if cannot.
911 //
912 // TODO: Eventually, once all passes have migrated away from setting branch
913 // weights to indicate estimated trip counts, this function will drop the
914 // EstimatedLoopInvocationWeight parameter.
915 if (EstimatedLoopInvocationWeight) {
916 uint64_t LoopWeight = 0, ExitWeight = 0; // Inits expected to be unused.
917 if (!extractBranchWeights(I: *ExitingBranch, TrueVal&: LoopWeight, FalseVal&: ExitWeight))
918 return std::nullopt;
919 if (L->contains(BB: ExitingBranch->getSuccessor(i: 1)))
920 std::swap(a&: LoopWeight, b&: ExitWeight);
921 if (!ExitWeight)
922 return std::nullopt;
923 *EstimatedLoopInvocationWeight = ExitWeight;
924 }
925
926 // Return the estimated trip count from metadata unless the metadata is
927 // missing or has no value.
928 //
929 // Some passes set llvm.loop.estimated_trip_count to 0. For example, after
930 // peeling 10 or more iterations from a loop with an estimated trip count of
931 // 10, llvm.loop.estimated_trip_count becomes 0 on the remaining loop. It
932 // indicates that, each time execution reaches the peeled iterations,
933 // execution is estimated to exit them without reaching the remaining loop's
934 // header.
935 //
936 // Even if the probability of reaching a loop's header is low, if it is
937 // reached, it is the start of an iteration. Consequently, some passes
938 // historically assume that llvm::getLoopEstimatedTripCount always returns a
939 // positive count or std::nullopt. Thus, return std::nullopt when
940 // llvm.loop.estimated_trip_count is 0.
941 if (std::optional<unsigned> TC =
942 getOptionalIntLoopAttribute(TheLoop: L, Name: LLVMLoopEstimatedTripCount)) {
943 LLVM_DEBUG(dbgs() << "getLoopEstimatedTripCount: "
944 << LLVMLoopEstimatedTripCount << " metadata has trip "
945 << "count of " << *TC
946 << (*TC == 0 ? " (returning std::nullopt)" : "")
947 << " for " << DbgLoop(L) << "\n");
948 return *TC == 0 ? std::nullopt : TC;
949 }
950
951 // Estimate the trip count from latch branch weights.
952 return estimateLoopTripCount(L);
953}
954
955bool llvm::setLoopEstimatedTripCount(
956 Loop *L, unsigned EstimatedTripCount,
957 std::optional<unsigned> EstimatedloopInvocationWeight) {
958 // If EstimatedLoopInvocationWeight, we do not support this loop if
959 // getExpectedExitLoopLatchBranch returns nullptr.
960 //
961 // FIXME: See comments in getLoopEstimatedTripCount for why this is required
962 // here regardless of EstimatedLoopInvocationWeight.
963 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
964 if (!LatchBranch)
965 return false;
966
967 // Set the metadata.
968 addStringMetadataToLoop(TheLoop: L, StringMD: LLVMLoopEstimatedTripCount, V: EstimatedTripCount);
969
970 // At the moment, we currently support changing the estimated trip count in
971 // the latch branch's branch weights only. We could extend this API to
972 // manipulate estimated trip counts for any exit.
973 //
974 // TODO: Eventually, once all passes have migrated away from setting branch
975 // weights to indicate estimated trip counts, we will not set branch weights
976 // here at all.
977 if (!EstimatedloopInvocationWeight)
978 return true;
979
980 // Calculate taken and exit weights.
981 unsigned LatchExitWeight = ProfcheckDisableMetadataFixes ? 0 : 1;
982 unsigned BackedgeTakenWeight = 0;
983
984 if (EstimatedTripCount != 0) {
985 LatchExitWeight = *EstimatedloopInvocationWeight;
986 BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight;
987 }
988
989 // Make a swap if back edge is taken when condition is "false".
990 if (LatchBranch->getSuccessor(i: 0) != L->getHeader())
991 std::swap(a&: BackedgeTakenWeight, b&: LatchExitWeight);
992
993 // Set/Update profile metadata.
994 setBranchWeights(I&: *LatchBranch, Weights: {BackedgeTakenWeight, LatchExitWeight},
995 /*IsExpected=*/false);
996
997 return true;
998}
999
1000BranchProbability llvm::getLoopProbability(Loop *L) {
1001 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
1002 if (!LatchBranch)
1003 return BranchProbability::getUnknown();
1004 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
1005 return getBranchProbability(B: LatchBranch, ForFirstTarget: FirstTargetIsLoop);
1006}
1007
1008bool llvm::setLoopProbability(Loop *L, BranchProbability P) {
1009 CondBrInst *LatchBranch = getExpectedExitLoopLatchBranch(L);
1010 if (!LatchBranch)
1011 return false;
1012 bool FirstTargetIsLoop = LatchBranch->getSuccessor(i: 0) == L->getHeader();
1013 setBranchProbability(B: LatchBranch, P, ForFirstTarget: FirstTargetIsLoop);
1014 return true;
1015}
1016
1017BranchProbability llvm::getBranchProbability(CondBrInst *B,
1018 bool ForFirstTarget) {
1019 uint64_t Weight0, Weight1;
1020 if (!extractBranchWeights(I: *B, TrueVal&: Weight0, FalseVal&: Weight1))
1021 return BranchProbability::getUnknown();
1022 uint64_t Denominator = Weight0 + Weight1;
1023 if (Denominator == 0)
1024 return BranchProbability::getUnknown();
1025 if (!ForFirstTarget)
1026 std::swap(a&: Weight0, b&: Weight1);
1027 return BranchProbability::getBranchProbability(Numerator: Weight0, Denominator);
1028}
1029
1030BranchProbability llvm::getBranchProbability(BasicBlock *Src, BasicBlock *Dst) {
1031 assert(Src != Dst && "Passed in same source as destination");
1032
1033 Instruction *TI = Src->getTerminator();
1034 if (!TI || TI->getNumSuccessors() == 0)
1035 return BranchProbability::getZero();
1036
1037 SmallVector<uint32_t, 4> Weights;
1038
1039 if (!extractBranchWeights(I: *TI, Weights)) {
1040 // No metadata
1041 return BranchProbability::getUnknown();
1042 }
1043 assert(TI->getNumSuccessors() == Weights.size() &&
1044 "Missing weights in branch_weights");
1045
1046 uint64_t Total = 0;
1047 uint32_t Numerator = 0;
1048 for (auto [i, Weight] : llvm::enumerate(First&: Weights)) {
1049 if (TI->getSuccessor(Idx: i) == Dst)
1050 Numerator += Weight;
1051 Total += Weight;
1052 }
1053
1054 // Total of edges might be 0 if the metadata is incorrect/set by hand
1055 // or missing. In such case return here to avoid division by 0 later on.
1056 // There might also be a case where the value of Total cannot fit into
1057 // uint32_t, in such case, just bail out.
1058 if (Total == 0 || Total > std::numeric_limits<uint32_t>::max())
1059 return BranchProbability::getUnknown();
1060
1061 return BranchProbability(Numerator, Total);
1062}
1063
1064void llvm::setBranchProbability(CondBrInst *B, BranchProbability P,
1065 bool ForFirstTarget) {
1066 BranchProbability Prob0 = P;
1067 BranchProbability Prob1 = P.getCompl();
1068 if (!ForFirstTarget)
1069 std::swap(a&: Prob0, b&: Prob1);
1070 setBranchWeights(I&: *B, Weights: {Prob0.getNumerator(), Prob1.getNumerator()},
1071 /*IsExpected=*/false);
1072}
1073
1074bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
1075 ScalarEvolution &SE) {
1076 Loop *OuterL = InnerLoop->getParentLoop();
1077 if (!OuterL)
1078 return true;
1079
1080 // Get the backedge taken count for the inner loop
1081 BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch();
1082 const SCEV *InnerLoopBECountSC = SE.getExitCount(L: InnerLoop, ExitingBlock: InnerLoopLatch);
1083 if (isa<SCEVCouldNotCompute>(Val: InnerLoopBECountSC) ||
1084 !InnerLoopBECountSC->getType()->isIntegerTy())
1085 return false;
1086
1087 // Get whether count is invariant to the outer loop
1088 ScalarEvolution::LoopDisposition LD =
1089 SE.getLoopDisposition(S: InnerLoopBECountSC, L: OuterL);
1090 if (LD != ScalarEvolution::LoopInvariant)
1091 return false;
1092
1093 return true;
1094}
1095
1096constexpr Intrinsic::ID llvm::getReductionIntrinsicID(RecurKind RK) {
1097 switch (RK) {
1098 default:
1099 llvm_unreachable("Unexpected recurrence kind");
1100 case RecurKind::AddChainWithSubs:
1101 case RecurKind::Sub:
1102 case RecurKind::Add:
1103 return Intrinsic::vector_reduce_add;
1104 case RecurKind::Mul:
1105 return Intrinsic::vector_reduce_mul;
1106 case RecurKind::And:
1107 return Intrinsic::vector_reduce_and;
1108 case RecurKind::Or:
1109 return Intrinsic::vector_reduce_or;
1110 case RecurKind::Xor:
1111 return Intrinsic::vector_reduce_xor;
1112 case RecurKind::FMulAdd:
1113 case RecurKind::FAddChainWithSubs:
1114 case RecurKind::FSub:
1115 case RecurKind::FAdd:
1116 return Intrinsic::vector_reduce_fadd;
1117 case RecurKind::FMul:
1118 return Intrinsic::vector_reduce_fmul;
1119 case RecurKind::SMax:
1120 return Intrinsic::vector_reduce_smax;
1121 case RecurKind::SMin:
1122 return Intrinsic::vector_reduce_smin;
1123 case RecurKind::UMax:
1124 return Intrinsic::vector_reduce_umax;
1125 case RecurKind::UMin:
1126 return Intrinsic::vector_reduce_umin;
1127 case RecurKind::FMax:
1128 case RecurKind::FMaxNum:
1129 return Intrinsic::vector_reduce_fmax;
1130 case RecurKind::FMin:
1131 case RecurKind::FMinNum:
1132 return Intrinsic::vector_reduce_fmin;
1133 case RecurKind::FMaximum:
1134 return Intrinsic::vector_reduce_fmaximum;
1135 case RecurKind::FMinimum:
1136 return Intrinsic::vector_reduce_fminimum;
1137 case RecurKind::FMaximumNum:
1138 return Intrinsic::vector_reduce_fmax;
1139 case RecurKind::FMinimumNum:
1140 return Intrinsic::vector_reduce_fmin;
1141 }
1142}
1143
1144Intrinsic::ID llvm::getMinMaxReductionIntrinsicID(Intrinsic::ID IID) {
1145 switch (IID) {
1146 default:
1147 llvm_unreachable("Unexpected intrinsic id");
1148 case Intrinsic::umin:
1149 return Intrinsic::vector_reduce_umin;
1150 case Intrinsic::umax:
1151 return Intrinsic::vector_reduce_umax;
1152 case Intrinsic::smin:
1153 return Intrinsic::vector_reduce_smin;
1154 case Intrinsic::smax:
1155 return Intrinsic::vector_reduce_smax;
1156 }
1157}
1158
1159// This is the inverse to getReductionForBinop
1160unsigned llvm::getArithmeticReductionInstruction(Intrinsic::ID RdxID) {
1161 switch (RdxID) {
1162 case Intrinsic::vector_reduce_fadd:
1163 return Instruction::FAdd;
1164 case Intrinsic::vector_reduce_fmul:
1165 return Instruction::FMul;
1166 case Intrinsic::vector_reduce_add:
1167 return Instruction::Add;
1168 case Intrinsic::vector_reduce_mul:
1169 return Instruction::Mul;
1170 case Intrinsic::vector_reduce_and:
1171 return Instruction::And;
1172 case Intrinsic::vector_reduce_or:
1173 return Instruction::Or;
1174 case Intrinsic::vector_reduce_xor:
1175 return Instruction::Xor;
1176 case Intrinsic::vector_reduce_smax:
1177 case Intrinsic::vector_reduce_smin:
1178 case Intrinsic::vector_reduce_umax:
1179 case Intrinsic::vector_reduce_umin:
1180 return Instruction::ICmp;
1181 case Intrinsic::vector_reduce_fmax:
1182 case Intrinsic::vector_reduce_fmin:
1183 return Instruction::FCmp;
1184 default:
1185 llvm_unreachable("Unexpected ID");
1186 }
1187}
1188
1189// This is the inverse to getArithmeticReductionInstruction
1190Intrinsic::ID llvm::getReductionForBinop(Instruction::BinaryOps Opc) {
1191 switch (Opc) {
1192 default:
1193 break;
1194 case Instruction::Add:
1195 return Intrinsic::vector_reduce_add;
1196 case Instruction::Mul:
1197 return Intrinsic::vector_reduce_mul;
1198 case Instruction::And:
1199 return Intrinsic::vector_reduce_and;
1200 case Instruction::Or:
1201 return Intrinsic::vector_reduce_or;
1202 case Instruction::Xor:
1203 return Intrinsic::vector_reduce_xor;
1204 case Instruction::FAdd:
1205 return Intrinsic::vector_reduce_fadd;
1206 case Instruction::FMul:
1207 return Intrinsic::vector_reduce_fmul;
1208 }
1209 return Intrinsic::not_intrinsic;
1210}
1211
1212Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(Intrinsic::ID RdxID) {
1213 switch (RdxID) {
1214 default:
1215 llvm_unreachable("Unknown min/max recurrence kind");
1216 case Intrinsic::vector_reduce_umin:
1217 return Intrinsic::umin;
1218 case Intrinsic::vector_reduce_umax:
1219 return Intrinsic::umax;
1220 case Intrinsic::vector_reduce_smin:
1221 return Intrinsic::smin;
1222 case Intrinsic::vector_reduce_smax:
1223 return Intrinsic::smax;
1224 case Intrinsic::vector_reduce_fmin:
1225 return Intrinsic::minnum;
1226 case Intrinsic::vector_reduce_fmax:
1227 return Intrinsic::maxnum;
1228 case Intrinsic::vector_reduce_fminimum:
1229 return Intrinsic::minimum;
1230 case Intrinsic::vector_reduce_fmaximum:
1231 return Intrinsic::maximum;
1232 }
1233}
1234
1235Intrinsic::ID llvm::getMinMaxReductionIntrinsicOp(RecurKind RK) {
1236 switch (RK) {
1237 default:
1238 llvm_unreachable("Unknown min/max recurrence kind");
1239 case RecurKind::UMin:
1240 return Intrinsic::umin;
1241 case RecurKind::UMax:
1242 return Intrinsic::umax;
1243 case RecurKind::SMin:
1244 return Intrinsic::smin;
1245 case RecurKind::SMax:
1246 return Intrinsic::smax;
1247 case RecurKind::FMin:
1248 case RecurKind::FMinNum:
1249 return Intrinsic::minnum;
1250 case RecurKind::FMax:
1251 case RecurKind::FMaxNum:
1252 return Intrinsic::maxnum;
1253 case RecurKind::FMinimum:
1254 return Intrinsic::minimum;
1255 case RecurKind::FMaximum:
1256 return Intrinsic::maximum;
1257 case RecurKind::FMinimumNum:
1258 return Intrinsic::minimumnum;
1259 case RecurKind::FMaximumNum:
1260 return Intrinsic::maximumnum;
1261 }
1262}
1263
1264RecurKind llvm::getMinMaxReductionRecurKind(Intrinsic::ID RdxID) {
1265 switch (RdxID) {
1266 case Intrinsic::vector_reduce_smax:
1267 return RecurKind::SMax;
1268 case Intrinsic::vector_reduce_smin:
1269 return RecurKind::SMin;
1270 case Intrinsic::vector_reduce_umax:
1271 return RecurKind::UMax;
1272 case Intrinsic::vector_reduce_umin:
1273 return RecurKind::UMin;
1274 case Intrinsic::vector_reduce_fmax:
1275 return RecurKind::FMax;
1276 case Intrinsic::vector_reduce_fmin:
1277 return RecurKind::FMin;
1278 default:
1279 return RecurKind::None;
1280 }
1281}
1282
1283CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
1284 switch (RK) {
1285 default:
1286 llvm_unreachable("Unknown min/max recurrence kind");
1287 case RecurKind::UMin:
1288 return CmpInst::ICMP_ULT;
1289 case RecurKind::UMax:
1290 return CmpInst::ICMP_UGT;
1291 case RecurKind::SMin:
1292 return CmpInst::ICMP_SLT;
1293 case RecurKind::SMax:
1294 return CmpInst::ICMP_SGT;
1295 case RecurKind::FMin:
1296 return CmpInst::FCMP_OLT;
1297 case RecurKind::FMax:
1298 return CmpInst::FCMP_OGT;
1299 // We do not add FMinimum/FMaximum recurrence kind here since there is no
1300 // equivalent predicate which compares signed zeroes according to the
1301 // semantics of the intrinsics (llvm.minimum/maximum).
1302 }
1303}
1304
1305Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
1306 Value *Right) {
1307 Type *Ty = Left->getType();
1308 if (Ty->isIntOrIntVectorTy() ||
1309 (RK == RecurKind::FMinNum || RK == RecurKind::FMaxNum ||
1310 RK == RecurKind::FMinimum || RK == RecurKind::FMaximum ||
1311 RK == RecurKind::FMinimumNum || RK == RecurKind::FMaximumNum)) {
1312 Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RK);
1313 return Builder.CreateIntrinsic(RetTy: Ty, ID: Id, Args: {Left, Right}, FMFSource: nullptr,
1314 Name: "rdx.minmax");
1315 }
1316 CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
1317 Value *Cmp = Builder.CreateCmp(Pred, LHS: Left, RHS: Right, Name: "rdx.minmax.cmp");
1318 Value *Select = Builder.CreateSelect(C: Cmp, True: Left, False: Right, Name: "rdx.minmax.select");
1319 // This select is synthesized fresh, not lowered from an existing branch, so
1320 // it carries no real profile. Mark its weights as explicitly unknown.
1321 if (auto *SI = dyn_cast<SelectInst>(Val: Select))
1322 setExplicitlyUnknownBranchWeightsIfProfiled(I&: *SI, DEBUG_TYPE);
1323 return Select;
1324}
1325
1326// Helper to generate an ordered reduction.
1327Value *llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src,
1328 unsigned Op, RecurKind RdxKind) {
1329 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1330
1331 // Extract and apply reduction ops in ascending order:
1332 // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
1333 Value *Result = Acc;
1334 for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
1335 Value *Ext =
1336 Builder.CreateExtractElement(Vec: Src, Idx: Builder.getInt32(C: ExtractIdx));
1337
1338 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1339 Result = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: Result, RHS: Ext,
1340 Name: "bin.rdx");
1341 } else {
1342 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1343 "Invalid min/max");
1344 Result = createMinMaxOp(Builder, RK: RdxKind, Left: Result, Right: Ext);
1345 }
1346 }
1347
1348 return Result;
1349}
1350
1351Value *llvm::expandReductionViaLoop(IRBuilderBase &Builder, Value *Vec,
1352 unsigned RdxOpcode, Value *Acc,
1353 DominatorTree *DT, LoopInfo *LI) {
1354 auto *VTy = cast<VectorType>(Val: Vec->getType());
1355 Type *EltTy = VTy->getElementType();
1356 Function *F = Builder.GetInsertBlock()->getParent();
1357
1358 const DataLayout &DL = F->getDataLayout();
1359 Type *IdxTy = DL.getIndexType(C&: EltTy->getContext(), AddressSpace: 0);
1360 unsigned MinElts = VTy->getElementCount().getKnownMinValue();
1361 Value *NumElts = Builder.CreateVScale(Ty: IdxTy);
1362 NumElts = Builder.CreateMul(LHS: NumElts, RHS: ConstantInt::get(Ty: IdxTy, V: MinElts));
1363
1364 BasicBlock *EntryBB = Builder.GetInsertBlock();
1365 BasicBlock *LoopBB = BasicBlock::Create(Context&: F->getContext(), Name: "rdx.loop", Parent: F);
1366 BasicBlock *ExitBB = SplitBlock(Old: EntryBB, SplitPt: Builder.GetInsertPoint(), DT, LI,
1367 MSSAU: nullptr, BBName: "rdx.exit");
1368
1369 EntryBB->getTerminator()->eraseFromParent();
1370 Builder.SetInsertPoint(EntryBB);
1371 Builder.CreateBr(Dest: LoopBB);
1372
1373 Builder.SetInsertPoint(LoopBB);
1374 PHINode *IV = Builder.CreatePHI(Ty: IdxTy, NumReservedValues: 2, Name: "rdx.iv");
1375 PHINode *AccPhi = Builder.CreatePHI(Ty: EltTy, NumReservedValues: 2, Name: "rdx.acc");
1376 IV->addIncoming(V: ConstantInt::get(Ty: IdxTy, V: 0), BB: EntryBB);
1377 AccPhi->addIncoming(V: Acc, BB: EntryBB);
1378
1379 Value *Elt = Builder.CreateExtractElement(Vec, Idx: IV);
1380 Value *Res = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)RdxOpcode, LHS: AccPhi,
1381 RHS: Elt, Name: "rdx.op");
1382
1383 Value *NextIV =
1384 Builder.CreateNUWAdd(LHS: IV, RHS: ConstantInt::get(Ty: IdxTy, V: 1), Name: "rdx.next");
1385 IV->addIncoming(V: NextIV, BB: LoopBB);
1386 AccPhi->addIncoming(V: Res, BB: LoopBB);
1387
1388 Value *Done = Builder.CreateICmpEQ(LHS: NextIV, RHS: NumElts, Name: "rdx.done");
1389 Builder.CreateCondBr(Cond: Done, True: ExitBB, False: LoopBB);
1390
1391 // SplitBlock above updated DT/LI for EntryBB -> ExitBB. Now update
1392 // for replacing that edge with EntryBB -> LoopBB -> {ExitBB, LoopBB}.
1393 if (DT)
1394 DT->applyUpdates(Updates: {{DominatorTree::Insert, EntryBB, LoopBB},
1395 {DominatorTree::Insert, LoopBB, LoopBB},
1396 {DominatorTree::Insert, LoopBB, ExitBB},
1397 {DominatorTree::Delete, EntryBB, ExitBB}});
1398
1399 if (LI) {
1400 Loop *NewLoop = LI->AllocateLoop();
1401 if (Loop *ParentLoop = LI->getLoopFor(BB: EntryBB))
1402 ParentLoop->addChildLoop(NewChild: NewLoop);
1403 else
1404 LI->addTopLevelLoop(New: NewLoop);
1405 NewLoop->addBasicBlockToLoop(NewBB: LoopBB, LI&: *LI);
1406 }
1407
1408 Builder.SetInsertPoint(TheBB: ExitBB, IP: ExitBB->begin());
1409 return Res;
1410}
1411
1412// Helper to generate a log2 shuffle reduction.
1413Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
1414 unsigned Op,
1415 TargetTransformInfo::ReductionShuffle RS,
1416 RecurKind RdxKind) {
1417 unsigned VF = cast<FixedVectorType>(Val: Src->getType())->getNumElements();
1418 // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
1419 // and vector ops, reducing the set of values being computed by half each
1420 // round.
1421 assert(isPowerOf2_32(VF) &&
1422 "Reduction emission only supported for pow2 vectors!");
1423 // Note: fast-math-flags flags are controlled by the builder configuration
1424 // and are assumed to apply to all generated arithmetic instructions. Other
1425 // poison generating flags (nsw/nuw/inbounds/inrange/exact) are not part
1426 // of the builder configuration, and since they're not passed explicitly,
1427 // will never be relevant here. Note that it would be generally unsound to
1428 // propagate these from an intrinsic call to the expansion anyways as we/
1429 // change the order of operations.
1430 auto BuildShuffledOp = [&Builder, &Op,
1431 &RdxKind](SmallVectorImpl<int> &ShuffleMask,
1432 Value *&TmpVec) -> void {
1433 Value *Shuf = Builder.CreateShuffleVector(V: TmpVec, Mask: ShuffleMask, Name: "rdx.shuf");
1434 if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
1435 TmpVec = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op, LHS: TmpVec, RHS: Shuf,
1436 Name: "bin.rdx");
1437 } else {
1438 assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind) &&
1439 "Invalid min/max");
1440 TmpVec = createMinMaxOp(Builder, RK: RdxKind, Left: TmpVec, Right: Shuf);
1441 }
1442 };
1443
1444 Value *TmpVec = Src;
1445 if (TargetTransformInfo::ReductionShuffle::Pairwise == RS) {
1446 SmallVector<int, 32> ShuffleMask(VF);
1447 for (unsigned stride = 1; stride < VF; stride <<= 1) {
1448 // Initialise the mask with undef.
1449 llvm::fill(Range&: ShuffleMask, Value: -1);
1450 for (unsigned j = 0; j < VF; j += stride << 1) {
1451 ShuffleMask[j] = j + stride;
1452 }
1453 BuildShuffledOp(ShuffleMask, TmpVec);
1454 }
1455 } else {
1456 SmallVector<int, 32> ShuffleMask(VF);
1457 for (unsigned i = VF; i != 1; i >>= 1) {
1458 // Move the upper half of the vector to the lower half.
1459 for (unsigned j = 0; j != i / 2; ++j)
1460 ShuffleMask[j] = i / 2 + j;
1461
1462 // Fill the rest of the mask with undef.
1463 std::fill(first: &ShuffleMask[i / 2], last: ShuffleMask.end(), value: -1);
1464 BuildShuffledOp(ShuffleMask, TmpVec);
1465 }
1466 }
1467 // The result is in the first element of the vector.
1468 return Builder.CreateExtractElement(Vec: TmpVec, Idx: Builder.getInt32(C: 0));
1469}
1470
1471Value *llvm::createAnyOfReduction(IRBuilderBase &Builder, Value *Src,
1472 Value *InitVal, PHINode *OrigPhi) {
1473 Value *NewVal = nullptr;
1474
1475 // First use the original phi to determine the new value we're trying to
1476 // select from in the loop.
1477 SelectInst *SI = nullptr;
1478 for (auto *U : OrigPhi->users()) {
1479 if ((SI = dyn_cast<SelectInst>(Val: U)))
1480 break;
1481 }
1482 assert(SI && "One user of the original phi should be a select");
1483
1484 if (SI->getTrueValue() == OrigPhi)
1485 NewVal = SI->getFalseValue();
1486 else {
1487 assert(SI->getFalseValue() == OrigPhi &&
1488 "At least one input to the select should be the original Phi");
1489 NewVal = SI->getTrueValue();
1490 }
1491
1492 // If any predicate is true it means that we want to select the new value.
1493 Value *AnyOf =
1494 Src->getType()->isVectorTy() ? Builder.CreateOrReduce(Src) : Src;
1495 // The compares in the loop may yield poison, which propagates through the
1496 // bitwise ORs. Freeze it here before the condition is used.
1497 AnyOf = Builder.CreateFreeze(V: AnyOf);
1498 return Builder.CreateSelect(C: AnyOf, True: NewVal, False: InitVal, Name: "rdx.select");
1499}
1500
1501Value *llvm::getReductionIdentity(Intrinsic::ID RdxID, Type *Ty,
1502 FastMathFlags Flags) {
1503 bool Negative = false;
1504 switch (RdxID) {
1505 default:
1506 llvm_unreachable("Expecting a reduction intrinsic");
1507 case Intrinsic::vector_reduce_add:
1508 case Intrinsic::vector_reduce_mul:
1509 case Intrinsic::vector_reduce_or:
1510 case Intrinsic::vector_reduce_xor:
1511 case Intrinsic::vector_reduce_and:
1512 case Intrinsic::vector_reduce_fadd:
1513 case Intrinsic::vector_reduce_fmul: {
1514 unsigned Opc = getArithmeticReductionInstruction(RdxID);
1515 return ConstantExpr::getBinOpIdentity(Opcode: Opc, Ty, AllowRHSConstant: false,
1516 NSZ: Flags.noSignedZeros());
1517 }
1518 case Intrinsic::vector_reduce_umax:
1519 case Intrinsic::vector_reduce_umin:
1520 case Intrinsic::vector_reduce_smin:
1521 case Intrinsic::vector_reduce_smax: {
1522 Intrinsic::ID ScalarID = getMinMaxReductionIntrinsicOp(RdxID);
1523 return ConstantExpr::getIntrinsicIdentity(ScalarID, Ty);
1524 }
1525 case Intrinsic::vector_reduce_fmax:
1526 case Intrinsic::vector_reduce_fmaximum:
1527 Negative = true;
1528 [[fallthrough]];
1529 case Intrinsic::vector_reduce_fmin:
1530 case Intrinsic::vector_reduce_fminimum: {
1531 bool PropagatesNaN = RdxID == Intrinsic::vector_reduce_fminimum ||
1532 RdxID == Intrinsic::vector_reduce_fmaximum;
1533 const fltSemantics &Semantics = Ty->getFltSemantics();
1534 return (!Flags.noNaNs() && !PropagatesNaN)
1535 ? ConstantFP::getQNaN(Ty, Negative)
1536 : !Flags.noInfs()
1537 ? ConstantFP::getInfinity(Ty, Negative)
1538 : ConstantFP::get(Ty, V: APFloat::getLargest(Sem: Semantics, Negative));
1539 }
1540 }
1541}
1542
1543Value *llvm::getRecurrenceIdentity(RecurKind K, Type *Tp, FastMathFlags FMF) {
1544 assert((!(K == RecurKind::FMin || K == RecurKind::FMax) ||
1545 (FMF.noNaNs() && FMF.noSignedZeros())) &&
1546 "nnan, nsz is expected to be set for FP min/max reduction.");
1547 Intrinsic::ID RdxID = getReductionIntrinsicID(RK: K);
1548 return getReductionIdentity(RdxID, Ty: Tp, Flags: FMF);
1549}
1550
1551Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1552 RecurKind RdxKind) {
1553 auto *SrcVecEltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1554 auto getIdentity = [&]() {
1555 return getRecurrenceIdentity(K: RdxKind, Tp: SrcVecEltTy,
1556 FMF: Builder.getFastMathFlags());
1557 };
1558 switch (RdxKind) {
1559 case RecurKind::AddChainWithSubs:
1560 case RecurKind::Sub:
1561 case RecurKind::Add:
1562 case RecurKind::Mul:
1563 case RecurKind::And:
1564 case RecurKind::Or:
1565 case RecurKind::Xor:
1566 case RecurKind::SMax:
1567 case RecurKind::SMin:
1568 case RecurKind::UMax:
1569 case RecurKind::UMin:
1570 case RecurKind::FMax:
1571 case RecurKind::FMin:
1572 case RecurKind::FMinNum:
1573 case RecurKind::FMaxNum:
1574 case RecurKind::FMinimum:
1575 case RecurKind::FMaximum:
1576 case RecurKind::FMinimumNum:
1577 case RecurKind::FMaximumNum:
1578 return Builder.CreateUnaryIntrinsic(ID: getReductionIntrinsicID(RK: RdxKind), Op: Src);
1579 case RecurKind::FMulAdd:
1580 case RecurKind::FAddChainWithSubs:
1581 case RecurKind::FSub:
1582 case RecurKind::FAdd:
1583 return Builder.CreateFAddReduce(Acc: getIdentity(), Src);
1584 case RecurKind::FMul:
1585 return Builder.CreateFMulReduce(Acc: getIdentity(), Src);
1586 default:
1587 llvm_unreachable("Unhandled opcode");
1588 }
1589}
1590
1591Value *llvm::createSimpleReduction(IRBuilderBase &Builder, Value *Src,
1592 RecurKind Kind, Value *Mask, Value *EVL) {
1593 assert(!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
1594 !RecurrenceDescriptor::isFindRecurrenceKind(Kind) &&
1595 "AnyOf and FindIV reductions are not supported.");
1596 Intrinsic::ID Id = getReductionIntrinsicID(RK: Kind);
1597 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1598 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1599 "No VPIntrinsic for this reduction");
1600 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1601 Value *Iden = getRecurrenceIdentity(K: Kind, Tp: EltTy, FMF: Builder.getFastMathFlags());
1602 Value *Ops[] = {Iden, Src, Mask, EVL};
1603 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1604}
1605
1606Value *llvm::createOrderedReduction(IRBuilderBase &B, RecurKind Kind,
1607 Value *Src, Value *Start) {
1608 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1609 "Unexpected reduction kind");
1610 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1611 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1612
1613 return B.CreateFAddReduce(Acc: Start, Src);
1614}
1615
1616Value *llvm::createOrderedReduction(IRBuilderBase &Builder, RecurKind Kind,
1617 Value *Src, Value *Start, Value *Mask,
1618 Value *EVL) {
1619 assert((Kind == RecurKind::FAdd || Kind == RecurKind::FMulAdd) &&
1620 "Unexpected reduction kind");
1621 assert(Src->getType()->isVectorTy() && "Expected a vector type");
1622 assert(!Start->getType()->isVectorTy() && "Expected a scalar type");
1623
1624 Intrinsic::ID Id = getReductionIntrinsicID(RK: RecurKind::FAdd);
1625 auto VPID = VPIntrinsic::getForIntrinsic(Id);
1626 assert(VPReductionIntrinsic::isVPReduction(VPID) &&
1627 "No VPIntrinsic for this reduction");
1628 auto *EltTy = cast<VectorType>(Val: Src->getType())->getElementType();
1629 Value *Ops[] = {Start, Src, Mask, EVL};
1630 return Builder.CreateIntrinsic(RetTy: EltTy, ID: VPID, Args: Ops);
1631}
1632
1633void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue,
1634 bool IncludeWrapFlags) {
1635 auto *VecOp = dyn_cast<Instruction>(Val: I);
1636 if (!VecOp)
1637 return;
1638 auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(Val: VL[0])
1639 : dyn_cast<Instruction>(Val: OpValue);
1640 if (!Intersection)
1641 return;
1642 const unsigned Opcode = Intersection->getOpcode();
1643 VecOp->copyIRFlags(V: Intersection, IncludeWrapFlags);
1644 for (auto *V : VL) {
1645 auto *Instr = dyn_cast<Instruction>(Val: V);
1646 if (!Instr)
1647 continue;
1648 if (OpValue == nullptr || Opcode == Instr->getOpcode())
1649 VecOp->andIRFlags(V);
1650 }
1651}
1652
1653bool llvm::isKnownNegativeInLoop(const SCEV *S, const Loop *L,
1654 ScalarEvolution &SE) {
1655 const SCEV *Zero = SE.getZero(Ty: S->getType());
1656 return SE.isAvailableAtLoopEntry(S, L) &&
1657 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLT, LHS: S, RHS: Zero);
1658}
1659
1660bool llvm::isKnownNonNegativeInLoop(const SCEV *S, const Loop *L,
1661 ScalarEvolution &SE) {
1662 const SCEV *Zero = SE.getZero(Ty: S->getType());
1663 return SE.isAvailableAtLoopEntry(S, L) &&
1664 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGE, LHS: S, RHS: Zero);
1665}
1666
1667bool llvm::isKnownPositiveInLoop(const SCEV *S, const Loop *L,
1668 ScalarEvolution &SE) {
1669 const SCEV *Zero = SE.getZero(Ty: S->getType());
1670 return SE.isAvailableAtLoopEntry(S, L) &&
1671 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SGT, LHS: S, RHS: Zero);
1672}
1673
1674bool llvm::isKnownNonPositiveInLoop(const SCEV *S, const Loop *L,
1675 ScalarEvolution &SE) {
1676 const SCEV *Zero = SE.getZero(Ty: S->getType());
1677 return SE.isAvailableAtLoopEntry(S, L) &&
1678 SE.isLoopEntryGuardedByCond(L, Pred: ICmpInst::ICMP_SLE, LHS: S, RHS: Zero);
1679}
1680
1681bool llvm::cannotBeMinInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1682 bool Signed) {
1683 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1684 APInt Min = Signed ? APInt::getSignedMinValue(numBits: BitWidth) :
1685 APInt::getMinValue(numBits: BitWidth);
1686 auto Predicate = Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
1687 return SE.isAvailableAtLoopEntry(S, L) &&
1688 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1689 RHS: SE.getConstant(Val: Min));
1690}
1691
1692bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
1693 bool Signed) {
1694 unsigned BitWidth = cast<IntegerType>(Val: S->getType())->getBitWidth();
1695 APInt Max = Signed ? APInt::getSignedMaxValue(numBits: BitWidth) :
1696 APInt::getMaxValue(numBits: BitWidth);
1697 auto Predicate = Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1698 return SE.isAvailableAtLoopEntry(S, L) &&
1699 SE.isLoopEntryGuardedByCond(L, Pred: Predicate, LHS: S,
1700 RHS: SE.getConstant(Val: Max));
1701}
1702
1703//===----------------------------------------------------------------------===//
1704// rewriteLoopExitValues - Optimize IV users outside the loop.
1705// As a side effect, reduces the amount of IV processing within the loop.
1706//===----------------------------------------------------------------------===//
1707
1708static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
1709 SmallPtrSet<const Instruction *, 8> Visited;
1710 SmallVector<const Instruction *, 8> WorkList;
1711 Visited.insert(Ptr: I);
1712 WorkList.push_back(Elt: I);
1713 while (!WorkList.empty()) {
1714 const Instruction *Curr = WorkList.pop_back_val();
1715 // This use is outside the loop, nothing to do.
1716 if (!L->contains(Inst: Curr))
1717 continue;
1718 // Do we assume it is a "hard" use which will not be eliminated easily?
1719 if (Curr->mayHaveSideEffects())
1720 return true;
1721 // Otherwise, add all its users to worklist.
1722 for (const auto *U : Curr->users()) {
1723 auto *UI = cast<Instruction>(Val: U);
1724 if (Visited.insert(Ptr: UI).second)
1725 WorkList.push_back(Elt: UI);
1726 }
1727 }
1728 return false;
1729}
1730
1731// Collect information about PHI nodes which can be transformed in
1732// rewriteLoopExitValues.
1733struct RewritePhi {
1734 PHINode *PN; // For which PHI node is this replacement?
1735 unsigned Ith; // For which incoming value?
1736 const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
1737 Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
1738 bool HighCost; // Is this expansion a high-cost?
1739
1740 RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
1741 bool H)
1742 : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
1743 HighCost(H) {}
1744};
1745
1746// Check whether it is possible to delete the loop after rewriting exit
1747// value. If it is possible, ignore ReplaceExitValue and do rewriting
1748// aggressively.
1749static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) {
1750 BasicBlock *Preheader = L->getLoopPreheader();
1751 // If there is no preheader, the loop will not be deleted.
1752 if (!Preheader)
1753 return false;
1754
1755 // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1.
1756 // We obviate multiple ExitingBlocks case for simplicity.
1757 // TODO: If we see testcase with multiple ExitingBlocks can be deleted
1758 // after exit value rewriting, we can enhance the logic here.
1759 SmallVector<BasicBlock *, 4> ExitingBlocks;
1760 L->getExitingBlocks(ExitingBlocks);
1761 SmallVector<BasicBlock *, 8> ExitBlocks;
1762 L->getUniqueExitBlocks(ExitBlocks);
1763 if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1)
1764 return false;
1765
1766 BasicBlock *ExitBlock = ExitBlocks[0];
1767 BasicBlock::iterator BI = ExitBlock->begin();
1768 while (PHINode *P = dyn_cast<PHINode>(Val&: BI)) {
1769 Value *Incoming = P->getIncomingValueForBlock(BB: ExitingBlocks[0]);
1770
1771 // If the Incoming value of P is found in RewritePhiSet, we know it
1772 // could be rewritten to use a loop invariant value in transformation
1773 // phase later. Skip it in the loop invariant check below.
1774 bool found = false;
1775 for (const RewritePhi &Phi : RewritePhiSet) {
1776 unsigned i = Phi.Ith;
1777 if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
1778 found = true;
1779 break;
1780 }
1781 }
1782
1783 Instruction *I;
1784 if (!found && (I = dyn_cast<Instruction>(Val: Incoming)))
1785 if (!L->hasLoopInvariantOperands(I))
1786 return false;
1787
1788 ++BI;
1789 }
1790
1791 for (auto *BB : L->blocks())
1792 if (llvm::any_of(Range&: *BB, P: [](Instruction &I) {
1793 return I.mayHaveSideEffects();
1794 }))
1795 return false;
1796
1797 return true;
1798}
1799
1800/// Checks if it is safe to call InductionDescriptor::isInductionPHI for \p Phi,
1801/// and returns true if this Phi is an induction phi in the loop. When
1802/// isInductionPHI returns true, \p ID will be also be set by isInductionPHI.
1803static bool checkIsIndPhi(PHINode *Phi, Loop *L, ScalarEvolution *SE,
1804 InductionDescriptor &ID) {
1805 if (!Phi)
1806 return false;
1807 if (!L->getLoopPreheader())
1808 return false;
1809 if (Phi->getParent() != L->getHeader())
1810 return false;
1811 return InductionDescriptor::isInductionPHI(Phi, L, SE, D&: ID);
1812}
1813
1814int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
1815 ScalarEvolution *SE,
1816 const TargetTransformInfo *TTI,
1817 SCEVExpander &Rewriter, DominatorTree *DT,
1818 ReplaceExitVal ReplaceExitValue,
1819 SmallVector<WeakTrackingVH, 16> &DeadInsts) {
1820 // Check a pre-condition.
1821 assert(L->isRecursivelyLCSSAForm(*DT, *LI) &&
1822 "Caller did not preserve LCSSA!");
1823
1824 SmallVector<BasicBlock*, 8> ExitBlocks;
1825 L->getUniqueExitBlocks(ExitBlocks);
1826
1827 SmallVector<RewritePhi, 8> RewritePhiSet;
1828 // Find all values that are computed inside the loop, but used outside of it.
1829 // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan
1830 // the exit blocks of the loop to find them.
1831 for (BasicBlock *ExitBB : ExitBlocks) {
1832 // If there are no PHI nodes in this exit block, then no values defined
1833 // inside the loop are used on this path, skip it.
1834 PHINode *PN = dyn_cast<PHINode>(Val: ExitBB->begin());
1835 if (!PN) continue;
1836
1837 unsigned NumPreds = PN->getNumIncomingValues();
1838
1839 // Iterate over all of the PHI nodes.
1840 BasicBlock::iterator BBI = ExitBB->begin();
1841 while ((PN = dyn_cast<PHINode>(Val: BBI++))) {
1842 if (PN->use_empty())
1843 continue; // dead use, don't replace it
1844
1845 if (!SE->isSCEVable(Ty: PN->getType()))
1846 continue;
1847
1848 // Iterate over all of the values in all the PHI nodes.
1849 for (unsigned i = 0; i != NumPreds; ++i) {
1850 // If the value being merged in is not integer or is not defined
1851 // in the loop, skip it.
1852 Value *InVal = PN->getIncomingValue(i);
1853 if (!isa<Instruction>(Val: InVal))
1854 continue;
1855
1856 // If this pred is for a subloop, not L itself, skip it.
1857 if (LI->getLoopFor(BB: PN->getIncomingBlock(i)) != L)
1858 continue; // The Block is in a subloop, skip it.
1859
1860 // Check that InVal is defined in the loop.
1861 Instruction *Inst = cast<Instruction>(Val: InVal);
1862 if (!L->contains(Inst))
1863 continue;
1864
1865 // Find exit values which are induction variables in the loop, and are
1866 // unused in the loop, with the only use being the exit block PhiNode,
1867 // and the induction variable update binary operator.
1868 // The exit value can be replaced with the final value when it is cheap
1869 // to do so.
1870 if (ReplaceExitValue == UnusedIndVarInLoop) {
1871 InductionDescriptor ID;
1872 PHINode *IndPhi = dyn_cast<PHINode>(Val: Inst);
1873 if (IndPhi) {
1874 if (!checkIsIndPhi(Phi: IndPhi, L, SE, ID))
1875 continue;
1876 // This is an induction PHI. Check that the only users are PHI
1877 // nodes, and induction variable update binary operators.
1878 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1879 if (!isa<PHINode>(Val: U) && !isa<BinaryOperator>(Val: U))
1880 return true;
1881 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: U);
1882 if (B && B != ID.getInductionBinOp())
1883 return true;
1884 return false;
1885 }))
1886 continue;
1887 } else {
1888 // If it is not an induction phi, it must be an induction update
1889 // binary operator with an induction phi user.
1890 BinaryOperator *B = dyn_cast<BinaryOperator>(Val: Inst);
1891 if (!B)
1892 continue;
1893 if (llvm::any_of(Range: Inst->users(), P: [&](User *U) {
1894 PHINode *Phi = dyn_cast<PHINode>(Val: U);
1895 if (Phi != PN && !checkIsIndPhi(Phi, L, SE, ID))
1896 return true;
1897 return false;
1898 }))
1899 continue;
1900 if (B != ID.getInductionBinOp())
1901 continue;
1902 }
1903 }
1904
1905 // Okay, this instruction has a user outside of the current loop
1906 // and varies predictably *inside* the loop. Evaluate the value it
1907 // contains when the loop exits, if possible. We prefer to start with
1908 // expressions which are true for all exits (so as to maximize
1909 // expression reuse by the SCEVExpander), but resort to per-exit
1910 // evaluation if that fails.
1911 const SCEV *ExitValue = SE->getSCEVAtScope(V: Inst, L: L->getParentLoop());
1912 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1913 !SE->isLoopInvariant(S: ExitValue, L) ||
1914 !Rewriter.isSafeToExpand(S: ExitValue)) {
1915 // TODO: This should probably be sunk into SCEV in some way; maybe a
1916 // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for
1917 // most SCEV expressions and other recurrence types (e.g. shift
1918 // recurrences). Is there existing code we can reuse?
1919 const SCEV *ExitCount = SE->getExitCount(L, ExitingBlock: PN->getIncomingBlock(i));
1920 if (isa<SCEVCouldNotCompute>(Val: ExitCount))
1921 continue;
1922 if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Val: SE->getSCEV(V: Inst)))
1923 if (AddRec->getLoop() == L)
1924 ExitValue = AddRec->evaluateAtIteration(It: ExitCount, SE&: *SE);
1925 if (isa<SCEVCouldNotCompute>(Val: ExitValue) ||
1926 !SE->isLoopInvariant(S: ExitValue, L) ||
1927 !Rewriter.isSafeToExpand(S: ExitValue))
1928 continue;
1929 }
1930
1931 // Computing the value outside of the loop brings no benefit if it is
1932 // definitely used inside the loop in a way which can not be optimized
1933 // away. Avoid doing so unless we know we have a value which computes
1934 // the ExitValue already. TODO: This should be merged into SCEV
1935 // expander to leverage its knowledge of existing expressions.
1936 if (ReplaceExitValue != AlwaysRepl && !isa<SCEVConstant>(Val: ExitValue) &&
1937 !isa<SCEVUnknown>(Val: ExitValue) && hasHardUserWithinLoop(L, I: Inst))
1938 continue;
1939
1940 // Check if expansions of this SCEV would count as being high cost.
1941 bool HighCost = Rewriter.isHighCostExpansion(
1942 Exprs: ExitValue, L, Budget: SCEVCheapExpansionBudget, TTI, At: Inst);
1943
1944 // Note that we must not perform expansions until after
1945 // we query *all* the costs, because if we perform temporary expansion
1946 // inbetween, one that we might not intend to keep, said expansion
1947 // *may* affect cost calculation of the next SCEV's we'll query,
1948 // and next SCEV may errneously get smaller cost.
1949
1950 // Collect all the candidate PHINodes to be rewritten.
1951 Instruction *InsertPt =
1952 (isa<PHINode>(Val: Inst) || isa<LandingPadInst>(Val: Inst)) ?
1953 &*Inst->getParent()->getFirstInsertionPt() : Inst;
1954 RewritePhiSet.emplace_back(Args&: PN, Args&: i, Args&: ExitValue, Args&: InsertPt, Args&: HighCost);
1955 }
1956 }
1957 }
1958
1959 // TODO: evaluate whether it is beneficial to change how we calculate
1960 // high-cost: if we have SCEV 'A' which we know we will expand, should we
1961 // calculate the cost of other SCEV's after expanding SCEV 'A', thus
1962 // potentially giving cost bonus to those other SCEV's?
1963
1964 bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
1965 int NumReplaced = 0;
1966
1967 // Transformation.
1968 for (const RewritePhi &Phi : RewritePhiSet) {
1969 PHINode *PN = Phi.PN;
1970
1971 // Only do the rewrite when the ExitValue can be expanded cheaply.
1972 // If LoopCanBeDel is true, rewrite exit value aggressively.
1973 if ((ReplaceExitValue == OnlyCheapRepl ||
1974 ReplaceExitValue == UnusedIndVarInLoop) &&
1975 !LoopCanBeDel && Phi.HighCost)
1976 continue;
1977
1978 Value *ExitVal = Rewriter.expandCodeFor(
1979 SH: Phi.ExpansionSCEV, Ty: Phi.PN->getType(), I: Phi.ExpansionPoint);
1980
1981 LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal
1982 << '\n'
1983 << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
1984
1985#ifndef NDEBUG
1986 // If we reuse an instruction from a loop which is neither L nor one of
1987 // its containing loops, we end up breaking LCSSA form for this loop by
1988 // creating a new use of its instruction.
1989 if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
1990 if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1991 if (EVL != L)
1992 assert(EVL->contains(L) && "LCSSA breach detected!");
1993#endif
1994
1995 NumReplaced++;
1996 Instruction *Inst = cast<Instruction>(Val: PN->getIncomingValue(i: Phi.Ith));
1997 PN->setIncomingValue(i: Phi.Ith, V: ExitVal);
1998 // It's necessary to tell ScalarEvolution about this explicitly so that
1999 // it can walk the def-use list and forget all SCEVs, as it may not be
2000 // watching the PHI itself. Once the new exit value is in place, there
2001 // may not be a def-use connection between the loop and every instruction
2002 // which got a SCEVAddRecExpr for that loop.
2003 SE->forgetValue(V: PN);
2004
2005 // If this instruction is dead now, delete it. Don't do it now to avoid
2006 // invalidating iterators.
2007 if (isInstructionTriviallyDead(I: Inst, TLI))
2008 DeadInsts.push_back(Elt: Inst);
2009
2010 // Replace PN with ExitVal if that is legal and does not break LCSSA.
2011 if (PN->getNumIncomingValues() == 1 &&
2012 LI->replacementPreservesLCSSAForm(From: PN, To: ExitVal)) {
2013 PN->replaceAllUsesWith(V: ExitVal);
2014 PN->eraseFromParent();
2015 }
2016 }
2017
2018 // The insertion point instruction may have been deleted; clear it out
2019 // so that the rewriter doesn't trip over it later.
2020 Rewriter.clearInsertPoint();
2021 return NumReplaced;
2022}
2023
2024/// Utility that implements appending of loops onto a worklist.
2025/// Loops are added in preorder (analogous for reverse postorder for trees),
2026/// and the worklist is processed LIFO.
2027template <typename RangeT>
2028void llvm::appendReversedLoopsToWorklist(
2029 RangeT &&Loops, SmallPriorityWorklist<Loop *, 4> &Worklist) {
2030 // We use an internal worklist to build up the preorder traversal without
2031 // recursion.
2032 SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist;
2033
2034 // We walk the initial sequence of loops in reverse because we generally want
2035 // to visit defs before uses and the worklist is LIFO.
2036 for (Loop *RootL : Loops) {
2037 assert(PreOrderLoops.empty() && "Must start with an empty preorder walk.");
2038 assert(PreOrderWorklist.empty() &&
2039 "Must start with an empty preorder walk worklist.");
2040 PreOrderWorklist.push_back(Elt: RootL);
2041 do {
2042 Loop *L = PreOrderWorklist.pop_back_val();
2043 PreOrderWorklist.append(in_start: L->begin(), in_end: L->end());
2044 PreOrderLoops.push_back(Elt: L);
2045 } while (!PreOrderWorklist.empty());
2046
2047 Worklist.insert(Input: std::move(PreOrderLoops));
2048 PreOrderLoops.clear();
2049 }
2050}
2051
2052template <typename RangeT>
2053void llvm::appendLoopsToWorklist(RangeT &&Loops,
2054 SmallPriorityWorklist<Loop *, 4> &Worklist) {
2055 appendReversedLoopsToWorklist(reverse(Loops), Worklist);
2056}
2057
2058template LLVM_EXPORT_TEMPLATE void
2059llvm::appendLoopsToWorklist<ArrayRef<Loop *> &>(
2060 ArrayRef<Loop *> &Loops, SmallPriorityWorklist<Loop *, 4> &Worklist);
2061
2062template LLVM_EXPORT_TEMPLATE void
2063llvm::appendLoopsToWorklist<Loop &>(Loop &L,
2064 SmallPriorityWorklist<Loop *, 4> &Worklist);
2065
2066void llvm::appendLoopsToWorklist(LoopInfo &LI,
2067 SmallPriorityWorklist<Loop *, 4> &Worklist) {
2068 appendReversedLoopsToWorklist(Loops&: LI, Worklist);
2069}
2070
2071Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
2072 LoopInfo *LI, LPPassManager *LPM) {
2073 Loop &New = *LI->AllocateLoop();
2074 if (PL)
2075 PL->addChildLoop(NewChild: &New);
2076 else
2077 LI->addTopLevelLoop(New: &New);
2078
2079 if (LPM)
2080 LPM->addLoop(L&: New);
2081
2082 // Add all of the blocks in L to the new loop.
2083 for (BasicBlock *BB : L->blocks())
2084 if (LI->getLoopFor(BB) == L)
2085 New.addBasicBlockToLoop(NewBB: cast<BasicBlock>(Val&: VM[BB]), LI&: *LI);
2086
2087 // Add all of the subloops to the new loop.
2088 for (Loop *I : *L)
2089 cloneLoop(L: I, PL: &New, VM, LI, LPM);
2090
2091 return &New;
2092}
2093
2094/// IR Values for the lower and upper bounds of a pointer evolution. We
2095/// need to use value-handles because SCEV expansion can invalidate previously
2096/// expanded values. Thus expansion of a pointer can invalidate the bounds for
2097/// a previous one.
2098struct PointerBounds {
2099 TrackingVH<Value> Start;
2100 TrackingVH<Value> End;
2101 Value *StrideToCheck;
2102};
2103
2104/// Expand code for the lower and upper bound of the pointer group \p CG
2105/// in \p TheLoop. \return the values for the bounds.
2106static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
2107 Loop *TheLoop, Instruction *Loc,
2108 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2109 LLVMContext &Ctx = Loc->getContext();
2110 Type *PtrArithTy = PointerType::get(C&: Ctx, AddressSpace: CG->AddressSpace);
2111
2112 Value *Start = nullptr, *End = nullptr;
2113 LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
2114 const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
2115
2116 // If the Low and High values are themselves loop-variant, then we may want
2117 // to expand the range to include those covered by the outer loop as well.
2118 // There is a trade-off here with the advantage being that creating checks
2119 // using the expanded range permits the runtime memory checks to be hoisted
2120 // out of the outer loop. This reduces the cost of entering the inner loop,
2121 // which can be significant for low trip counts. The disadvantage is that
2122 // there is a chance we may now never enter the vectorized inner loop,
2123 // whereas using a restricted range check could have allowed us to enter at
2124 // least once. This is why the behaviour is not currently the default and is
2125 // controlled by the parameter 'HoistRuntimeChecks'.
2126 if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
2127 isa<SCEVAddRecExpr>(Val: High) && isa<SCEVAddRecExpr>(Val: Low)) {
2128 auto *HighAR = cast<SCEVAddRecExpr>(Val: High);
2129 auto *LowAR = cast<SCEVAddRecExpr>(Val: Low);
2130 const Loop *OuterLoop = TheLoop->getParentLoop();
2131 ScalarEvolution &SE = *Exp.getSE();
2132 const SCEV *Recur = LowAR->getStepRecurrence(SE);
2133 if (Recur == HighAR->getStepRecurrence(SE) &&
2134 HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
2135 BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
2136 const SCEV *OuterExitCount = SE.getExitCount(L: OuterLoop, ExitingBlock: OuterLoopLatch);
2137 if (!isa<SCEVCouldNotCompute>(Val: OuterExitCount) &&
2138 OuterExitCount->getType()->isIntegerTy()) {
2139 const SCEV *NewHigh =
2140 cast<SCEVAddRecExpr>(Val: High)->evaluateAtIteration(It: OuterExitCount, SE);
2141 if (!isa<SCEVCouldNotCompute>(Val: NewHigh)) {
2142 LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
2143 "outer loop in order to permit hoisting\n");
2144 High = NewHigh;
2145 Low = cast<SCEVAddRecExpr>(Val: Low)->getStart();
2146 // If there is a possibility that the stride is negative then we have
2147 // to generate extra checks to ensure the stride is positive.
2148 if (!SE.isKnownNonNegative(
2149 S: SE.applyLoopGuards(Expr: Recur, L: HighAR->getLoop()))) {
2150 Stride = Recur;
2151 LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
2152 "positive: "
2153 << *Stride << '\n');
2154 }
2155 }
2156 }
2157 }
2158 }
2159
2160 Start = Exp.expandCodeFor(SH: Low, Ty: PtrArithTy, I: Loc);
2161 End = Exp.expandCodeFor(SH: High, Ty: PtrArithTy, I: Loc);
2162 if (CG->NeedsFreeze) {
2163 IRBuilder<> Builder(Loc);
2164 Start = Builder.CreateFreeze(V: Start, Name: Start->getName() + ".fr");
2165 End = Builder.CreateFreeze(V: End, Name: End->getName() + ".fr");
2166 }
2167 Value *StrideVal =
2168 Stride ? Exp.expandCodeFor(SH: Stride, Ty: Stride->getType(), I: Loc) : nullptr;
2169 LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
2170 return {.Start: Start, .End: End, .StrideToCheck: StrideVal};
2171}
2172
2173/// Turns a collection of checks into a collection of expanded upper and
2174/// lower bounds for both pointers in the check.
2175static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
2176expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
2177 Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
2178 SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
2179
2180 // Here we're relying on the SCEV Expander's cache to only emit code for the
2181 // same bounds once.
2182 transform(Range: PointerChecks, d_first: std::back_inserter(x&: ChecksWithBounds),
2183 F: [&](const RuntimePointerCheck &Check) {
2184 PointerBounds First = expandBounds(CG: Check.first, TheLoop: L, Loc, Exp,
2185 HoistRuntimeChecks),
2186 Second = expandBounds(CG: Check.second, TheLoop: L, Loc, Exp,
2187 HoistRuntimeChecks);
2188 return std::make_pair(x&: First, y&: Second);
2189 });
2190
2191 return ChecksWithBounds;
2192}
2193
2194Value *llvm::addRuntimeChecks(
2195 Instruction *Loc, Loop *TheLoop,
2196 const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
2197 SCEVExpander &Exp, bool HoistRuntimeChecks) {
2198 // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
2199 // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
2200 auto ExpandedChecks =
2201 expandBounds(PointerChecks, L: TheLoop, Loc, Exp, HoistRuntimeChecks);
2202
2203 LLVMContext &Ctx = Loc->getContext();
2204 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2205 ChkBuilder.SetInsertPoint(Loc);
2206 // Our instructions might fold to a constant.
2207 Value *MemoryRuntimeCheck = nullptr;
2208
2209 for (const auto &[A, B] : ExpandedChecks) {
2210 // Check if two pointers (A and B) conflict where conflict is computed as:
2211 // start(A) <= end(B) && start(B) <= end(A)
2212
2213 assert((A.Start->getType()->getPointerAddressSpace() ==
2214 B.End->getType()->getPointerAddressSpace()) &&
2215 (B.Start->getType()->getPointerAddressSpace() ==
2216 A.End->getType()->getPointerAddressSpace()) &&
2217 "Trying to bounds check pointers with different address spaces");
2218
2219 // [A|B].Start points to the first accessed byte under base [A|B].
2220 // [A|B].End points to the last accessed byte, plus one.
2221 // There is no conflict when the intervals are disjoint:
2222 // NoConflict = (B.Start >= A.End) || (A.Start >= B.End)
2223 //
2224 // bound0 = (B.Start < A.End)
2225 // bound1 = (A.Start < B.End)
2226 // IsConflict = bound0 & bound1
2227 Value *Cmp0 = ChkBuilder.CreateICmpULT(LHS: A.Start, RHS: B.End, Name: "bound0");
2228 Value *Cmp1 = ChkBuilder.CreateICmpULT(LHS: B.Start, RHS: A.End, Name: "bound1");
2229 Value *IsConflict = ChkBuilder.CreateAnd(LHS: Cmp0, RHS: Cmp1, Name: "found.conflict");
2230 if (A.StrideToCheck) {
2231 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2232 LHS: A.StrideToCheck, RHS: ConstantInt::get(Ty: A.StrideToCheck->getType(), V: 0),
2233 Name: "stride.check");
2234 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2235 }
2236 if (B.StrideToCheck) {
2237 Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
2238 LHS: B.StrideToCheck, RHS: ConstantInt::get(Ty: B.StrideToCheck->getType(), V: 0),
2239 Name: "stride.check");
2240 IsConflict = ChkBuilder.CreateOr(LHS: IsConflict, RHS: IsNegativeStride);
2241 }
2242 if (MemoryRuntimeCheck) {
2243 IsConflict =
2244 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2245 }
2246 MemoryRuntimeCheck = IsConflict;
2247 }
2248
2249 Exp.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2250 return MemoryRuntimeCheck;
2251}
2252
2253namespace {
2254/// Rewriter to replace SCEVPtrToIntExpr with SCEVPtrToAddrExpr when the result
2255/// type matches the pointer address type. This allows expressions mixing
2256/// ptrtoint and ptrtoaddr to simplify properly.
2257struct SCEVPtrToAddrRewriter : SCEVRewriteVisitor<SCEVPtrToAddrRewriter> {
2258 const DataLayout &DL;
2259 SCEVPtrToAddrRewriter(ScalarEvolution &SE, const DataLayout &DL)
2260 : SCEVRewriteVisitor(SE), DL(DL) {}
2261
2262 const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *E) {
2263 const SCEV *Op = visit(S: E->getOperand());
2264 if (E->getType() == DL.getAddressType(PtrTy: E->getOperand()->getType()))
2265 return SE.getPtrToAddrExpr(Op);
2266 return Op == E->getOperand() ? E : SE.getPtrToIntExpr(Op, Ty: E->getType());
2267 }
2268};
2269} // namespace
2270
2271Value *llvm::addDiffRuntimeChecks(
2272 Instruction *Loc, ArrayRef<PointerDiffInfo> Checks, SCEVExpander &Expander,
2273 function_ref<Value *(IRBuilderBase &, unsigned)> GetVF, unsigned IC) {
2274
2275 LLVMContext &Ctx = Loc->getContext();
2276 IRBuilder ChkBuilder(Ctx, InstSimplifyFolder(Loc->getDataLayout()));
2277 ChkBuilder.SetInsertPoint(Loc);
2278 // Our instructions might fold to a constant.
2279 Value *MemoryRuntimeCheck = nullptr;
2280
2281 auto &SE = *Expander.getSE();
2282 const DataLayout &DL = Loc->getDataLayout();
2283 SCEVPtrToAddrRewriter Rewriter(SE, DL);
2284 // Map to keep track of created compares, The key is the pair of operands for
2285 // the compare, to allow detecting and re-using redundant compares.
2286 DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
2287 for (const auto &[SrcStart, SinkStart, AccessSize, NeedsFreeze] : Checks) {
2288 Type *Ty = SinkStart->getType();
2289 // Compute VF * IC * AccessSize.
2290 auto *VFTimesICTimesSize =
2291 ChkBuilder.CreateMul(LHS: GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
2292 RHS: ConstantInt::get(Ty, V: IC * AccessSize));
2293 const SCEV *SinkStartRewritten = Rewriter.visit(S: SinkStart);
2294 const SCEV *SrcStartRewritten = Rewriter.visit(S: SrcStart);
2295 Value *Diff = Expander.expandCodeFor(
2296 SH: SE.getMinusSCEV(LHS: SinkStartRewritten, RHS: SrcStartRewritten), Ty, I: Loc);
2297
2298 // Check if the same compare has already been created earlier. In that case,
2299 // there is no need to check it again.
2300 Value *IsConflict = SeenCompares.lookup(Val: {Diff, VFTimesICTimesSize});
2301 if (IsConflict)
2302 continue;
2303
2304 IsConflict =
2305 ChkBuilder.CreateICmpULT(LHS: Diff, RHS: VFTimesICTimesSize, Name: "diff.check");
2306 SeenCompares.insert(KV: {{Diff, VFTimesICTimesSize}, IsConflict});
2307 if (NeedsFreeze)
2308 IsConflict =
2309 ChkBuilder.CreateFreeze(V: IsConflict, Name: IsConflict->getName() + ".fr");
2310 if (MemoryRuntimeCheck) {
2311 IsConflict =
2312 ChkBuilder.CreateOr(LHS: MemoryRuntimeCheck, RHS: IsConflict, Name: "conflict.rdx");
2313 }
2314 MemoryRuntimeCheck = IsConflict;
2315 }
2316
2317 Expander.eraseDeadInstructions(Root: MemoryRuntimeCheck);
2318 return MemoryRuntimeCheck;
2319}
2320
2321std::optional<IVConditionInfo>
2322llvm::hasPartialIVCondition(const Loop &L, unsigned MSSAThreshold,
2323 const MemorySSA &MSSA, AAResults &AA) {
2324 auto *TI = dyn_cast<CondBrInst>(Val: L.getHeader()->getTerminator());
2325 if (!TI)
2326 return {};
2327
2328 auto *CondI = dyn_cast<Instruction>(Val: TI->getCondition());
2329 // The case with the condition outside the loop should already be handled
2330 // earlier.
2331 // Allow CmpInst and TruncInsts as they may be users of load instructions
2332 // and have potential for partial unswitching
2333 if (!CondI || !isa<CmpInst, TruncInst>(Val: CondI) || !L.contains(Inst: CondI))
2334 return {};
2335
2336 SmallVector<Instruction *> InstToDuplicate;
2337 InstToDuplicate.push_back(Elt: CondI);
2338
2339 SmallVector<Value *, 4> WorkList;
2340 WorkList.append(in_start: CondI->op_begin(), in_end: CondI->op_end());
2341
2342 SmallVector<MemoryAccess *, 4> AccessesToCheck;
2343 SmallVector<MemoryLocation, 4> AccessedLocs;
2344 while (!WorkList.empty()) {
2345 Instruction *I = dyn_cast<Instruction>(Val: WorkList.pop_back_val());
2346 if (!I || !L.contains(Inst: I))
2347 continue;
2348
2349 // TODO: support additional instructions.
2350 if (!isa<LoadInst>(Val: I) && !isa<GetElementPtrInst>(Val: I))
2351 return {};
2352
2353 // Do not duplicate volatile and atomic loads.
2354 if (auto *LI = dyn_cast<LoadInst>(Val: I))
2355 if (LI->isVolatile() || LI->isAtomic())
2356 return {};
2357
2358 InstToDuplicate.push_back(Elt: I);
2359 if (MemoryAccess *MA = MSSA.getMemoryAccess(I)) {
2360 if (auto *MemUse = dyn_cast_or_null<MemoryUse>(Val: MA)) {
2361 // Queue the defining access to check for alias checks.
2362 AccessesToCheck.push_back(Elt: MemUse->getDefiningAccess());
2363 AccessedLocs.push_back(Elt: MemoryLocation::get(Inst: I));
2364 } else {
2365 // MemoryDefs may clobber the location or may be atomic memory
2366 // operations. Bail out.
2367 return {};
2368 }
2369 }
2370 WorkList.append(in_start: I->op_begin(), in_end: I->op_end());
2371 }
2372
2373 if (InstToDuplicate.empty())
2374 return {};
2375
2376 SmallVector<BasicBlock *, 4> ExitingBlocks;
2377 L.getExitingBlocks(ExitingBlocks);
2378 auto HasNoClobbersOnPath =
2379 [&L, &AA, &AccessedLocs, &ExitingBlocks, &InstToDuplicate,
2380 MSSAThreshold](BasicBlock *Succ, BasicBlock *Header,
2381 SmallVector<MemoryAccess *, 4> AccessesToCheck)
2382 -> std::optional<IVConditionInfo> {
2383 IVConditionInfo Info;
2384 // First, collect all blocks in the loop that are on a patch from Succ
2385 // to the header.
2386 SmallVector<BasicBlock *, 4> WorkList;
2387 WorkList.push_back(Elt: Succ);
2388 WorkList.push_back(Elt: Header);
2389 SmallPtrSet<BasicBlock *, 4> Seen;
2390 Seen.insert(Ptr: Header);
2391 Info.PathIsNoop &=
2392 all_of(Range&: *Header, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2393
2394 while (!WorkList.empty()) {
2395 BasicBlock *Current = WorkList.pop_back_val();
2396 if (!L.contains(BB: Current))
2397 continue;
2398 const auto &SeenIns = Seen.insert(Ptr: Current);
2399 if (!SeenIns.second)
2400 continue;
2401
2402 Info.PathIsNoop &= all_of(
2403 Range&: *Current, P: [](Instruction &I) { return !I.mayHaveSideEffects(); });
2404 WorkList.append(in_start: succ_begin(BB: Current), in_end: succ_end(BB: Current));
2405 }
2406
2407 // Require at least 2 blocks on a path through the loop. This skips
2408 // paths that directly exit the loop.
2409 if (Seen.size() < 2)
2410 return {};
2411
2412 // Next, check if there are any MemoryDefs that are on the path through
2413 // the loop (in the Seen set) and they may-alias any of the locations in
2414 // AccessedLocs. If that is the case, they may modify the condition and
2415 // partial unswitching is not possible.
2416 SmallPtrSet<MemoryAccess *, 4> SeenAccesses;
2417 while (!AccessesToCheck.empty()) {
2418 MemoryAccess *Current = AccessesToCheck.pop_back_val();
2419 auto SeenI = SeenAccesses.insert(Ptr: Current);
2420 if (!SeenI.second || !Seen.contains(Ptr: Current->getBlock()))
2421 continue;
2422
2423 // Bail out if exceeded the threshold.
2424 if (SeenAccesses.size() >= MSSAThreshold)
2425 return {};
2426
2427 // MemoryUse are read-only accesses.
2428 if (isa<MemoryUse>(Val: Current))
2429 continue;
2430
2431 // For a MemoryDef, check if is aliases any of the location feeding
2432 // the original condition.
2433 if (auto *CurrentDef = dyn_cast<MemoryDef>(Val: Current)) {
2434 if (any_of(Range&: AccessedLocs, P: [&AA, CurrentDef](MemoryLocation &Loc) {
2435 return isModSet(
2436 MRI: AA.getModRefInfo(I: CurrentDef->getMemoryInst(), OptLoc: Loc));
2437 }))
2438 return {};
2439 }
2440
2441 for (Use &U : Current->uses())
2442 AccessesToCheck.push_back(Elt: cast<MemoryAccess>(Val: U.getUser()));
2443 }
2444
2445 // We could also allow loops with known trip counts without mustprogress,
2446 // but ScalarEvolution may not be available.
2447 Info.PathIsNoop &= isMustProgress(L: &L);
2448
2449 // If the path is considered a no-op so far, check if it reaches a
2450 // single exit block without any phis. This ensures no values from the
2451 // loop are used outside of the loop.
2452 if (Info.PathIsNoop) {
2453 for (auto *Exiting : ExitingBlocks) {
2454 if (!Seen.contains(Ptr: Exiting))
2455 continue;
2456 for (auto *Succ : successors(BB: Exiting)) {
2457 if (L.contains(BB: Succ))
2458 continue;
2459
2460 Info.PathIsNoop &= Succ->phis().empty() &&
2461 (!Info.ExitForPath || Info.ExitForPath == Succ);
2462 if (!Info.PathIsNoop)
2463 break;
2464 assert((!Info.ExitForPath || Info.ExitForPath == Succ) &&
2465 "cannot have multiple exit blocks");
2466 Info.ExitForPath = Succ;
2467 }
2468 }
2469 }
2470 if (!Info.ExitForPath)
2471 Info.PathIsNoop = false;
2472
2473 Info.InstToDuplicate = std::move(InstToDuplicate);
2474 return Info;
2475 };
2476
2477 // If we branch to the same successor, partial unswitching will not be
2478 // beneficial.
2479 if (TI->getSuccessor(i: 0) == TI->getSuccessor(i: 1))
2480 return {};
2481
2482 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 0), L.getHeader(),
2483 AccessesToCheck)) {
2484 Info->KnownValue = ConstantInt::getTrue(Context&: TI->getContext());
2485 return Info;
2486 }
2487 if (auto Info = HasNoClobbersOnPath(TI->getSuccessor(i: 1), L.getHeader(),
2488 AccessesToCheck)) {
2489 Info->KnownValue = ConstantInt::getFalse(Context&: TI->getContext());
2490 return Info;
2491 }
2492
2493 return {};
2494}
2495