1 | //===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file provides the implementation of the MIRSampleProfile loader, mainly |
10 | // for flow sensitive SampleFDO. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/CodeGen/MIRSampleProfile.h" |
15 | #include "llvm/ADT/DenseMap.h" |
16 | #include "llvm/ADT/DenseSet.h" |
17 | #include "llvm/Analysis/BlockFrequencyInfoImpl.h" |
18 | #include "llvm/CodeGen/MachineBlockFrequencyInfo.h" |
19 | #include "llvm/CodeGen/MachineBranchProbabilityInfo.h" |
20 | #include "llvm/CodeGen/MachineDominators.h" |
21 | #include "llvm/CodeGen/MachineInstr.h" |
22 | #include "llvm/CodeGen/MachineLoopInfo.h" |
23 | #include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h" |
24 | #include "llvm/CodeGen/MachinePostDominators.h" |
25 | #include "llvm/CodeGen/Passes.h" |
26 | #include "llvm/IR/Function.h" |
27 | #include "llvm/IR/PseudoProbe.h" |
28 | #include "llvm/InitializePasses.h" |
29 | #include "llvm/Support/CommandLine.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include "llvm/Support/VirtualFileSystem.h" |
32 | #include "llvm/Support/raw_ostream.h" |
33 | #include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h" |
34 | #include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h" |
35 | #include <optional> |
36 | |
37 | using namespace llvm; |
38 | using namespace sampleprof; |
39 | using namespace llvm::sampleprofutil; |
40 | using ProfileCount = Function::ProfileCount; |
41 | |
42 | #define DEBUG_TYPE "fs-profile-loader" |
43 | |
44 | static cl::opt<bool> ShowFSBranchProb( |
45 | "show-fs-branchprob" , cl::Hidden, cl::init(Val: false), |
46 | cl::desc("Print setting flow sensitive branch probabilities" )); |
47 | static cl::opt<unsigned> FSProfileDebugProbDiffThreshold( |
48 | "fs-profile-debug-prob-diff-threshold" , cl::init(Val: 10), |
49 | cl::desc("Only show debug message if the branch probility is greater than " |
50 | "this value (in percentage)." )); |
51 | |
52 | static cl::opt<unsigned> FSProfileDebugBWThreshold( |
53 | "fs-profile-debug-bw-threshold" , cl::init(Val: 10000), |
54 | cl::desc("Only show debug message if the source branch weight is greater " |
55 | " than this value." )); |
56 | |
57 | static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before" , cl::Hidden, |
58 | cl::init(Val: false), |
59 | cl::desc("View BFI before MIR loader" )); |
60 | static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after" , cl::Hidden, |
61 | cl::init(Val: false), |
62 | cl::desc("View BFI after MIR loader" )); |
63 | |
64 | namespace llvm { |
65 | extern cl::opt<bool> ImprovedFSDiscriminator; |
66 | } |
67 | char MIRProfileLoaderPass::ID = 0; |
68 | |
69 | INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE, |
70 | "Load MIR Sample Profile" , |
71 | /* cfg = */ false, /* is_analysis = */ false) |
72 | INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass) |
73 | INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) |
74 | INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass) |
75 | INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) |
76 | INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass) |
77 | INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile" , |
78 | /* cfg = */ false, /* is_analysis = */ false) |
79 | |
80 | char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID; |
81 | |
82 | FunctionPass * |
83 | llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile, |
84 | FSDiscriminatorPass P, |
85 | IntrusiveRefCntPtr<vfs::FileSystem> FS) { |
86 | return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS)); |
87 | } |
88 | |
89 | namespace llvm { |
90 | |
91 | // Internal option used to control BFI display only after MBP pass. |
92 | // Defined in CodeGen/MachineBlockFrequencyInfo.cpp: |
93 | // -view-block-layout-with-bfi={none | fraction | integer | count} |
94 | extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI; |
95 | |
96 | // Command line option to specify the name of the function for CFG dump |
97 | // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= |
98 | extern cl::opt<std::string> ViewBlockFreqFuncName; |
99 | |
100 | std::optional<PseudoProbe> (const MachineInstr &MI) { |
101 | if (MI.isPseudoProbe()) { |
102 | PseudoProbe Probe; |
103 | Probe.Id = MI.getOperand(i: 1).getImm(); |
104 | Probe.Type = MI.getOperand(i: 2).getImm(); |
105 | Probe.Attr = MI.getOperand(i: 3).getImm(); |
106 | Probe.Factor = 1; |
107 | DILocation *DebugLoc = MI.getDebugLoc(); |
108 | Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0; |
109 | return Probe; |
110 | } |
111 | |
112 | // Ignore callsite probes since they do not have FS discriminators. |
113 | return std::nullopt; |
114 | } |
115 | |
116 | namespace afdo_detail { |
117 | template <> struct IRTraits<MachineBasicBlock> { |
118 | using InstructionT = MachineInstr; |
119 | using BasicBlockT = MachineBasicBlock; |
120 | using FunctionT = MachineFunction; |
121 | using BlockFrequencyInfoT = MachineBlockFrequencyInfo; |
122 | using LoopT = MachineLoop; |
123 | using LoopInfoPtrT = MachineLoopInfo *; |
124 | using DominatorTreePtrT = MachineDominatorTree *; |
125 | using PostDominatorTreePtrT = MachinePostDominatorTree *; |
126 | using PostDominatorTreeT = MachinePostDominatorTree; |
127 | using = MachineOptimizationRemarkEmitter; |
128 | using = MachineOptimizationRemarkAnalysis; |
129 | using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; |
130 | using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>; |
131 | static Function &getFunction(MachineFunction &F) { return F.getFunction(); } |
132 | static const MachineBasicBlock *getEntryBB(const MachineFunction *F) { |
133 | return GraphTraits<const MachineFunction *>::getEntryNode(F); |
134 | } |
135 | static PredRangeT getPredecessors(MachineBasicBlock *BB) { |
136 | return BB->predecessors(); |
137 | } |
138 | static SuccRangeT getSuccessors(MachineBasicBlock *BB) { |
139 | return BB->successors(); |
140 | } |
141 | }; |
142 | } // namespace afdo_detail |
143 | |
144 | class MIRProfileLoader final |
145 | : public SampleProfileLoaderBaseImpl<MachineFunction> { |
146 | public: |
147 | void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT, |
148 | MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI, |
149 | MachineOptimizationRemarkEmitter *MORE) { |
150 | DT = MDT; |
151 | PDT = MPDT; |
152 | LI = MLI; |
153 | BFI = MBFI; |
154 | ORE = MORE; |
155 | } |
156 | void setFSPass(FSDiscriminatorPass Pass) { |
157 | P = Pass; |
158 | LowBit = getFSPassBitBegin(P); |
159 | HighBit = getFSPassBitEnd(P); |
160 | assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit" ); |
161 | } |
162 | |
163 | MIRProfileLoader(StringRef Name, StringRef RemapName, |
164 | IntrusiveRefCntPtr<vfs::FileSystem> FS) |
165 | : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName), |
166 | std::move(FS)) {} |
167 | |
168 | void setBranchProbs(MachineFunction &F); |
169 | bool runOnFunction(MachineFunction &F); |
170 | bool doInitialization(Module &M); |
171 | bool isValid() const { return ProfileIsValid; } |
172 | |
173 | protected: |
174 | friend class SampleCoverageTracker; |
175 | |
176 | /// Hold the information of the basic block frequency. |
177 | MachineBlockFrequencyInfo *BFI; |
178 | |
179 | /// PassNum is the sequence number this pass is called, start from 1. |
180 | FSDiscriminatorPass P; |
181 | |
182 | // LowBit in the FS discriminator used by this instance. Note the number is |
183 | // 0-based. Base discrimnator use bit 0 to bit 11. |
184 | unsigned LowBit; |
185 | // HighwBit in the FS discriminator used by this instance. Note the number |
186 | // is 0-based. |
187 | unsigned HighBit; |
188 | |
189 | bool ProfileIsValid = true; |
190 | ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override { |
191 | if (FunctionSamples::ProfileIsProbeBased) |
192 | return getProbeWeight(Inst: MI); |
193 | if (ImprovedFSDiscriminator && MI.isMetaInstruction()) |
194 | return std::error_code(); |
195 | return getInstWeightImpl(Inst: MI); |
196 | } |
197 | }; |
198 | |
199 | template <> |
200 | void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo( |
201 | MachineFunction &F) {} |
202 | |
203 | void MIRProfileLoader::setBranchProbs(MachineFunction &F) { |
204 | LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n" ); |
205 | for (auto &BI : F) { |
206 | MachineBasicBlock *BB = &BI; |
207 | if (BB->succ_size() < 2) |
208 | continue; |
209 | const MachineBasicBlock *EC = EquivalenceClass[BB]; |
210 | uint64_t BBWeight = BlockWeights[EC]; |
211 | uint64_t SumEdgeWeight = 0; |
212 | for (MachineBasicBlock *Succ : BB->successors()) { |
213 | Edge E = std::make_pair(x&: BB, y&: Succ); |
214 | SumEdgeWeight += EdgeWeights[E]; |
215 | } |
216 | |
217 | if (BBWeight != SumEdgeWeight) { |
218 | LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight=" |
219 | << BBWeight << " SumEdgeWeight= " << SumEdgeWeight |
220 | << "\n" ); |
221 | BBWeight = SumEdgeWeight; |
222 | } |
223 | if (BBWeight == 0) { |
224 | LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n" ); |
225 | continue; |
226 | } |
227 | |
228 | #ifndef NDEBUG |
229 | uint64_t BBWeightOrig = BBWeight; |
230 | #endif |
231 | uint32_t MaxWeight = std::numeric_limits<uint32_t>::max(); |
232 | uint32_t Factor = 1; |
233 | if (BBWeight > MaxWeight) { |
234 | Factor = BBWeight / MaxWeight + 1; |
235 | BBWeight /= Factor; |
236 | LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n" ); |
237 | } |
238 | |
239 | for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(), |
240 | SE = BB->succ_end(); |
241 | SI != SE; ++SI) { |
242 | MachineBasicBlock *Succ = *SI; |
243 | Edge E = std::make_pair(x&: BB, y&: Succ); |
244 | uint64_t EdgeWeight = EdgeWeights[E]; |
245 | EdgeWeight /= Factor; |
246 | |
247 | assert(BBWeight >= EdgeWeight && |
248 | "BBweight is larger than EdgeWeight -- should not happen.\n" ); |
249 | |
250 | BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(Src: BB, Dst: SI); |
251 | BranchProbability NewProb(EdgeWeight, BBWeight); |
252 | if (OldProb == NewProb) |
253 | continue; |
254 | BB->setSuccProbability(I: SI, Prob: NewProb); |
255 | #ifndef NDEBUG |
256 | if (!ShowFSBranchProb) |
257 | continue; |
258 | bool Show = false; |
259 | BranchProbability Diff; |
260 | if (OldProb > NewProb) |
261 | Diff = OldProb - NewProb; |
262 | else |
263 | Diff = NewProb - OldProb; |
264 | Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100)); |
265 | Show &= (BBWeightOrig >= FSProfileDebugBWThreshold); |
266 | |
267 | auto DIL = BB->findBranchDebugLoc(); |
268 | auto SuccDIL = Succ->findBranchDebugLoc(); |
269 | if (Show) { |
270 | dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> " |
271 | << Succ->getNumber() << "): " ; |
272 | if (DIL) |
273 | dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":" |
274 | << DIL->getColumn(); |
275 | if (SuccDIL) |
276 | dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine() |
277 | << ":" << SuccDIL->getColumn(); |
278 | dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb |
279 | << "\n" ; |
280 | } |
281 | #endif |
282 | } |
283 | } |
284 | } |
285 | |
286 | bool MIRProfileLoader::doInitialization(Module &M) { |
287 | auto &Ctx = M.getContext(); |
288 | |
289 | auto ReaderOrErr = sampleprof::SampleProfileReader::create( |
290 | Filename, C&: Ctx, FS&: *FS, P, RemapFilename: RemappingFilename); |
291 | if (std::error_code EC = ReaderOrErr.getError()) { |
292 | std::string Msg = "Could not open profile: " + EC.message(); |
293 | Ctx.diagnose(DI: DiagnosticInfoSampleProfile(Filename, Msg)); |
294 | return false; |
295 | } |
296 | |
297 | Reader = std::move(ReaderOrErr.get()); |
298 | Reader->setModule(&M); |
299 | ProfileIsValid = (Reader->read() == sampleprof_error::success); |
300 | |
301 | // Load pseudo probe descriptors for probe-based function samples. |
302 | if (Reader->profileIsProbeBased()) { |
303 | ProbeManager = std::make_unique<PseudoProbeManager>(args&: M); |
304 | if (!ProbeManager->moduleIsProbed(M)) { |
305 | return false; |
306 | } |
307 | } |
308 | |
309 | return true; |
310 | } |
311 | |
312 | bool MIRProfileLoader::runOnFunction(MachineFunction &MF) { |
313 | // Do not load non-FS profiles. A line or probe can get a zero-valued |
314 | // discriminator at certain pass which could result in accidentally loading |
315 | // the corresponding base counter in the non-FS profile, while a non-zero |
316 | // discriminator would end up getting zero samples. This could in turn undo |
317 | // the sample distribution effort done by previous BFI maintenance and the |
318 | // probe distribution factor work for pseudo probes. |
319 | if (!Reader->profileIsFS()) |
320 | return false; |
321 | |
322 | Function &Func = MF.getFunction(); |
323 | clearFunctionData(ResetDT: false); |
324 | Samples = Reader->getSamplesFor(F: Func); |
325 | if (!Samples || Samples->empty()) |
326 | return false; |
327 | |
328 | if (FunctionSamples::ProfileIsProbeBased) { |
329 | if (!ProbeManager->profileIsValid(F: MF.getFunction(), Samples: *Samples)) |
330 | return false; |
331 | } else { |
332 | if (getFunctionLoc(Func&: MF) == 0) |
333 | return false; |
334 | } |
335 | |
336 | DenseSet<GlobalValue::GUID> InlinedGUIDs; |
337 | bool Changed = computeAndPropagateWeights(F&: MF, InlinedGUIDs); |
338 | |
339 | // Set the new BPI, BFI. |
340 | setBranchProbs(MF); |
341 | |
342 | return Changed; |
343 | } |
344 | |
345 | } // namespace llvm |
346 | |
347 | MIRProfileLoaderPass::MIRProfileLoaderPass( |
348 | std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P, |
349 | IntrusiveRefCntPtr<vfs::FileSystem> FS) |
350 | : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) { |
351 | LowBit = getFSPassBitBegin(P); |
352 | HighBit = getFSPassBitEnd(P); |
353 | |
354 | auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem(); |
355 | MIRSampleLoader = std::make_unique<MIRProfileLoader>( |
356 | args&: FileName, args&: RemappingFileName, args: std::move(VFS)); |
357 | assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit" ); |
358 | } |
359 | |
360 | bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) { |
361 | if (!MIRSampleLoader->isValid()) |
362 | return false; |
363 | |
364 | LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: " |
365 | << MF.getFunction().getName() << "\n" ); |
366 | MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI(); |
367 | MIRSampleLoader->setInitVals( |
368 | MDT: &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(), |
369 | MPDT: &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree(), |
370 | MLI: &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI, |
371 | MORE: &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE()); |
372 | |
373 | MF.RenumberBlocks(); |
374 | if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None && |
375 | (ViewBlockFreqFuncName.empty() || |
376 | MF.getFunction().getName() == ViewBlockFreqFuncName)) { |
377 | MBFI->view(Name: "MIR_Prof_loader_b." + MF.getName(), isSimple: false); |
378 | } |
379 | |
380 | bool Changed = MIRSampleLoader->runOnFunction(MF); |
381 | if (Changed) |
382 | MBFI->calculate(F: MF, MBPI: *MBFI->getMBPI(), |
383 | MLI: *&getAnalysis<MachineLoopInfoWrapperPass>().getLI()); |
384 | |
385 | if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None && |
386 | (ViewBlockFreqFuncName.empty() || |
387 | MF.getFunction().getName() == ViewBlockFreqFuncName)) { |
388 | MBFI->view(Name: "MIR_prof_loader_a." + MF.getName(), isSimple: false); |
389 | } |
390 | |
391 | return Changed; |
392 | } |
393 | |
394 | bool MIRProfileLoaderPass::doInitialization(Module &M) { |
395 | LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName() |
396 | << "\n" ); |
397 | |
398 | MIRSampleLoader->setFSPass(P); |
399 | return MIRSampleLoader->doInitialization(M); |
400 | } |
401 | |
402 | void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const { |
403 | AU.setPreservesAll(); |
404 | AU.addRequired<MachineBlockFrequencyInfoWrapperPass>(); |
405 | AU.addRequired<MachineDominatorTreeWrapperPass>(); |
406 | AU.addRequired<MachinePostDominatorTreeWrapperPass>(); |
407 | AU.addRequiredTransitive<MachineLoopInfoWrapperPass>(); |
408 | AU.addRequired<MachineOptimizationRemarkEmitterPass>(); |
409 | MachineFunctionPass::getAnalysisUsage(AU); |
410 | } |
411 | |