1//===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements an unroll and jam pass. Most of the work is done by
10// Utils/UnrollLoopAndJam.cpp.
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/SmallPtrSet.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/AssumptionCache.h"
19#include "llvm/Analysis/CodeMetrics.h"
20#include "llvm/Analysis/DependenceAnalysis.h"
21#include "llvm/Analysis/LoopAnalysisManager.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/LoopNestAnalysis.h"
24#include "llvm/Analysis/LoopPass.h"
25#include "llvm/Analysis/OptimizationRemarkEmitter.h"
26#include "llvm/Analysis/ScalarEvolution.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/IR/BasicBlock.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/Instructions.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/IR/PassManager.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Support/Debug.h"
38#include "llvm/Support/raw_ostream.h"
39#include "llvm/Transforms/Scalar/LoopPassManager.h"
40#include "llvm/Transforms/Utils/LoopPeel.h"
41#include "llvm/Transforms/Utils/LoopUtils.h"
42#include "llvm/Transforms/Utils/UnrollLoop.h"
43#include <cassert>
44#include <cstdint>
45
46namespace llvm {
47class Instruction;
48class Value;
49} // namespace llvm
50
51using namespace llvm;
52
53#define DEBUG_TYPE "loop-unroll-and-jam"
54
55/// @{
56/// Metadata attribute names
57static const char *const LLVMLoopUnrollAndJamFollowupAll =
58 "llvm.loop.unroll_and_jam.followup_all";
59static const char *const LLVMLoopUnrollAndJamFollowupInner =
60 "llvm.loop.unroll_and_jam.followup_inner";
61static const char *const LLVMLoopUnrollAndJamFollowupOuter =
62 "llvm.loop.unroll_and_jam.followup_outer";
63static const char *const LLVMLoopUnrollAndJamFollowupRemainderInner =
64 "llvm.loop.unroll_and_jam.followup_remainder_inner";
65static const char *const LLVMLoopUnrollAndJamFollowupRemainderOuter =
66 "llvm.loop.unroll_and_jam.followup_remainder_outer";
67/// @}
68
69static cl::opt<bool>
70 AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden,
71 cl::desc("Allows loops to be unroll-and-jammed."));
72
73static cl::opt<unsigned> UnrollAndJamCount(
74 "unroll-and-jam-count", cl::Hidden,
75 cl::desc("Use this unroll count for all loops including those with "
76 "unroll_and_jam_count pragma values, for testing purposes"));
77
78static cl::opt<unsigned> UnrollAndJamThreshold(
79 "unroll-and-jam-threshold", cl::init(Val: 60), cl::Hidden,
80 cl::desc("Threshold to use for inner loop when doing unroll and jam."));
81
82static cl::opt<unsigned> PragmaUnrollAndJamThreshold(
83 "pragma-unroll-and-jam-threshold", cl::init(Val: 1024), cl::Hidden,
84 cl::desc("Unrolled size limit for loops with an unroll_and_jam(full) or "
85 "unroll_count pragma."));
86
87// Returns true if the loop has any metadata starting with Prefix. For example a
88// Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata.
89static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
90 if (MDNode *LoopID = L->getLoopID()) {
91 // First operand should refer to the loop id itself.
92 assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
93 assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
94
95 for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) {
96 MDNode *MD = dyn_cast<MDNode>(Val: LoopID->getOperand(I));
97 if (!MD)
98 continue;
99
100 MDString *S = dyn_cast<MDString>(Val: MD->getOperand(I: 0));
101 if (!S)
102 continue;
103
104 if (S->getString().starts_with(Prefix))
105 return true;
106 }
107 }
108 return false;
109}
110
111// Returns true if the loop has an unroll_and_jam(enable) pragma.
112static bool hasUnrollAndJamEnablePragma(const Loop *L) {
113 return getUnrollMetadataForLoop(L, Name: "llvm.loop.unroll_and_jam.enable");
114}
115
116// If loop has an unroll_and_jam_count pragma return the (necessarily
117// positive) value from the pragma. Otherwise return 0.
118static unsigned unrollAndJamCountPragmaValue(const Loop *L) {
119 MDNode *MD = getUnrollMetadataForLoop(L, Name: "llvm.loop.unroll_and_jam.count");
120 if (MD) {
121 assert(MD->getNumOperands() == 2 &&
122 "Unroll count hint metadata should have two operands.");
123 unsigned Count =
124 mdconst::extract<ConstantInt>(MD: MD->getOperand(I: 1))->getZExtValue();
125 assert(Count >= 1 && "Unroll count must be positive.");
126 return Count;
127 }
128 return 0;
129}
130
131// Returns loop size estimation for unrolled loop.
132static uint64_t
133getUnrollAndJammedLoopSize(unsigned LoopSize,
134 TargetTransformInfo::UnrollingPreferences &UP) {
135 assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!");
136 return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns;
137}
138
139// Calculates unroll and jam count and writes it to UP.Count. Returns true if
140// unroll count was set explicitly.
141static bool computeUnrollAndJamCount(
142 Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT,
143 LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE,
144 const SmallPtrSetImpl<const Value *> &EphValues,
145 OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
146 unsigned OuterTripMultiple, const UnrollCostEstimator &OuterUCE,
147 unsigned InnerTripCount, unsigned InnerLoopSize,
148 TargetTransformInfo::UnrollingPreferences &UP,
149 TargetTransformInfo::PeelingPreferences &PP) {
150 unsigned OuterLoopSize = OuterUCE.getRolledLoopSize();
151 // First up use computeUnrollCount from the loop unroller to get a count
152 // for unrolling the outer loop, plus any loops requiring explicit
153 // unrolling we leave to the unroller. This uses UP.Threshold /
154 // UP.PartialThreshold / UP.MaxCount to come up with sensible loop values.
155 // We have already checked that the loop has no unroll.* pragmas.
156 computeUnrollCount(L, TTI, DT, LI, AC, SE, EphValues, ORE, TripCount: OuterTripCount,
157 /*MaxTripCount*/ 0, /*MaxOrZero*/ false, TripMultiple: OuterTripMultiple,
158 UCE: OuterUCE, UP, PP);
159
160 // Override with any explicit Count from the "unroll-and-jam-count" option.
161 bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0;
162 if (UserUnrollCount) {
163 UP.Count = UnrollAndJamCount;
164 UP.Force = true;
165 if (UP.AllowRemainder &&
166 getUnrollAndJammedLoopSize(LoopSize: OuterLoopSize, UP) < UP.Threshold &&
167 getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) <
168 UP.UnrollAndJamInnerLoopThreshold)
169 return true;
170 }
171
172 // Check for unroll_and_jam pragmas
173 unsigned PragmaCount = unrollAndJamCountPragmaValue(L);
174 if (PragmaCount > 0) {
175 UP.Count = PragmaCount;
176 UP.Runtime = true;
177 UP.Force = true;
178 if ((UP.AllowRemainder || (OuterTripMultiple % PragmaCount == 0)) &&
179 getUnrollAndJammedLoopSize(LoopSize: OuterLoopSize, UP) < UP.Threshold &&
180 getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) <
181 UP.UnrollAndJamInnerLoopThreshold)
182 return true;
183 }
184
185 bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L);
186 bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount;
187 bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount;
188
189 // If the loop has an unrolling pragma, we want to be more aggressive with
190 // unrolling limits.
191 if (ExplicitUnrollAndJam)
192 UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold;
193
194 if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) >=
195 UP.UnrollAndJamInnerLoopThreshold) {
196 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't create remainder and "
197 "inner loop too large\n");
198 UP.Count = 0;
199 return false;
200 }
201
202 // We have a sensible limit for the outer loop, now adjust it for the inner
203 // loop and UP.UnrollAndJamInnerLoopThreshold. If the outer limit was set
204 // explicitly, we want to stick to it.
205 if (!ExplicitUnrollAndJamCount && UP.AllowRemainder) {
206 while (UP.Count != 0 && getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) >=
207 UP.UnrollAndJamInnerLoopThreshold)
208 UP.Count--;
209 }
210
211 // If we are explicitly unroll and jamming, we are done. Otherwise there are a
212 // number of extra performance heuristics to check.
213 if (ExplicitUnrollAndJam)
214 return true;
215
216 // If the inner loop count is known and small, leave the entire loop nest to
217 // be the unroller
218 if (InnerTripCount && InnerLoopSize * InnerTripCount < UP.Threshold) {
219 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; small inner loop count is "
220 "being left for the unroller\n");
221 UP.Count = 0;
222 return false;
223 }
224
225 // Check for situations where UnJ is likely to be unprofitable. Including
226 // subloops with more than 1 block.
227 if (SubLoop->getBlocks().size() != 1) {
228 LLVM_DEBUG(
229 dbgs() << "Won't unroll-and-jam; More than one inner loop block\n");
230 UP.Count = 0;
231 return false;
232 }
233
234 // Limit to loops where there is something to gain from unrolling and
235 // jamming the loop. In this case, look for loads that are invariant in the
236 // outer loop and can become shared.
237 unsigned NumInvariant = 0;
238 for (BasicBlock *BB : SubLoop->getBlocks()) {
239 for (Instruction &I : *BB) {
240 if (auto *Ld = dyn_cast<LoadInst>(Val: &I)) {
241 Value *V = Ld->getPointerOperand();
242 const SCEV *LSCEV = SE.getSCEVAtScope(V, L);
243 if (SE.isLoopInvariant(S: LSCEV, L))
244 NumInvariant++;
245 }
246 }
247 }
248 if (NumInvariant == 0) {
249 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; No loop invariant loads\n");
250 UP.Count = 0;
251 return false;
252 }
253
254 return false;
255}
256
257static LoopUnrollResult
258tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
259 ScalarEvolution &SE, const TargetTransformInfo &TTI,
260 AssumptionCache &AC, DependenceInfo &DI,
261 OptimizationRemarkEmitter &ORE, int OptLevel) {
262 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
263 L, SE, TTI, BFI: nullptr, PSI: nullptr, ORE, OptLevel, UserThreshold: std::nullopt, UserCount: std::nullopt,
264 UserAllowPartial: std::nullopt, UserRuntime: std::nullopt, UserUpperBound: std::nullopt, UserFullUnrollMaxCount: std::nullopt);
265 TargetTransformInfo::PeelingPreferences PP =
266 gatherPeelingPreferences(L, SE, TTI, UserAllowPeeling: std::nullopt, UserAllowProfileBasedPeeling: std::nullopt);
267
268 TransformationMode EnableMode = hasUnrollAndJamTransformation(L);
269 if (EnableMode & TM_Disable)
270 return LoopUnrollResult::Unmodified;
271 if (EnableMode & TM_ForcedByUser)
272 UP.UnrollAndJam = true;
273
274 if (AllowUnrollAndJam.getNumOccurrences() > 0)
275 UP.UnrollAndJam = AllowUnrollAndJam;
276 if (UnrollAndJamThreshold.getNumOccurrences() > 0)
277 UP.UnrollAndJamInnerLoopThreshold = UnrollAndJamThreshold;
278 // Exit early if unrolling is disabled.
279 if (!UP.UnrollAndJam || UP.UnrollAndJamInnerLoopThreshold == 0)
280 return LoopUnrollResult::Unmodified;
281
282 LLVM_DEBUG(dbgs() << "Loop Unroll and Jam: F["
283 << L->getHeader()->getParent()->getName() << "] Loop %"
284 << L->getHeader()->getName() << "\n");
285
286 // A loop with any unroll pragma (enabling/disabling/count/etc) is left for
287 // the unroller, so long as it does not explicitly have unroll_and_jam
288 // metadata. This means #pragma nounroll will disable unroll and jam as well
289 // as unrolling
290 if (hasAnyUnrollPragma(L, Prefix: "llvm.loop.unroll.") &&
291 !hasAnyUnrollPragma(L, Prefix: "llvm.loop.unroll_and_jam.")) {
292 LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n");
293 return LoopUnrollResult::Unmodified;
294 }
295
296 if (!isSafeToUnrollAndJam(L, SE, DT, DI, LI&: *LI)) {
297 LLVM_DEBUG(dbgs() << " Disabled due to not being safe.\n");
298 return LoopUnrollResult::Unmodified;
299 }
300
301 // Approximate the loop size and collect useful info
302 SmallPtrSet<const Value *, 32> EphValues;
303 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
304 Loop *SubLoop = L->getSubLoops()[0];
305 UnrollCostEstimator InnerUCE(SubLoop, TTI, EphValues, UP.BEInsns);
306 UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);
307
308 if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
309 LLVM_DEBUG(dbgs() << " Loop not considered unrollable\n");
310 return LoopUnrollResult::Unmodified;
311 }
312
313 unsigned InnerLoopSize = InnerUCE.getRolledLoopSize();
314 LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterUCE.getRolledLoopSize()
315 << "\n");
316 LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n");
317
318 if (InnerUCE.NumInlineCandidates != 0 || OuterUCE.NumInlineCandidates != 0) {
319 LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
320 return LoopUnrollResult::Unmodified;
321 }
322 // FIXME: The call to canUnroll() allows some controlled convergent
323 // operations, but we block them here for future changes.
324 if (InnerUCE.Convergence != ConvergenceKind::None ||
325 OuterUCE.Convergence != ConvergenceKind::None) {
326 LLVM_DEBUG(
327 dbgs() << " Not unrolling loop with convergent instructions.\n");
328 return LoopUnrollResult::Unmodified;
329 }
330
331 // Save original loop IDs for after the transformation.
332 MDNode *OrigOuterLoopID = L->getLoopID();
333 MDNode *OrigSubLoopID = SubLoop->getLoopID();
334
335 // To assign the loop id of the epilogue, assign it before unrolling it so it
336 // is applied to every inner loop of the epilogue. We later apply the loop ID
337 // for the jammed inner loop.
338 std::optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID(
339 OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
340 LLVMLoopUnrollAndJamFollowupRemainderInner});
341 if (NewInnerEpilogueLoopID)
342 SubLoop->setLoopID(*NewInnerEpilogueLoopID);
343
344 // Find trip count and trip multiple
345 BasicBlock *Latch = L->getLoopLatch();
346 BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
347 unsigned OuterTripCount = SE.getSmallConstantTripCount(L, ExitingBlock: Latch);
348 unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, ExitingBlock: Latch);
349 unsigned InnerTripCount = SE.getSmallConstantTripCount(L: SubLoop, ExitingBlock: SubLoopLatch);
350
351 // Decide if, and by how much, to unroll
352 bool IsCountSetExplicitly = computeUnrollAndJamCount(
353 L, SubLoop, TTI, DT, LI, AC: &AC, SE, EphValues, ORE: &ORE, OuterTripCount,
354 OuterTripMultiple, OuterUCE, InnerTripCount, InnerLoopSize, UP, PP);
355 if (UP.Count <= 1)
356 return LoopUnrollResult::Unmodified;
357 // Unroll factor (Count) must be less or equal to TripCount.
358 if (OuterTripCount && UP.Count > OuterTripCount)
359 UP.Count = OuterTripCount;
360
361 Loop *EpilogueOuterLoop = nullptr;
362 LoopUnrollResult UnrollResult = UnrollAndJamLoop(
363 L, Count: UP.Count, TripCount: OuterTripCount, TripMultiple: OuterTripMultiple, UnrollRemainder: UP.UnrollRemainder, LI,
364 SE: &SE, DT: &DT, AC: &AC, TTI: &TTI, ORE: &ORE, EpilogueLoop: &EpilogueOuterLoop);
365
366 // Assign new loop attributes.
367 if (EpilogueOuterLoop) {
368 std::optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID(
369 OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
370 LLVMLoopUnrollAndJamFollowupRemainderOuter});
371 if (NewOuterEpilogueLoopID)
372 EpilogueOuterLoop->setLoopID(*NewOuterEpilogueLoopID);
373 }
374
375 std::optional<MDNode *> NewInnerLoopID =
376 makeFollowupLoopID(OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
377 LLVMLoopUnrollAndJamFollowupInner});
378 if (NewInnerLoopID)
379 SubLoop->setLoopID(*NewInnerLoopID);
380 else
381 SubLoop->setLoopID(OrigSubLoopID);
382
383 if (UnrollResult == LoopUnrollResult::PartiallyUnrolled) {
384 std::optional<MDNode *> NewOuterLoopID = makeFollowupLoopID(
385 OrigLoopID: OrigOuterLoopID,
386 FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter});
387 if (NewOuterLoopID) {
388 L->setLoopID(*NewOuterLoopID);
389
390 // Do not setLoopAlreadyUnrolled if a followup was given.
391 return UnrollResult;
392 }
393 }
394
395 // If loop has an unroll count pragma or unrolled by explicitly set count
396 // mark loop as unrolled to prevent unrolling beyond that requested.
397 if (UnrollResult != LoopUnrollResult::FullyUnrolled && IsCountSetExplicitly)
398 L->setLoopAlreadyUnrolled();
399
400 return UnrollResult;
401}
402
403static bool tryToUnrollAndJamLoop(LoopNest &LN, DominatorTree &DT, LoopInfo &LI,
404 ScalarEvolution &SE,
405 const TargetTransformInfo &TTI,
406 AssumptionCache &AC, DependenceInfo &DI,
407 OptimizationRemarkEmitter &ORE, int OptLevel,
408 LPMUpdater &U, bool &AnyLoopRemoved) {
409 bool DidSomething = false;
410 ArrayRef<Loop *> Loops = LN.getLoops();
411 Loop *OutmostLoop = &LN.getOutermostLoop();
412
413 // Add the loop nests in the reverse order of LN. See method
414 // declaration.
415 SmallPriorityWorklist<Loop *, 4> Worklist;
416 appendLoopsToWorklist(Loops, Worklist);
417 while (!Worklist.empty()) {
418 Loop *L = Worklist.pop_back_val();
419 std::string LoopName = std::string(L->getName());
420 LoopUnrollResult Result =
421 tryToUnrollAndJamLoop(L, DT, LI: &LI, SE, TTI, AC, DI, ORE, OptLevel);
422 if (Result != LoopUnrollResult::Unmodified)
423 DidSomething = true;
424 if (Result == LoopUnrollResult::FullyUnrolled) {
425 if (L == OutmostLoop)
426 U.markLoopAsDeleted(L&: *L, Name: LoopName);
427 AnyLoopRemoved = true;
428 }
429 }
430
431 return DidSomething;
432}
433
434PreservedAnalyses LoopUnrollAndJamPass::run(LoopNest &LN,
435 LoopAnalysisManager &AM,
436 LoopStandardAnalysisResults &AR,
437 LPMUpdater &U) {
438 Function &F = *LN.getParent();
439
440 DependenceInfo DI(&F, &AR.AA, &AR.SE, &AR.LI);
441 OptimizationRemarkEmitter ORE(&F);
442
443 bool AnyLoopRemoved = false;
444 if (!tryToUnrollAndJamLoop(LN, DT&: AR.DT, LI&: AR.LI, SE&: AR.SE, TTI: AR.TTI, AC&: AR.AC, DI, ORE,
445 OptLevel, U, AnyLoopRemoved))
446 return PreservedAnalyses::all();
447
448 auto PA = getLoopPassPreservedAnalyses();
449 if (!AnyLoopRemoved)
450 PA.preserve<LoopNestAnalysis>();
451 return PA;
452}
453