1 | //===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements a Loop Data Prefetching Pass. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "llvm/Transforms/Scalar/LoopDataPrefetch.h" |
14 | #include "llvm/InitializePasses.h" |
15 | |
16 | #include "llvm/ADT/DepthFirstIterator.h" |
17 | #include "llvm/ADT/Statistic.h" |
18 | #include "llvm/Analysis/AssumptionCache.h" |
19 | #include "llvm/Analysis/CodeMetrics.h" |
20 | #include "llvm/Analysis/LoopInfo.h" |
21 | #include "llvm/Analysis/OptimizationRemarkEmitter.h" |
22 | #include "llvm/Analysis/ScalarEvolution.h" |
23 | #include "llvm/Analysis/ScalarEvolutionExpressions.h" |
24 | #include "llvm/Analysis/TargetTransformInfo.h" |
25 | #include "llvm/IR/Dominators.h" |
26 | #include "llvm/IR/Function.h" |
27 | #include "llvm/IR/Module.h" |
28 | #include "llvm/Support/CommandLine.h" |
29 | #include "llvm/Support/Debug.h" |
30 | #include "llvm/Transforms/Scalar.h" |
31 | #include "llvm/Transforms/Utils.h" |
32 | #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" |
33 | |
34 | #define DEBUG_TYPE "loop-data-prefetch" |
35 | |
36 | using namespace llvm; |
37 | |
38 | // By default, we limit this to creating 16 PHIs (which is a little over half |
39 | // of the allocatable register set). |
40 | static cl::opt<bool> |
41 | PrefetchWrites("loop-prefetch-writes" , cl::Hidden, cl::init(Val: false), |
42 | cl::desc("Prefetch write addresses" )); |
43 | |
44 | static cl::opt<unsigned> |
45 | PrefetchDistance("prefetch-distance" , |
46 | cl::desc("Number of instructions to prefetch ahead" ), |
47 | cl::Hidden); |
48 | |
49 | static cl::opt<unsigned> |
50 | MinPrefetchStride("min-prefetch-stride" , |
51 | cl::desc("Min stride to add prefetches" ), cl::Hidden); |
52 | |
53 | static cl::opt<unsigned> MaxPrefetchIterationsAhead( |
54 | "max-prefetch-iters-ahead" , |
55 | cl::desc("Max number of iterations to prefetch ahead" ), cl::Hidden); |
56 | |
57 | STATISTIC(NumPrefetches, "Number of prefetches inserted" ); |
58 | |
59 | namespace { |
60 | |
61 | /// Loop prefetch implementation class. |
62 | class LoopDataPrefetch { |
63 | public: |
64 | LoopDataPrefetch(AssumptionCache *AC, DominatorTree *DT, LoopInfo *LI, |
65 | ScalarEvolution *SE, const TargetTransformInfo *TTI, |
66 | OptimizationRemarkEmitter *ORE) |
67 | : AC(AC), DT(DT), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} |
68 | |
69 | bool run(); |
70 | |
71 | private: |
72 | bool runOnLoop(Loop *L); |
73 | |
74 | /// Check if the stride of the accesses is large enough to |
75 | /// warrant a prefetch. |
76 | bool isStrideLargeEnough(const SCEVAddRecExpr *AR, unsigned TargetMinStride); |
77 | |
78 | unsigned getMinPrefetchStride(unsigned NumMemAccesses, |
79 | unsigned NumStridedMemAccesses, |
80 | unsigned NumPrefetches, |
81 | bool HasCall) { |
82 | if (MinPrefetchStride.getNumOccurrences() > 0) |
83 | return MinPrefetchStride; |
84 | return TTI->getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, |
85 | NumPrefetches, HasCall); |
86 | } |
87 | |
88 | unsigned getPrefetchDistance() { |
89 | if (PrefetchDistance.getNumOccurrences() > 0) |
90 | return PrefetchDistance; |
91 | return TTI->getPrefetchDistance(); |
92 | } |
93 | |
94 | unsigned getMaxPrefetchIterationsAhead() { |
95 | if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) |
96 | return MaxPrefetchIterationsAhead; |
97 | return TTI->getMaxPrefetchIterationsAhead(); |
98 | } |
99 | |
100 | bool doPrefetchWrites() { |
101 | if (PrefetchWrites.getNumOccurrences() > 0) |
102 | return PrefetchWrites; |
103 | return TTI->enableWritePrefetching(); |
104 | } |
105 | |
106 | AssumptionCache *AC; |
107 | DominatorTree *DT; |
108 | LoopInfo *LI; |
109 | ScalarEvolution *SE; |
110 | const TargetTransformInfo *TTI; |
111 | OptimizationRemarkEmitter *ORE; |
112 | }; |
113 | |
114 | /// Legacy class for inserting loop data prefetches. |
115 | class LoopDataPrefetchLegacyPass : public FunctionPass { |
116 | public: |
117 | static char ID; // Pass ID, replacement for typeid |
118 | LoopDataPrefetchLegacyPass() : FunctionPass(ID) { |
119 | initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry()); |
120 | } |
121 | |
122 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
123 | AU.addRequired<AssumptionCacheTracker>(); |
124 | AU.addRequired<DominatorTreeWrapperPass>(); |
125 | AU.addPreserved<DominatorTreeWrapperPass>(); |
126 | AU.addRequired<LoopInfoWrapperPass>(); |
127 | AU.addPreserved<LoopInfoWrapperPass>(); |
128 | AU.addRequiredID(ID&: LoopSimplifyID); |
129 | AU.addPreservedID(ID&: LoopSimplifyID); |
130 | AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); |
131 | AU.addRequired<ScalarEvolutionWrapperPass>(); |
132 | AU.addPreserved<ScalarEvolutionWrapperPass>(); |
133 | AU.addRequired<TargetTransformInfoWrapperPass>(); |
134 | } |
135 | |
136 | bool runOnFunction(Function &F) override; |
137 | }; |
138 | } |
139 | |
140 | char LoopDataPrefetchLegacyPass::ID = 0; |
141 | INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch" , |
142 | "Loop Data Prefetch" , false, false) |
143 | INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) |
144 | INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) |
145 | INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) |
146 | INITIALIZE_PASS_DEPENDENCY(LoopSimplify) |
147 | INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) |
148 | INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) |
149 | INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch" , |
150 | "Loop Data Prefetch" , false, false) |
151 | |
152 | FunctionPass *llvm::createLoopDataPrefetchPass() { |
153 | return new LoopDataPrefetchLegacyPass(); |
154 | } |
155 | |
156 | bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR, |
157 | unsigned TargetMinStride) { |
158 | // No need to check if any stride goes. |
159 | if (TargetMinStride <= 1) |
160 | return true; |
161 | |
162 | const auto *ConstStride = dyn_cast<SCEVConstant>(Val: AR->getStepRecurrence(SE&: *SE)); |
163 | // If MinStride is set, don't prefetch unless we can ensure that stride is |
164 | // larger. |
165 | if (!ConstStride) |
166 | return false; |
167 | |
168 | unsigned AbsStride = std::abs(i: ConstStride->getAPInt().getSExtValue()); |
169 | return TargetMinStride <= AbsStride; |
170 | } |
171 | |
172 | PreservedAnalyses LoopDataPrefetchPass::run(Function &F, |
173 | FunctionAnalysisManager &AM) { |
174 | DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(IR&: F); |
175 | LoopInfo *LI = &AM.getResult<LoopAnalysis>(IR&: F); |
176 | ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(IR&: F); |
177 | AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(IR&: F); |
178 | OptimizationRemarkEmitter *ORE = |
179 | &AM.getResult<OptimizationRemarkEmitterAnalysis>(IR&: F); |
180 | const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(IR&: F); |
181 | |
182 | LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE); |
183 | bool Changed = LDP.run(); |
184 | |
185 | if (Changed) { |
186 | PreservedAnalyses PA; |
187 | PA.preserve<DominatorTreeAnalysis>(); |
188 | PA.preserve<LoopAnalysis>(); |
189 | return PA; |
190 | } |
191 | |
192 | return PreservedAnalyses::all(); |
193 | } |
194 | |
195 | bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { |
196 | if (skipFunction(F)) |
197 | return false; |
198 | |
199 | DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); |
200 | LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); |
201 | ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); |
202 | AssumptionCache *AC = |
203 | &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); |
204 | OptimizationRemarkEmitter *ORE = |
205 | &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); |
206 | const TargetTransformInfo *TTI = |
207 | &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); |
208 | |
209 | LoopDataPrefetch LDP(AC, DT, LI, SE, TTI, ORE); |
210 | return LDP.run(); |
211 | } |
212 | |
213 | bool LoopDataPrefetch::run() { |
214 | // If PrefetchDistance is not set, don't run the pass. This gives an |
215 | // opportunity for targets to run this pass for selected subtargets only |
216 | // (whose TTI sets PrefetchDistance and CacheLineSize). |
217 | if (getPrefetchDistance() == 0 || TTI->getCacheLineSize() == 0) { |
218 | LLVM_DEBUG(dbgs() << "Please set both PrefetchDistance and CacheLineSize " |
219 | "for loop data prefetch.\n" ); |
220 | return false; |
221 | } |
222 | |
223 | bool MadeChange = false; |
224 | |
225 | for (Loop *I : *LI) |
226 | for (Loop *L : depth_first(G: I)) |
227 | MadeChange |= runOnLoop(L); |
228 | |
229 | return MadeChange; |
230 | } |
231 | |
232 | /// A record for a potential prefetch made during the initial scan of the |
233 | /// loop. This is used to let a single prefetch target multiple memory accesses. |
234 | struct Prefetch { |
235 | /// The address formula for this prefetch as returned by ScalarEvolution. |
236 | const SCEVAddRecExpr *LSCEVAddRec; |
237 | /// The point of insertion for the prefetch instruction. |
238 | Instruction *InsertPt = nullptr; |
239 | /// True if targeting a write memory access. |
240 | bool Writes = false; |
241 | /// The (first seen) prefetched instruction. |
242 | Instruction *MemI = nullptr; |
243 | |
244 | /// Constructor to create a new Prefetch for \p I. |
245 | Prefetch(const SCEVAddRecExpr *L, Instruction *I) : LSCEVAddRec(L) { |
246 | addInstruction(I); |
247 | }; |
248 | |
249 | /// Add the instruction \param I to this prefetch. If it's not the first |
250 | /// one, 'InsertPt' and 'Writes' will be updated as required. |
251 | /// \param PtrDiff the known constant address difference to the first added |
252 | /// instruction. |
253 | void addInstruction(Instruction *I, DominatorTree *DT = nullptr, |
254 | int64_t PtrDiff = 0) { |
255 | if (!InsertPt) { |
256 | MemI = I; |
257 | InsertPt = I; |
258 | Writes = isa<StoreInst>(Val: I); |
259 | } else { |
260 | BasicBlock *PrefBB = InsertPt->getParent(); |
261 | BasicBlock *InsBB = I->getParent(); |
262 | if (PrefBB != InsBB) { |
263 | BasicBlock *DomBB = DT->findNearestCommonDominator(A: PrefBB, B: InsBB); |
264 | if (DomBB != PrefBB) |
265 | InsertPt = DomBB->getTerminator(); |
266 | } |
267 | |
268 | if (isa<StoreInst>(Val: I) && PtrDiff == 0) |
269 | Writes = true; |
270 | } |
271 | } |
272 | }; |
273 | |
274 | bool LoopDataPrefetch::runOnLoop(Loop *L) { |
275 | bool MadeChange = false; |
276 | |
277 | // Only prefetch in the inner-most loop |
278 | if (!L->isInnermost()) |
279 | return MadeChange; |
280 | |
281 | SmallPtrSet<const Value *, 32> EphValues; |
282 | CodeMetrics::collectEphemeralValues(L, AC, EphValues); |
283 | |
284 | // Calculate the number of iterations ahead to prefetch |
285 | CodeMetrics Metrics; |
286 | bool HasCall = false; |
287 | for (const auto BB : L->blocks()) { |
288 | // If the loop already has prefetches, then assume that the user knows |
289 | // what they are doing and don't add any more. |
290 | for (auto &I : *BB) { |
291 | if (isa<CallInst>(Val: &I) || isa<InvokeInst>(Val: &I)) { |
292 | if (const Function *F = cast<CallBase>(Val&: I).getCalledFunction()) { |
293 | if (F->getIntrinsicID() == Intrinsic::prefetch) |
294 | return MadeChange; |
295 | if (TTI->isLoweredToCall(F)) |
296 | HasCall = true; |
297 | } else { // indirect call. |
298 | HasCall = true; |
299 | } |
300 | } |
301 | } |
302 | Metrics.analyzeBasicBlock(BB, TTI: *TTI, EphValues); |
303 | } |
304 | |
305 | if (!Metrics.NumInsts.isValid()) |
306 | return MadeChange; |
307 | |
308 | unsigned LoopSize = *Metrics.NumInsts.getValue(); |
309 | if (!LoopSize) |
310 | LoopSize = 1; |
311 | |
312 | unsigned ItersAhead = getPrefetchDistance() / LoopSize; |
313 | if (!ItersAhead) |
314 | ItersAhead = 1; |
315 | |
316 | if (ItersAhead > getMaxPrefetchIterationsAhead()) |
317 | return MadeChange; |
318 | |
319 | unsigned ConstantMaxTripCount = SE->getSmallConstantMaxTripCount(L); |
320 | if (ConstantMaxTripCount && ConstantMaxTripCount < ItersAhead + 1) |
321 | return MadeChange; |
322 | |
323 | unsigned NumMemAccesses = 0; |
324 | unsigned NumStridedMemAccesses = 0; |
325 | SmallVector<Prefetch, 16> Prefetches; |
326 | for (const auto BB : L->blocks()) |
327 | for (auto &I : *BB) { |
328 | Value *PtrValue; |
329 | Instruction *MemI; |
330 | |
331 | if (LoadInst *LMemI = dyn_cast<LoadInst>(Val: &I)) { |
332 | MemI = LMemI; |
333 | PtrValue = LMemI->getPointerOperand(); |
334 | } else if (StoreInst *SMemI = dyn_cast<StoreInst>(Val: &I)) { |
335 | if (!doPrefetchWrites()) continue; |
336 | MemI = SMemI; |
337 | PtrValue = SMemI->getPointerOperand(); |
338 | } else continue; |
339 | |
340 | unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); |
341 | if (!TTI->shouldPrefetchAddressSpace(AS: PtrAddrSpace)) |
342 | continue; |
343 | NumMemAccesses++; |
344 | if (L->isLoopInvariant(V: PtrValue)) |
345 | continue; |
346 | |
347 | const SCEV *LSCEV = SE->getSCEV(V: PtrValue); |
348 | const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(Val: LSCEV); |
349 | if (!LSCEVAddRec) |
350 | continue; |
351 | NumStridedMemAccesses++; |
352 | |
353 | // We don't want to double prefetch individual cache lines. If this |
354 | // access is known to be within one cache line of some other one that |
355 | // has already been prefetched, then don't prefetch this one as well. |
356 | bool DupPref = false; |
357 | for (auto &Pref : Prefetches) { |
358 | const SCEV *PtrDiff = SE->getMinusSCEV(LHS: LSCEVAddRec, RHS: Pref.LSCEVAddRec); |
359 | if (const SCEVConstant *ConstPtrDiff = |
360 | dyn_cast<SCEVConstant>(Val: PtrDiff)) { |
361 | int64_t PD = std::abs(i: ConstPtrDiff->getValue()->getSExtValue()); |
362 | if (PD < (int64_t) TTI->getCacheLineSize()) { |
363 | Pref.addInstruction(I: MemI, DT, PtrDiff: PD); |
364 | DupPref = true; |
365 | break; |
366 | } |
367 | } |
368 | } |
369 | if (!DupPref) |
370 | Prefetches.push_back(Elt: Prefetch(LSCEVAddRec, MemI)); |
371 | } |
372 | |
373 | unsigned TargetMinStride = |
374 | getMinPrefetchStride(NumMemAccesses, NumStridedMemAccesses, |
375 | NumPrefetches: Prefetches.size(), HasCall); |
376 | |
377 | LLVM_DEBUG(dbgs() << "Prefetching " << ItersAhead |
378 | << " iterations ahead (loop size: " << LoopSize << ") in " |
379 | << L->getHeader()->getParent()->getName() << ": " << *L); |
380 | LLVM_DEBUG(dbgs() << "Loop has: " |
381 | << NumMemAccesses << " memory accesses, " |
382 | << NumStridedMemAccesses << " strided memory accesses, " |
383 | << Prefetches.size() << " potential prefetch(es), " |
384 | << "a minimum stride of " << TargetMinStride << ", " |
385 | << (HasCall ? "calls" : "no calls" ) << ".\n" ); |
386 | |
387 | for (auto &P : Prefetches) { |
388 | // Check if the stride of the accesses is large enough to warrant a |
389 | // prefetch. |
390 | if (!isStrideLargeEnough(AR: P.LSCEVAddRec, TargetMinStride)) |
391 | continue; |
392 | |
393 | BasicBlock *BB = P.InsertPt->getParent(); |
394 | SCEVExpander SCEVE(*SE, BB->getDataLayout(), "prefaddr" ); |
395 | const SCEV *NextLSCEV = SE->getAddExpr(LHS: P.LSCEVAddRec, RHS: SE->getMulExpr( |
396 | LHS: SE->getConstant(Ty: P.LSCEVAddRec->getType(), V: ItersAhead), |
397 | RHS: P.LSCEVAddRec->getStepRecurrence(SE&: *SE))); |
398 | if (!SCEVE.isSafeToExpand(S: NextLSCEV)) |
399 | continue; |
400 | |
401 | unsigned PtrAddrSpace = NextLSCEV->getType()->getPointerAddressSpace(); |
402 | Type *I8Ptr = PointerType::get(C&: BB->getContext(), AddressSpace: PtrAddrSpace); |
403 | Value *PrefPtrValue = SCEVE.expandCodeFor(SH: NextLSCEV, Ty: I8Ptr, I: P.InsertPt); |
404 | |
405 | IRBuilder<> Builder(P.InsertPt); |
406 | Module *M = BB->getParent()->getParent(); |
407 | Type *I32 = Type::getInt32Ty(C&: BB->getContext()); |
408 | Function *PrefetchFunc = Intrinsic::getDeclaration( |
409 | M, id: Intrinsic::prefetch, Tys: PrefPtrValue->getType()); |
410 | Builder.CreateCall( |
411 | Callee: PrefetchFunc, |
412 | Args: {PrefPtrValue, |
413 | ConstantInt::get(Ty: I32, V: P.Writes), |
414 | ConstantInt::get(Ty: I32, V: 3), ConstantInt::get(Ty: I32, V: 1)}); |
415 | ++NumPrefetches; |
416 | LLVM_DEBUG(dbgs() << " Access: " |
417 | << *P.MemI->getOperand(isa<LoadInst>(P.MemI) ? 0 : 1) |
418 | << ", SCEV: " << *P.LSCEVAddRec << "\n" ); |
419 | ORE->emit(RemarkBuilder: [&]() { |
420 | return OptimizationRemark(DEBUG_TYPE, "Prefetched" , P.MemI) |
421 | << "prefetched memory access" ; |
422 | }); |
423 | |
424 | MadeChange = true; |
425 | } |
426 | |
427 | return MadeChange; |
428 | } |
429 | |