1//===-------- MIRSampleProfile.cpp: MIRSampleFDO (For FSAFDO) -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file provides the implementation of the MIRSampleProfile loader, mainly
10// for flow sensitive SampleFDO.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/MIRSampleProfile.h"
15#include "llvm/ADT/DenseMap.h"
16#include "llvm/ADT/DenseSet.h"
17#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
18#include "llvm/CodeGen/MIRFSDiscriminatorOptions.h"
19#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
20#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
21#include "llvm/CodeGen/MachineDominators.h"
22#include "llvm/CodeGen/MachineInstr.h"
23#include "llvm/CodeGen/MachineLoopInfo.h"
24#include "llvm/CodeGen/MachineOptimizationRemarkEmitter.h"
25#include "llvm/CodeGen/MachinePostDominators.h"
26#include "llvm/CodeGen/Passes.h"
27#include "llvm/IR/Function.h"
28#include "llvm/IR/PseudoProbe.h"
29#include "llvm/InitializePasses.h"
30#include "llvm/Support/CommandLine.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/VirtualFileSystem.h"
33#include "llvm/Support/raw_ostream.h"
34#include "llvm/Transforms/Utils/SampleProfileLoaderBaseImpl.h"
35#include "llvm/Transforms/Utils/SampleProfileLoaderBaseUtil.h"
36#include <optional>
37
38using namespace llvm;
39using namespace sampleprof;
40using namespace llvm::sampleprofutil;
41using ProfileCount = Function::ProfileCount;
42
43#define DEBUG_TYPE "fs-profile-loader"
44
45static cl::opt<bool> ShowFSBranchProb(
46 "show-fs-branchprob", cl::Hidden, cl::init(Val: false),
47 cl::desc("Print setting flow sensitive branch probabilities"));
48static cl::opt<unsigned> FSProfileDebugProbDiffThreshold(
49 "fs-profile-debug-prob-diff-threshold", cl::init(Val: 10),
50 cl::desc(
51 "Only show debug message if the branch probability is greater than "
52 "this value (in percentage)."));
53
54static cl::opt<unsigned> FSProfileDebugBWThreshold(
55 "fs-profile-debug-bw-threshold", cl::init(Val: 10000),
56 cl::desc("Only show debug message if the source branch weight is greater "
57 " than this value."));
58
59static cl::opt<bool> ViewBFIBefore("fs-viewbfi-before", cl::Hidden,
60 cl::init(Val: false),
61 cl::desc("View BFI before MIR loader"));
62static cl::opt<bool> ViewBFIAfter("fs-viewbfi-after", cl::Hidden,
63 cl::init(Val: false),
64 cl::desc("View BFI after MIR loader"));
65
66char MIRProfileLoaderPass::ID = 0;
67
68INITIALIZE_PASS_BEGIN(MIRProfileLoaderPass, DEBUG_TYPE,
69 "Load MIR Sample Profile",
70 /* cfg = */ false, /* is_analysis = */ false)
71INITIALIZE_PASS_DEPENDENCY(MachineBlockFrequencyInfoWrapperPass)
72INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass)
73INITIALIZE_PASS_DEPENDENCY(MachinePostDominatorTreeWrapperPass)
74INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass)
75INITIALIZE_PASS_DEPENDENCY(MachineOptimizationRemarkEmitterPass)
76INITIALIZE_PASS_END(MIRProfileLoaderPass, DEBUG_TYPE, "Load MIR Sample Profile",
77 /* cfg = */ false, /* is_analysis = */ false)
78
79char &llvm::MIRProfileLoaderPassID = MIRProfileLoaderPass::ID;
80
81FunctionPass *
82llvm::createMIRProfileLoaderPass(std::string File, std::string RemappingFile,
83 FSDiscriminatorPass P,
84 IntrusiveRefCntPtr<vfs::FileSystem> FS) {
85 return new MIRProfileLoaderPass(File, RemappingFile, P, std::move(FS));
86}
87
88namespace llvm {
89
90// Internal option used to control BFI display only after MBP pass.
91// Defined in CodeGen/MachineBlockFrequencyInfo.cpp:
92// -view-block-layout-with-bfi={none | fraction | integer | count}
93extern cl::opt<GVDAGType> ViewBlockLayoutWithBFI;
94
95// Command line option to specify the name of the function for CFG dump
96// Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name=
97extern cl::opt<std::string> ViewBlockFreqFuncName;
98
99std::optional<PseudoProbe> extractProbe(const MachineInstr &MI) {
100 if (MI.isPseudoProbe()) {
101 PseudoProbe Probe;
102 Probe.Id = MI.getOperand(i: 1).getImm();
103 Probe.Type = MI.getOperand(i: 2).getImm();
104 Probe.Attr = MI.getOperand(i: 3).getImm();
105 Probe.Factor = 1;
106 DILocation *DebugLoc = MI.getDebugLoc();
107 Probe.Discriminator = DebugLoc ? DebugLoc->getDiscriminator() : 0;
108 return Probe;
109 }
110
111 // Ignore callsite probes since they do not have FS discriminators.
112 return std::nullopt;
113}
114
115namespace afdo_detail {
116template <> struct IRTraits<MachineBasicBlock> {
117 using InstructionT = MachineInstr;
118 using BasicBlockT = MachineBasicBlock;
119 using FunctionT = MachineFunction;
120 using BlockFrequencyInfoT = MachineBlockFrequencyInfo;
121 using LoopT = MachineLoop;
122 using LoopInfoPtrT = MachineLoopInfo *;
123 using DominatorTreePtrT = MachineDominatorTree *;
124 using PostDominatorTreePtrT = MachinePostDominatorTree *;
125 using PostDominatorTreeT = MachinePostDominatorTree;
126 using OptRemarkEmitterT = MachineOptimizationRemarkEmitter;
127 using OptRemarkAnalysisT = MachineOptimizationRemarkAnalysis;
128 using PredRangeT =
129 iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
130 using SuccRangeT =
131 iterator_range<SmallVectorImpl<MachineBasicBlock *>::iterator>;
132 static Function &getFunction(MachineFunction &F) { return F.getFunction(); }
133 static const MachineBasicBlock *getEntryBB(const MachineFunction *F) {
134 return GraphTraits<const MachineFunction *>::getEntryNode(F);
135 }
136 static PredRangeT getPredecessors(MachineBasicBlock *BB) {
137 return BB->predecessors();
138 }
139 static SuccRangeT getSuccessors(MachineBasicBlock *BB) {
140 return BB->successors();
141 }
142};
143} // namespace afdo_detail
144
145class MIRProfileLoader final
146 : public SampleProfileLoaderBaseImpl<MachineFunction> {
147public:
148 void setInitVals(MachineDominatorTree *MDT, MachinePostDominatorTree *MPDT,
149 MachineLoopInfo *MLI, MachineBlockFrequencyInfo *MBFI,
150 MachineOptimizationRemarkEmitter *MORE) {
151 DT = MDT;
152 PDT = MPDT;
153 LI = MLI;
154 BFI = MBFI;
155 ORE = MORE;
156 }
157 void setFSPass(FSDiscriminatorPass Pass) {
158 P = Pass;
159 LowBit = getFSPassBitBegin(P);
160 HighBit = getFSPassBitEnd(P);
161 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
162 }
163
164 MIRProfileLoader(StringRef Name, StringRef RemapName,
165 IntrusiveRefCntPtr<vfs::FileSystem> FS)
166 : SampleProfileLoaderBaseImpl(std::string(Name), std::string(RemapName),
167 std::move(FS)) {}
168
169 void setBranchProbs(MachineFunction &F);
170 bool runOnFunction(MachineFunction &F);
171 bool doInitialization(Module &M);
172 bool isValid() const { return ProfileIsValid; }
173
174protected:
175 friend class SampleCoverageTracker;
176
177 /// Hold the information of the basic block frequency.
178 MachineBlockFrequencyInfo *BFI;
179
180 /// PassNum is the sequence number this pass is called, start from 1.
181 FSDiscriminatorPass P;
182
183 // LowBit in the FS discriminator used by this instance. Note the number is
184 // 0-based. Base discrimnator use bit 0 to bit 11.
185 unsigned LowBit;
186 // HighwBit in the FS discriminator used by this instance. Note the number
187 // is 0-based.
188 unsigned HighBit;
189
190 bool ProfileIsValid = true;
191 ErrorOr<uint64_t> getInstWeight(const MachineInstr &MI) override {
192 if (FunctionSamples::ProfileIsProbeBased)
193 return getProbeWeight(Inst: MI);
194 if (ImprovedFSDiscriminator && MI.isMetaInstruction())
195 return std::error_code();
196 return getInstWeightImpl(Inst: MI);
197 }
198};
199
200template <>
201void SampleProfileLoaderBaseImpl<MachineFunction>::computeDominanceAndLoopInfo(
202 MachineFunction &F) {}
203
204void MIRProfileLoader::setBranchProbs(MachineFunction &F) {
205 LLVM_DEBUG(dbgs() << "\nPropagation complete. Setting branch probs\n");
206 for (auto &BI : F) {
207 MachineBasicBlock *BB = &BI;
208 if (BB->succ_size() < 2)
209 continue;
210 const MachineBasicBlock *EC = EquivalenceClass[BB];
211 uint64_t BBWeight = BlockWeights[EC];
212 uint64_t SumEdgeWeight = 0;
213 for (MachineBasicBlock *Succ : BB->successors()) {
214 Edge E = std::make_pair(x&: BB, y&: Succ);
215 SumEdgeWeight += EdgeWeights[E];
216 }
217
218 if (BBWeight != SumEdgeWeight) {
219 LLVM_DEBUG(dbgs() << "BBweight is not equal to SumEdgeWeight: BBWWeight="
220 << BBWeight << " SumEdgeWeight= " << SumEdgeWeight
221 << "\n");
222 BBWeight = SumEdgeWeight;
223 }
224 if (BBWeight == 0) {
225 LLVM_DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n");
226 continue;
227 }
228
229#ifndef NDEBUG
230 uint64_t BBWeightOrig = BBWeight;
231#endif
232 uint32_t MaxWeight = std::numeric_limits<uint32_t>::max();
233 uint32_t Factor = 1;
234 if (BBWeight > MaxWeight) {
235 Factor = BBWeight / MaxWeight + 1;
236 BBWeight /= Factor;
237 LLVM_DEBUG(dbgs() << "Scaling weights by " << Factor << "\n");
238 }
239
240 for (MachineBasicBlock::succ_iterator SI = BB->succ_begin(),
241 SE = BB->succ_end();
242 SI != SE; ++SI) {
243 MachineBasicBlock *Succ = *SI;
244 Edge E = std::make_pair(x&: BB, y&: Succ);
245 uint64_t EdgeWeight = EdgeWeights[E];
246 EdgeWeight /= Factor;
247
248 assert(BBWeight >= EdgeWeight &&
249 "BBweight is larger than EdgeWeight -- should not happen.\n");
250
251 BranchProbability OldProb = BFI->getMBPI()->getEdgeProbability(Src: BB, Dst: SI);
252 BranchProbability NewProb(EdgeWeight, BBWeight);
253 if (OldProb == NewProb)
254 continue;
255 BB->setSuccProbability(I: SI, Prob: NewProb);
256#ifndef NDEBUG
257 if (!ShowFSBranchProb)
258 continue;
259 bool Show = false;
260 BranchProbability Diff;
261 if (OldProb > NewProb)
262 Diff = OldProb - NewProb;
263 else
264 Diff = NewProb - OldProb;
265 Show = (Diff >= BranchProbability(FSProfileDebugProbDiffThreshold, 100));
266 Show &= (BBWeightOrig >= FSProfileDebugBWThreshold);
267
268 auto DIL = BB->findBranchDebugLoc();
269 auto SuccDIL = Succ->findBranchDebugLoc();
270 if (Show) {
271 dbgs() << "Set branch fs prob: MBB (" << BB->getNumber() << " -> "
272 << Succ->getNumber() << "): ";
273 if (DIL)
274 dbgs() << DIL->getFilename() << ":" << DIL->getLine() << ":"
275 << DIL->getColumn();
276 if (SuccDIL)
277 dbgs() << "-->" << SuccDIL->getFilename() << ":" << SuccDIL->getLine()
278 << ":" << SuccDIL->getColumn();
279 dbgs() << " W=" << BBWeightOrig << " " << OldProb << " --> " << NewProb
280 << "\n";
281 }
282#endif
283 }
284 }
285}
286
287bool MIRProfileLoader::doInitialization(Module &M) {
288 auto &Ctx = M.getContext();
289
290 auto ReaderOrErr = sampleprof::SampleProfileReader::create(
291 Filename, C&: Ctx, FS&: *FS, P, RemapFilename: RemappingFilename);
292 if (std::error_code EC = ReaderOrErr.getError()) {
293 std::string Msg = "Could not open profile: " + EC.message();
294 Ctx.diagnose(DI: DiagnosticInfoSampleProfile(Filename, Msg));
295 return false;
296 }
297
298 Reader = std::move(ReaderOrErr.get());
299 Reader->setModule(&M);
300 ProfileIsValid = (Reader->read() == sampleprof_error::success);
301
302 // Load pseudo probe descriptors for probe-based function samples.
303 if (Reader->profileIsProbeBased()) {
304 ProbeManager = std::make_unique<PseudoProbeManager>(args&: M);
305 if (!ProbeManager->moduleIsProbed(M)) {
306 return false;
307 }
308 }
309
310 return true;
311}
312
313bool MIRProfileLoader::runOnFunction(MachineFunction &MF) {
314 // Do not load non-FS profiles. A line or probe can get a zero-valued
315 // discriminator at certain pass which could result in accidentally loading
316 // the corresponding base counter in the non-FS profile, while a non-zero
317 // discriminator would end up getting zero samples. This could in turn undo
318 // the sample distribution effort done by previous BFI maintenance and the
319 // probe distribution factor work for pseudo probes.
320 if (!Reader->profileIsFS())
321 return false;
322
323 Function &Func = MF.getFunction();
324 clearFunctionData(ResetDT: false);
325 Samples = Reader->getSamplesFor(F: Func);
326 if (!Samples || Samples->empty())
327 return false;
328
329 if (FunctionSamples::ProfileIsProbeBased) {
330 if (!ProbeManager->profileIsValid(F: MF.getFunction(), Samples: *Samples))
331 return false;
332 } else {
333 if (getFunctionLoc(F&: MF) == 0)
334 return false;
335 }
336
337 DenseSet<GlobalValue::GUID> InlinedGUIDs;
338 bool Changed = computeAndPropagateWeights(F&: MF, InlinedGUIDs);
339
340 // Set the new BPI, BFI.
341 setBranchProbs(MF);
342
343 return Changed;
344}
345
346} // namespace llvm
347
348MIRProfileLoaderPass::MIRProfileLoaderPass(
349 std::string FileName, std::string RemappingFileName, FSDiscriminatorPass P,
350 IntrusiveRefCntPtr<vfs::FileSystem> FS)
351 : MachineFunctionPass(ID), ProfileFileName(FileName), P(P) {
352 LowBit = getFSPassBitBegin(P);
353 HighBit = getFSPassBitEnd(P);
354
355 auto VFS = FS ? std::move(FS) : vfs::getRealFileSystem();
356 MIRSampleLoader = std::make_unique<MIRProfileLoader>(
357 args&: FileName, args&: RemappingFileName, args: std::move(VFS));
358 assert(LowBit < HighBit && "HighBit needs to be greater than Lowbit");
359}
360
361bool MIRProfileLoaderPass::runOnMachineFunction(MachineFunction &MF) {
362 if (!MIRSampleLoader->isValid())
363 return false;
364
365 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Func: "
366 << MF.getFunction().getName() << "\n");
367 MBFI = &getAnalysis<MachineBlockFrequencyInfoWrapperPass>().getMBFI();
368 auto *MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
369 auto *MPDT =
370 &getAnalysis<MachinePostDominatorTreeWrapperPass>().getPostDomTree();
371
372 MF.RenumberBlocks();
373 MDT->updateBlockNumbers();
374 MPDT->updateBlockNumbers();
375
376 MIRSampleLoader->setInitVals(
377 MDT, MPDT, MLI: &getAnalysis<MachineLoopInfoWrapperPass>().getLI(), MBFI,
378 MORE: &getAnalysis<MachineOptimizationRemarkEmitterPass>().getORE());
379
380 if (ViewBFIBefore && ViewBlockLayoutWithBFI != GVDT_None &&
381 (ViewBlockFreqFuncName.empty() ||
382 MF.getFunction().getName() == ViewBlockFreqFuncName)) {
383 MBFI->view(Name: "MIR_Prof_loader_b." + MF.getName(), isSimple: false);
384 }
385
386 bool Changed = MIRSampleLoader->runOnFunction(MF);
387 if (Changed)
388 MBFI->calculate(F: MF, MBPI: *MBFI->getMBPI(),
389 MLI: *&getAnalysis<MachineLoopInfoWrapperPass>().getLI());
390
391 if (ViewBFIAfter && ViewBlockLayoutWithBFI != GVDT_None &&
392 (ViewBlockFreqFuncName.empty() ||
393 MF.getFunction().getName() == ViewBlockFreqFuncName)) {
394 MBFI->view(Name: "MIR_prof_loader_a." + MF.getName(), isSimple: false);
395 }
396
397 return Changed;
398}
399
400bool MIRProfileLoaderPass::doInitialization(Module &M) {
401 LLVM_DEBUG(dbgs() << "MIRProfileLoader pass working on Module " << M.getName()
402 << "\n");
403
404 MIRSampleLoader->setFSPass(P);
405 return MIRSampleLoader->doInitialization(M);
406}
407
408void MIRProfileLoaderPass::getAnalysisUsage(AnalysisUsage &AU) const {
409 AU.setPreservesAll();
410 AU.addRequired<MachineBlockFrequencyInfoWrapperPass>();
411 AU.addRequired<MachineDominatorTreeWrapperPass>();
412 AU.addRequired<MachinePostDominatorTreeWrapperPass>();
413 AU.addRequiredTransitive<MachineLoopInfoWrapperPass>();
414 AU.addRequired<MachineOptimizationRemarkEmitterPass>();
415 MachineFunctionPass::getAnalysisUsage(AU);
416}
417