1//===- LoopUnrollAndJam.cpp - Loop unroll and jam pass --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements an unroll and jam pass. Most of the work is done by
10// Utils/UnrollLoopAndJam.cpp.
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Scalar/LoopUnrollAndJamPass.h"
14#include "llvm/ADT/ArrayRef.h"
15#include "llvm/ADT/PriorityWorklist.h"
16#include "llvm/ADT/SmallPtrSet.h"
17#include "llvm/ADT/StringRef.h"
18#include "llvm/Analysis/AssumptionCache.h"
19#include "llvm/Analysis/CodeMetrics.h"
20#include "llvm/Analysis/DependenceAnalysis.h"
21#include "llvm/Analysis/LoopAnalysisManager.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/Analysis/LoopNestAnalysis.h"
24#include "llvm/Analysis/LoopPass.h"
25#include "llvm/Analysis/OptimizationRemarkEmitter.h"
26#include "llvm/Analysis/ScalarEvolution.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/IR/BasicBlock.h"
29#include "llvm/IR/Constants.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/Instructions.h"
33#include "llvm/IR/Metadata.h"
34#include "llvm/IR/PassManager.h"
35#include "llvm/Support/Casting.h"
36#include "llvm/Support/CommandLine.h"
37#include "llvm/Support/Debug.h"
38#include "llvm/Support/raw_ostream.h"
39#include "llvm/Transforms/Scalar/LoopPassManager.h"
40#include "llvm/Transforms/Utils/LoopPeel.h"
41#include "llvm/Transforms/Utils/LoopUtils.h"
42#include "llvm/Transforms/Utils/UnrollLoop.h"
43#include <cassert>
44#include <cstdint>
45
46namespace llvm {
47class Instruction;
48class Value;
49} // namespace llvm
50
51using namespace llvm;
52
53#define DEBUG_TYPE "loop-unroll-and-jam"
54
55/// @{
56/// Metadata attribute names
57static const char *const LLVMLoopUnrollAndJamFollowupAll =
58 "llvm.loop.unroll_and_jam.followup_all";
59static const char *const LLVMLoopUnrollAndJamFollowupInner =
60 "llvm.loop.unroll_and_jam.followup_inner";
61static const char *const LLVMLoopUnrollAndJamFollowupOuter =
62 "llvm.loop.unroll_and_jam.followup_outer";
63static const char *const LLVMLoopUnrollAndJamFollowupRemainderInner =
64 "llvm.loop.unroll_and_jam.followup_remainder_inner";
65static const char *const LLVMLoopUnrollAndJamFollowupRemainderOuter =
66 "llvm.loop.unroll_and_jam.followup_remainder_outer";
67/// @}
68
69static cl::opt<bool>
70 AllowUnrollAndJam("allow-unroll-and-jam", cl::Hidden,
71 cl::desc("Allows loops to be unroll-and-jammed."));
72
73static cl::opt<unsigned> UnrollAndJamCount(
74 "unroll-and-jam-count", cl::Hidden,
75 cl::desc("Use this unroll count for all loops including those with "
76 "unroll_and_jam_count pragma values, for testing purposes"));
77
78static cl::opt<unsigned> UnrollAndJamThreshold(
79 "unroll-and-jam-threshold", cl::init(Val: 60), cl::Hidden,
80 cl::desc("Threshold to use for inner loop when doing unroll and jam."));
81
82static cl::opt<unsigned> PragmaUnrollAndJamThreshold(
83 "pragma-unroll-and-jam-threshold", cl::init(Val: 1024), cl::Hidden,
84 cl::desc("Unrolled size limit for loops with an unroll_and_jam(full) or "
85 "unroll_count pragma."));
86
87// Returns true if the loop has any metadata starting with Prefix. For example a
88// Prefix of "llvm.loop.unroll." returns true if we have any unroll metadata.
89static bool hasAnyUnrollPragma(const Loop *L, StringRef Prefix) {
90 if (MDNode *LoopID = L->getLoopID()) {
91 // First operand should refer to the loop id itself.
92 assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
93 assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
94
95 for (unsigned I = 1, E = LoopID->getNumOperands(); I < E; ++I) {
96 MDNode *MD = dyn_cast<MDNode>(Val: LoopID->getOperand(I));
97 if (!MD)
98 continue;
99
100 MDString *S = dyn_cast<MDString>(Val: MD->getOperand(I: 0));
101 if (!S)
102 continue;
103
104 if (S->getString().starts_with(Prefix))
105 return true;
106 }
107 }
108 return false;
109}
110
111// Returns true if the loop has an unroll_and_jam(enable) pragma.
112static bool hasUnrollAndJamEnablePragma(const Loop *L) {
113 return getUnrollMetadataForLoop(L, Name: "llvm.loop.unroll_and_jam.enable");
114}
115
116// If loop has an unroll_and_jam_count pragma return the (necessarily
117// positive) value from the pragma. Otherwise return 0.
118static unsigned unrollAndJamCountPragmaValue(const Loop *L) {
119 MDNode *MD = getUnrollMetadataForLoop(L, Name: "llvm.loop.unroll_and_jam.count");
120 if (MD) {
121 assert(MD->getNumOperands() == 2 &&
122 "Unroll count hint metadata should have two operands.");
123 unsigned Count =
124 mdconst::extract<ConstantInt>(MD: MD->getOperand(I: 1))->getZExtValue();
125 assert(Count >= 1 && "Unroll count must be positive.");
126 return Count;
127 }
128 return 0;
129}
130
131// Returns loop size estimation for unrolled loop.
132static uint64_t
133getUnrollAndJammedLoopSize(unsigned LoopSize,
134 TargetTransformInfo::UnrollingPreferences &UP) {
135 assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!");
136 return static_cast<uint64_t>(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns;
137}
138
139// Calculates unroll and jam count and writes it to UP.Count. Returns true if
140// unroll count was set explicitly.
141static bool computeUnrollAndJamCount(
142 Loop *L, Loop *SubLoop, const TargetTransformInfo &TTI, DominatorTree &DT,
143 LoopInfo *LI, AssumptionCache *AC, ScalarEvolution &SE,
144 const SmallPtrSetImpl<const Value *> &EphValues,
145 OptimizationRemarkEmitter *ORE, unsigned OuterTripCount,
146 unsigned OuterTripMultiple, const UnrollCostEstimator &OuterUCE,
147 unsigned InnerTripCount, unsigned InnerLoopSize,
148 TargetTransformInfo::UnrollingPreferences &UP,
149 TargetTransformInfo::PeelingPreferences &PP) {
150 unsigned OuterLoopSize = OuterUCE.getRolledLoopSize();
151 // Use computeUnrollCount from the loop unroller to get a count for
152 // unrolling the outer loop. This uses UP.Threshold / UP.PartialThreshold /
153 // UP.MaxCount to come up with sensible loop values.
154 // We have already checked that the loop has no unroll.* pragmas.
155 computeUnrollCount(L, TTI, DT, LI, AC, SE, EphValues, ORE, TripCount: OuterTripCount,
156 /*MaxTripCount*/ 0, /*MaxOrZero*/ false, TripMultiple: OuterTripMultiple,
157 UCE: OuterUCE, UP, PP);
158
159 // Override with any explicit Count from the "unroll-and-jam-count" option.
160 bool UserUnrollCount = UnrollAndJamCount.getNumOccurrences() > 0;
161 if (UserUnrollCount) {
162 UP.Count = UnrollAndJamCount;
163 UP.Force = true;
164 if (UP.AllowRemainder &&
165 getUnrollAndJammedLoopSize(LoopSize: OuterLoopSize, UP) < UP.Threshold &&
166 getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) <
167 UP.UnrollAndJamInnerLoopThreshold)
168 return true;
169 }
170
171 // Check for unroll_and_jam pragmas
172 unsigned PragmaCount = unrollAndJamCountPragmaValue(L);
173 if (PragmaCount > 0) {
174 UP.Count = PragmaCount;
175 UP.Runtime = true;
176 UP.Force = true;
177 if ((UP.AllowRemainder || (OuterTripMultiple % PragmaCount == 0)) &&
178 getUnrollAndJammedLoopSize(LoopSize: OuterLoopSize, UP) < UP.Threshold &&
179 getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) <
180 UP.UnrollAndJamInnerLoopThreshold)
181 return true;
182 }
183
184 bool PragmaEnableUnroll = hasUnrollAndJamEnablePragma(L);
185 bool ExplicitUnrollAndJamCount = PragmaCount > 0 || UserUnrollCount;
186 bool ExplicitUnrollAndJam = PragmaEnableUnroll || ExplicitUnrollAndJamCount;
187
188 // If the loop has an unrolling pragma, we want to be more aggressive with
189 // unrolling limits.
190 if (ExplicitUnrollAndJam)
191 UP.UnrollAndJamInnerLoopThreshold = PragmaUnrollAndJamThreshold;
192
193 if (!UP.AllowRemainder && getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) >=
194 UP.UnrollAndJamInnerLoopThreshold) {
195 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; can't create remainder and "
196 "inner loop too large\n");
197 UP.Count = 0;
198 return false;
199 }
200
201 // We have a sensible limit for the outer loop, now adjust it for the inner
202 // loop and UP.UnrollAndJamInnerLoopThreshold. If the outer limit was set
203 // explicitly, we want to stick to it.
204 if (!ExplicitUnrollAndJamCount && UP.AllowRemainder) {
205 while (UP.Count != 0 && getUnrollAndJammedLoopSize(LoopSize: InnerLoopSize, UP) >=
206 UP.UnrollAndJamInnerLoopThreshold)
207 UP.Count--;
208 }
209
210 // If we are explicitly unroll and jamming, we are done. Otherwise there are a
211 // number of extra performance heuristics to check.
212 if (ExplicitUnrollAndJam)
213 return true;
214
215 // If the inner loop count is known and small, leave the entire loop nest to
216 // be the unroller
217 if (InnerTripCount && InnerLoopSize * InnerTripCount < UP.Threshold) {
218 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; small inner loop count is "
219 "being left for the unroller\n");
220 UP.Count = 0;
221 return false;
222 }
223
224 // Check for situations where UnJ is likely to be unprofitable. Including
225 // subloops with more than 1 block.
226 if (SubLoop->getBlocks().size() != 1) {
227 LLVM_DEBUG(
228 dbgs() << "Won't unroll-and-jam; More than one inner loop block\n");
229 UP.Count = 0;
230 return false;
231 }
232
233 // Limit to loops where there is something to gain from unrolling and
234 // jamming the loop. In this case, look for loads that are invariant in the
235 // outer loop and can become shared.
236 unsigned NumInvariant = 0;
237 for (BasicBlock *BB : SubLoop->getBlocks()) {
238 for (Instruction &I : *BB) {
239 if (auto *Ld = dyn_cast<LoadInst>(Val: &I)) {
240 Value *V = Ld->getPointerOperand();
241 const SCEV *LSCEV = SE.getSCEVAtScope(V, L);
242 if (SE.isLoopInvariant(S: LSCEV, L))
243 NumInvariant++;
244 }
245 }
246 }
247 if (NumInvariant == 0) {
248 LLVM_DEBUG(dbgs() << "Won't unroll-and-jam; No loop invariant loads\n");
249 UP.Count = 0;
250 return false;
251 }
252
253 return false;
254}
255
256static LoopUnrollResult
257tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
258 ScalarEvolution &SE, const TargetTransformInfo &TTI,
259 AssumptionCache &AC, DependenceInfo &DI,
260 OptimizationRemarkEmitter &ORE, int OptLevel) {
261 TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences(
262 L, SE, TTI, BFI: nullptr, PSI: nullptr, ORE, OptLevel, UserThreshold: std::nullopt, UserCount: std::nullopt,
263 UserAllowPartial: std::nullopt, UserRuntime: std::nullopt, UserUpperBound: std::nullopt, UserFullUnrollMaxCount: std::nullopt);
264 TargetTransformInfo::PeelingPreferences PP =
265 gatherPeelingPreferences(L, SE, TTI, UserAllowPeeling: std::nullopt, UserAllowProfileBasedPeeling: std::nullopt);
266
267 TransformationMode EnableMode = hasUnrollAndJamTransformation(L);
268 if (EnableMode & TM_Disable)
269 return LoopUnrollResult::Unmodified;
270 if (EnableMode & TM_ForcedByUser)
271 UP.UnrollAndJam = true;
272
273 if (AllowUnrollAndJam.getNumOccurrences() > 0)
274 UP.UnrollAndJam = AllowUnrollAndJam;
275 if (UnrollAndJamThreshold.getNumOccurrences() > 0)
276 UP.UnrollAndJamInnerLoopThreshold = UnrollAndJamThreshold;
277 // Exit early if unrolling is disabled.
278 if (!UP.UnrollAndJam || UP.UnrollAndJamInnerLoopThreshold == 0)
279 return LoopUnrollResult::Unmodified;
280
281 LLVM_DEBUG(dbgs() << "Loop Unroll and Jam: F["
282 << L->getHeader()->getParent()->getName() << "] Loop %"
283 << L->getHeader()->getName() << "\n");
284
285 // A loop with any unroll pragma (enabling/disabling/count/etc) is left for
286 // the unroller, so long as it does not explicitly have unroll_and_jam
287 // metadata. This means #pragma nounroll will disable unroll and jam as well
288 // as unrolling
289 if (hasAnyUnrollPragma(L, Prefix: "llvm.loop.unroll.") &&
290 !hasAnyUnrollPragma(L, Prefix: "llvm.loop.unroll_and_jam.")) {
291 LLVM_DEBUG(dbgs() << " Disabled due to pragma.\n");
292 return LoopUnrollResult::Unmodified;
293 }
294
295 if (!isSafeToUnrollAndJam(L, SE, DT, DI, LI&: *LI)) {
296 LLVM_DEBUG(dbgs() << " Disabled due to not being safe.\n");
297 return LoopUnrollResult::Unmodified;
298 }
299
300 // Approximate the loop size and collect useful info
301 SmallPtrSet<const Value *, 32> EphValues;
302 CodeMetrics::collectEphemeralValues(L, AC: &AC, EphValues);
303 Loop *SubLoop = L->getSubLoops()[0];
304 UnrollCostEstimator InnerUCE(SubLoop, TTI, EphValues, UP.BEInsns);
305 UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);
306
307 if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
308 LLVM_DEBUG(dbgs() << " Loop not considered unrollable\n");
309 return LoopUnrollResult::Unmodified;
310 }
311
312 unsigned InnerLoopSize = InnerUCE.getRolledLoopSize();
313 LLVM_DEBUG(dbgs() << " Outer Loop Size: " << OuterUCE.getRolledLoopSize()
314 << "\n");
315 LLVM_DEBUG(dbgs() << " Inner Loop Size: " << InnerLoopSize << "\n");
316
317 if (InnerUCE.NumInlineCandidates != 0 || OuterUCE.NumInlineCandidates != 0) {
318 LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
319 return LoopUnrollResult::Unmodified;
320 }
321 // FIXME: The call to canUnroll() allows some controlled convergent
322 // operations, but we block them here for future changes.
323 if (InnerUCE.Convergence != ConvergenceKind::None ||
324 OuterUCE.Convergence != ConvergenceKind::None) {
325 LLVM_DEBUG(
326 dbgs() << " Not unrolling loop with convergent instructions.\n");
327 return LoopUnrollResult::Unmodified;
328 }
329
330 // Save original loop IDs for after the transformation.
331 MDNode *OrigOuterLoopID = L->getLoopID();
332 MDNode *OrigSubLoopID = SubLoop->getLoopID();
333
334 // To assign the loop id of the epilogue, assign it before unrolling it so it
335 // is applied to every inner loop of the epilogue. We later apply the loop ID
336 // for the jammed inner loop.
337 std::optional<MDNode *> NewInnerEpilogueLoopID = makeFollowupLoopID(
338 OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
339 LLVMLoopUnrollAndJamFollowupRemainderInner});
340 if (NewInnerEpilogueLoopID)
341 SubLoop->setLoopID(*NewInnerEpilogueLoopID);
342
343 // Find trip count and trip multiple
344 BasicBlock *Latch = L->getLoopLatch();
345 BasicBlock *SubLoopLatch = SubLoop->getLoopLatch();
346 unsigned OuterTripCount = SE.getSmallConstantTripCount(L, ExitingBlock: Latch);
347 unsigned OuterTripMultiple = SE.getSmallConstantTripMultiple(L, ExitingBlock: Latch);
348 unsigned InnerTripCount = SE.getSmallConstantTripCount(L: SubLoop, ExitingBlock: SubLoopLatch);
349
350 // Decide if, and by how much, to unroll
351 bool IsCountSetExplicitly = computeUnrollAndJamCount(
352 L, SubLoop, TTI, DT, LI, AC: &AC, SE, EphValues, ORE: &ORE, OuterTripCount,
353 OuterTripMultiple, OuterUCE, InnerTripCount, InnerLoopSize, UP, PP);
354 if (UP.Count <= 1)
355 return LoopUnrollResult::Unmodified;
356 // Unroll factor (Count) must be less or equal to TripCount.
357 if (OuterTripCount && UP.Count > OuterTripCount)
358 UP.Count = OuterTripCount;
359
360 Loop *EpilogueOuterLoop = nullptr;
361 LoopUnrollResult UnrollResult = UnrollAndJamLoop(
362 L, Count: UP.Count, TripCount: OuterTripCount, TripMultiple: OuterTripMultiple, UnrollRemainder: UP.UnrollRemainder, LI,
363 SE: &SE, DT: &DT, AC: &AC, TTI: &TTI, ORE: &ORE, EpilogueLoop: &EpilogueOuterLoop);
364
365 // Assign new loop attributes.
366 if (EpilogueOuterLoop) {
367 std::optional<MDNode *> NewOuterEpilogueLoopID = makeFollowupLoopID(
368 OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
369 LLVMLoopUnrollAndJamFollowupRemainderOuter});
370 if (NewOuterEpilogueLoopID)
371 EpilogueOuterLoop->setLoopID(*NewOuterEpilogueLoopID);
372 }
373
374 std::optional<MDNode *> NewInnerLoopID =
375 makeFollowupLoopID(OrigLoopID: OrigOuterLoopID, FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll,
376 LLVMLoopUnrollAndJamFollowupInner});
377 if (NewInnerLoopID)
378 SubLoop->setLoopID(*NewInnerLoopID);
379 else
380 SubLoop->setLoopID(OrigSubLoopID);
381
382 if (UnrollResult == LoopUnrollResult::PartiallyUnrolled) {
383 std::optional<MDNode *> NewOuterLoopID = makeFollowupLoopID(
384 OrigLoopID: OrigOuterLoopID,
385 FollowupAttrs: {LLVMLoopUnrollAndJamFollowupAll, LLVMLoopUnrollAndJamFollowupOuter});
386 if (NewOuterLoopID) {
387 L->setLoopID(*NewOuterLoopID);
388
389 // Do not setLoopAlreadyUnrolled if a followup was given.
390 return UnrollResult;
391 }
392 }
393
394 // If loop has an unroll count pragma or unrolled by explicitly set count
395 // mark loop as unrolled to prevent unrolling beyond that requested.
396 if (UnrollResult != LoopUnrollResult::FullyUnrolled && IsCountSetExplicitly)
397 L->setLoopAlreadyUnrolled();
398
399 return UnrollResult;
400}
401
402static bool tryToUnrollAndJamLoop(LoopNest &LN, DominatorTree &DT, LoopInfo &LI,
403 ScalarEvolution &SE,
404 const TargetTransformInfo &TTI,
405 AssumptionCache &AC, DependenceInfo &DI,
406 OptimizationRemarkEmitter &ORE, int OptLevel,
407 LPMUpdater &U, bool &AnyLoopRemoved) {
408 bool DidSomething = false;
409 ArrayRef<Loop *> Loops = LN.getLoops();
410 Loop *OutmostLoop = &LN.getOutermostLoop();
411
412 // Add the loop nests in the reverse order of LN. See method
413 // declaration.
414 SmallPriorityWorklist<Loop *, 4> Worklist;
415 appendLoopsToWorklist(Loops, Worklist);
416 while (!Worklist.empty()) {
417 Loop *L = Worklist.pop_back_val();
418 std::string LoopName = std::string(L->getName());
419 LoopUnrollResult Result =
420 tryToUnrollAndJamLoop(L, DT, LI: &LI, SE, TTI, AC, DI, ORE, OptLevel);
421 if (Result != LoopUnrollResult::Unmodified)
422 DidSomething = true;
423 if (Result == LoopUnrollResult::FullyUnrolled) {
424 if (L == OutmostLoop)
425 U.markLoopAsDeleted(L&: *L, Name: LoopName);
426 AnyLoopRemoved = true;
427 }
428 }
429
430 return DidSomething;
431}
432
433PreservedAnalyses LoopUnrollAndJamPass::run(LoopNest &LN,
434 LoopAnalysisManager &AM,
435 LoopStandardAnalysisResults &AR,
436 LPMUpdater &U) {
437 Function &F = *LN.getParent();
438
439 DependenceInfo DI(&F, &AR.AA, &AR.SE, &AR.LI);
440 OptimizationRemarkEmitter ORE(&F);
441
442 bool AnyLoopRemoved = false;
443 if (!tryToUnrollAndJamLoop(LN, DT&: AR.DT, LI&: AR.LI, SE&: AR.SE, TTI: AR.TTI, AC&: AR.AC, DI, ORE,
444 OptLevel, U, AnyLoopRemoved))
445 return PreservedAnalyses::all();
446
447 auto PA = getLoopPassPreservedAnalyses();
448 if (!AnyLoopRemoved)
449 PA.preserve<LoopNestAnalysis>();
450 return PA;
451}
452