1//===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides the implementation of the MIRSampleProfile loader, mainly
10// for flow sensitive SampleFDO.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/MIRSampleProfile.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/DenseSet.h"
17#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
19#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
20#include "llvm/CodeGen/MachineDominators.h"
21#include "llvm/CodeGen/MachineInstr.h"
22#include "llvm/CodeGen/MachineLoopInfo.h"
23#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
24#include "llvm/CodeGen/MachinePostDominators.h"
25#include "llvm/CodeGen/Passes.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/PseudoProbe.h"
28#include "llvm/InitializePasses.h"
29#include "llvm/Support/CommandLine.h"
30#include "llvm/Support/Debug.h"
31#include "llvm/Support/VirtualFileSystem.h"
32#include "llvm/Support/raw_ostream.h"
33#include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
34#include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
35#include <optional>
36
37using namespace llvm;
38using namespace sampleprof;
39using namespace llvm::sampleprofutil;
40using ProfileCount = Function::ProfileCount;
41
42#define DEBUG_TYPE "fs-profile-loader"
43
44static cl::opt<bool> ShowFSBranchProb(
45 "show-fs-branchprob", cl::Hidden, cl::init(Val: false),
46 cl::desc("Print setting flow sensitive branch probabilities"));
47static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
48 "fs-profile-debug-prob-diff-threshold", cl::init(Val: 10),
49 cl::desc("Only show debug message if the branch probility is greater than "
50 "this value (in percentage)."));
51
52static cl::opt<unsigned> FSProfileDebugBWThreshold(
53 "fs-profile-debug-bw-threshold", cl::init(Val: 10000),
54 cl::desc("Only show debug message if the source branch weight is greater "
55 " than this value."));
56
57static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
58 cl::init(Val: false),
59 cl::desc("View BFI before MIR loader"));
60static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
61 cl::init(Val: false),
62 cl::desc("View BFI after MIR loader"));
63
64namespace llvm {
65extern cl::opt<bool> ImprovedFSDiscriminator;
66}
67char MIRProfileLoaderPass::ID = 0;
68
69INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
70 "Load MIR Sample Profile",
71 /* cfg = */ false, /* is_analysis = */ false)
72INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
73INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
74INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
75INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
76INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
77INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
78 /* cfg = */ false, /* is_analysis = */ false)
79
80char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
81
82FunctionPass *
83llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
84 FSDiscriminatorPass P,
85 IntrusiveRefCntPtr<vfs::FileSystem> FS) {
86 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
87}
88
89namespace llvm {
90
91// Internal option used to control BFI display only after MBP pass.
92// Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
93// -view-block-layout-with-bfi={none | fraction | integer | count}
94extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
95
96// Command line option to specify the name of the function for CFG dump
97// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
98extern cl::opt<std::string> ViewBlockFreqFuncName;
99
100std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
101 if (MI.isPseudoProbe()) {
102 PseudoProbe Probe;
103 Probe.Id = MI.getOperand(i: 1).getImm();
104 Probe.Type = MI.getOperand(i: 2).getImm();
105 Probe.Attr = MI.getOperand(i: 3).getImm();
106 Probe.Factor = 1;
107 DILocation *DebugLoc = MI.getDebugLoc();
108 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
109 return Probe;
110 }
111
112 // Ignore callsite probes since they do not have FS discriminators.
113 return std::nullopt;
114}
115
116namespace afdo_detail {
117template <> struct IRTraits<MachineBasicBlock> {
118 using InstructionT = MachineInstr;
119 using BasicBlockT = MachineBasicBlock;
120 using FunctionT = MachineFunction;
121 using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
122 using LoopT = MachineLoop;
123 using LoopInfoPtrT = MachineLoopInfo *;
124 using DominatorTreePtrT = MachineDominatorTree *;
125 using PostDominatorTreePtrT = MachinePostDominatorTree *;
126 using PostDominatorTreeT = MachinePostDominatorTree;
127 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
128 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
129 using PredRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
130 using SuccRangeT = iterator_range<std::vector<MachineBasicBlock *>::iterator>;
131 static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
132 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
133 return GraphTraits<const MachineFunction *>::getEntryNode(F);
134 }
135 static PredRangeT getPredecessors(MachineBasicBlock *BB) {
136 return BB->predecessors();
137 }
138 static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
139 return BB->successors();
140 }
141};
142} // namespace afdo_detail
143
144class MIRProfileLoader final
145 : public SampleProfileLoaderBaseImpl<MachineFunction> {
146public:
147 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
148 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
149 MachineOptimizationRemarkEmitter *MORE) {
150 DT = MDT;
151 PDT = MPDT;
152 LI = MLI;
153 BFI = MBFI;
154 ORE = MORE;
155 }
156 void setFSPass(FSDiscriminatorPass Pass) {
157 P = Pass;
158 LowBit = getFSPassBitBegin(P);
159 HighBit = getFSPassBitEnd(P);
160 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
161 }
162
163 MIRProfileLoader(StringRef Name, StringRef RemapName,
164 IntrusiveRefCntPtr<vfs::FileSystem> FS)
165 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
166 std::move(FS)) {}
167
168 void setBranchProbs(MachineFunction &F);
169 bool runOnFunction(MachineFunction &F);
170 bool doInitialization(Module &M);
171 bool isValid() const { return ProfileIsValid; }
172
173protected:
174 friend class SampleCoverageTracker;
175
176 /// Hold the information of the basic block frequency.
177 MachineBlockFrequencyInfo *BFI;
178
179 /// PassNum is the sequence number this pass is called, start from 1.
180 FSDiscriminatorPass P;
181
182 // LowBit in the FS discriminator used by this instance. Note the number is
183 // 0-based. Base discrimnator use bit 0 to bit 11.
184 unsigned LowBit;
185 // HighwBit in the FS discriminator used by this instance. Note the number
186 // is 0-based.
187 unsigned HighBit;
188
189 bool ProfileIsValid = true;
190 ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
191 if (FunctionSamples::ProfileIsProbeBased)
192 return getProbeWeight(Inst: MI);
193 if (ImprovedFSDiscriminator && MI.isMetaInstruction())
194 return std::error_code();
195 return getInstWeightImpl(Inst: MI);
196 }
197};
198
199template <>
200void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo(
201 MachineFunction &F) {}
202
203void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
204 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
205 for (auto &BI : F) {
206 MachineBasicBlock *BB = &BI;
207 if (BB->succ_size() < 2)
208 continue;
209 const MachineBasicBlock *EC = EquivalenceClass[BB];
210 uint64_t BBWeight = BlockWeights[EC];
211 uint64_t SumEdgeWeight = 0;
212 for (MachineBasicBlock *Succ : BB->successors()) {
213 Edge E = std::make_pair(x&: BB, y&: Succ);
214 SumEdgeWeight += EdgeWeights[E];
215 }
216
217 if (BBWeight != SumEdgeWeight) {
218 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
219 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
220 << "\n");
221 BBWeight = SumEdgeWeight;
222 }
223 if (BBWeight == 0) {
224 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
225 continue;
226 }
227
228#ifndef NDEBUG
229 uint64_t BBWeightOrig = BBWeight;
230#endif
231 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
232 uint32_t Factor = 1;
233 if (BBWeight > MaxWeight) {
234 Factor = BBWeight / MaxWeight + 1;
235 BBWeight /= Factor;
236 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
237 }
238
239 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
240 SE = BB->succ_end();
241 SI != SE; ++SI) {
242 MachineBasicBlock *Succ = *SI;
243 Edge E = std::make_pair(x&: BB, y&: Succ);
244 uint64_t EdgeWeight = EdgeWeights[E];
245 EdgeWeight /= Factor;
246
247 assert(BBWeight >= EdgeWeight &&
248 "BBweight is larger than EdgeWeight -- should not happen.\n");
249
250 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(Src: BB, Dst: SI);
251 BranchProbability NewProb(EdgeWeight, BBWeight);
252 if (OldProb == NewProb)
253 continue;
254 BB->setSuccProbability(I: SI, Prob: NewProb);
255#ifndef NDEBUG
256 if (!ShowFSBranchProb)
257 continue;
258 bool Show = false;
259 BranchProbability Diff;
260 if (OldProb > NewProb)
261 Diff = OldProb - NewProb;
262 else
263 Diff = NewProb - OldProb;
264 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
265 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
266
267 auto DIL = BB->findBranchDebugLoc();
268 auto SuccDIL = Succ->findBranchDebugLoc();
269 if (Show) {
270 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
271 << Succ->getNumber() << "): ";
272 if (DIL)
273 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
274 << DIL->getColumn();
275 if (SuccDIL)
276 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
277 << ":" << SuccDIL->getColumn();
278 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
279 << "\n";
280 }
281#endif
282 }
283 }
284}
285
286bool MIRProfileLoader::doInitialization(Module &M) {
287 auto &Ctx = M.getContext();
288
289 auto ReaderOrErr = sampleprof::SampleProfileReader::create(
290 Filename, C&: Ctx, FS&: *FS, P, RemapFilename: RemappingFilename);
291 if (std::error_code EC = ReaderOrErr.getError()) {
292 std::string Msg = "Could not open profile: " + EC.message();
293 Ctx.diagnose(DI: DiagnosticInfoSampleProfile(Filename, Msg));
294 return false;
295 }
296
297 Reader = std::move(ReaderOrErr.get());
298 Reader->setModule(&M);
299 ProfileIsValid = (Reader->read() == sampleprof_error::success);
300
301 // Load pseudo probe descriptors for probe-based function samples.
302 if (Reader->profileIsProbeBased()) {
303 ProbeManager = std::make_unique<PseudoProbeManager>(args&: M);
304 if (!ProbeManager->moduleIsProbed(M)) {
305 return false;
306 }
307 }
308
309 return true;
310}
311
312bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
313 // Do not load non-FS profiles. A line or probe can get a zero-valued
314 // discriminator at certain pass which could result in accidentally loading
315 // the corresponding base counter in the non-FS profile, while a non-zero
316 // discriminator would end up getting zero samples. This could in turn undo
317 // the sample distribution effort done by previous BFI maintenance and the
318 // probe distribution factor work for pseudo probes.
319 if (!Reader->profileIsFS())
320 return false;
321
322 Function &Func = MF.getFunction();
323 clearFunctionData(ResetDT: false);
324 Samples = Reader->getSamplesFor(F: Func);
325 if (!Samples || Samples->empty())
326 return false;
327
328 if (FunctionSamples::ProfileIsProbeBased) {
329 if (!ProbeManager->profileIsValid(F: MF.getFunction(), Samples: *Samples))
330 return false;
331 } else {
332 if (getFunctionLoc(Func&: MF) == 0)
333 return false;
334 }
335
336 DenseSet<GlobalValue::GUID> InlinedGUIDs;
337 bool Changed = computeAndPropagateWeights(F&: MF, InlinedGUIDs);
338
339 // Set the new BPI, BFI.
340 setBranchProbs(MF);
341
342 return Changed;
343}
344
345} // namespace llvm
346
347MIRProfileLoaderPass::MIRProfileLoaderPass(
348 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
349 IntrusiveRefCntPtr<vfs::FileSystem> FS)
350 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
351 LowBit = getFSPassBitBegin(P);
352 HighBit = getFSPassBitEnd(P);
353
354 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
355 MIRSampleLoader = std::make_unique<MIRProfileLoader>(
356 args&: FileName, args&: RemappingFileName, args: std::move(VFS));
357 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
358}
359
360bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
361 if (!MIRSampleLoader->isValid())
362 return false;
363
364 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
365 << MF.getFunction().getName() << "\n");
366 MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
367 MIRSampleLoader->setInitVals(
368 MDT: &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(),
369 MPDT: &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree(),
370 MLI: &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI,
371 MORE: &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
372
373 MF.RenumberBlocks();
374 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
375 (ViewBlockFreqFuncName.empty() ||
376 MF.getFunction().getName() == ViewBlockFreqFuncName)) {
377 MBFI->view(Name: "MIR_Prof_loader_b." + MF.getName(), isSimple: false);
378 }
379
380 bool Changed = MIRSampleLoader->runOnFunction(MF);
381 if (Changed)
382 MBFI->calculate(F: MF, MBPI: *MBFI->getMBPI(),
383 MLI: *&getAnalysis<MachineLoopInfoWrapperPass>().getLI());
384
385 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
386 (ViewBlockFreqFuncName.empty() ||
387 MF.getFunction().getName() == ViewBlockFreqFuncName)) {
388 MBFI->view(Name: "MIR_prof_loader_a." + MF.getName(), isSimple: false);
389 }
390
391 return Changed;
392}
393
394bool MIRProfileLoaderPass::doInitialization(Module &M) {
395 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
396 << "\n");
397
398 MIRSampleLoader->setFSPass(P);
399 return MIRSampleLoader->doInitialization(M);
400}
401
402void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
403 AU.setPreservesAll();
404 AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
405 AU.addRequired<MachineDominatorTreeWrapperPass>();
406 AU.addRequired<MachinePostDominatorTreeWrapperPass>();
407 AU.addRequiredTransitive<MachineLoopInfoWrapperPass>();
408 AU.addRequired<MachineOptimizationRemarkEmitterPass>();
409 MachineFunctionPass::getAnalysisUsage(AU);
410}
411