1//===- ProfileVerify.cpp - Verify profile info for testing ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Utils/ProfileVerify.h"
10#include "llvm/ADT/DynamicAPInt.h"
11#include "llvm/ADT/STLExtras.h"
12#include "llvm/ADT/SmallString.h"
13#include "llvm/Analysis/BranchProbabilityInfo.h"
14#include "llvm/IR/Analysis.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/IR/Dominators.h"
17#include "llvm/IR/Function.h"
18#include "llvm/IR/GlobalValue.h"
19#include "llvm/IR/GlobalVariable.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/MDBuilder.h"
23#include "llvm/IR/Module.h"
24#include "llvm/IR/PassManager.h"
25#include "llvm/IR/ProfDataUtils.h"
26#include "llvm/Support/BranchProbability.h"
27#include "llvm/Support/Casting.h"
28#include "llvm/Support/CommandLine.h"
29
30using namespace llvm;
31static cl::opt<int64_t>
32 DefaultFunctionEntryCount("profcheck-default-function-entry-count",
33 cl::init(Val: 1000));
34static cl::opt<bool>
35 AnnotateSelect("profcheck-annotate-select", cl::init(Val: true),
36 cl::desc("Also inject (if missing) and verify MD_prof for "
37 "`select` instructions"));
38static cl::opt<bool>
39 WeightsForTest("profcheck-weights-for-test", cl::init(Val: false),
40 cl::desc("Generate weights with small values for tests."));
41
42static cl::opt<uint32_t> SelectTrueWeight(
43 "profcheck-default-select-true-weight", cl::init(Val: 2U),
44 cl::desc("When annotating `select` instructions, this value will be used "
45 "for the first ('true') case."));
46static cl::opt<uint32_t> SelectFalseWeight(
47 "profcheck-default-select-false-weight", cl::init(Val: 3U),
48 cl::desc("When annotating `select` instructions, this value will be used "
49 "for the second ('false') case."));
50namespace {
51class ProfileInjector {
52 Function &F;
53 FunctionAnalysisManager &FAM;
54
55public:
56 static const Instruction *
57 getTerminatorBenefitingFromMDProf(const BasicBlock &BB) {
58 if (succ_size(BB: &BB) < 2)
59 return nullptr;
60 auto *Term = BB.getTerminator();
61 return (isa<BranchInst>(Val: Term) || isa<SwitchInst>(Val: Term) ||
62 isa<IndirectBrInst>(Val: Term) || isa<CallBrInst>(Val: Term))
63 ? Term
64 : nullptr;
65 }
66
67 static Instruction *getTerminatorBenefitingFromMDProf(BasicBlock &BB) {
68 return const_cast<Instruction *>(
69 getTerminatorBenefitingFromMDProf(BB: const_cast<const BasicBlock &>(BB)));
70 }
71
72 ProfileInjector(Function &F, FunctionAnalysisManager &FAM) : F(F), FAM(FAM) {}
73 bool inject();
74};
75
76bool isAsmOnly(const Function &F) {
77 if (!F.hasFnAttribute(Kind: Attribute::AttrKind::Naked))
78 return false;
79 for (const auto &BB : F)
80 for (const auto &I : drop_end(RangeOrContainer: BB.instructionsWithoutDebug())) {
81 const auto *CB = dyn_cast<CallBase>(Val: &I);
82 if (!CB || !CB->isInlineAsm())
83 return false;
84 }
85 return true;
86}
87
88void emitProfileError(StringRef Msg, Function &F) {
89 F.getContext().emitError(ErrorStr: "Profile verification failed for function '" +
90 F.getName() + "': " + Msg);
91}
92
93} // namespace
94
95// FIXME: currently this injects only for terminators. Select isn't yet
96// supported.
97bool ProfileInjector::inject() {
98 // skip purely asm functions
99 if (isAsmOnly(F))
100 return false;
101 // Get whatever branch probability info can be derived from the given IR -
102 // whether it has or not metadata. The main intention for this pass is to
103 // ensure that other passes don't drop or "forget" to update MD_prof. We do
104 // this as a mode in which lit tests would run. We want to avoid changing the
105 // behavior of those tests. A pass may use BPI (or BFI, which is computed from
106 // BPI). If no metadata is present, BPI is guesstimated by
107 // BranchProbabilityAnalysis. The injector (this pass) only persists whatever
108 // information the analysis provides, in other words, the pass being tested
109 // will get the same BPI it does if the injector wasn't running.
110 auto &BPI = FAM.getResult<BranchProbabilityAnalysis>(IR&: F);
111
112 // Inject a function count if there's none. It's reasonable for a pass to
113 // want to clear the MD_prof of a function with zero entry count. If the
114 // original profile (iFDO or AFDO) is empty for a function, it's simpler to
115 // require assigning it the 0-entry count explicitly than to mark every branch
116 // as cold (we do want some explicit information in the spirit of what this
117 // verifier wants to achieve - make dropping / corrupting MD_prof
118 // unit-testable)
119 if (!F.getEntryCount(/*AllowSynthetic=*/true))
120 F.setEntryCount(Count: DefaultFunctionEntryCount);
121 // If there is an entry count that's 0, then don't bother injecting. We won't
122 // verify these either.
123 if (F.getEntryCount(/*AllowSynthetic=*/true)->getCount() == 0)
124 return false;
125 bool Changed = false;
126 // Cycle through the weights list. If we didn't, tests with more than (say)
127 // one conditional branch would have the same !prof metadata on all of them,
128 // and numerically that may make for a poor unit test.
129 uint32_t WeightsForTestOffset = 0;
130 for (auto &BB : F) {
131 if (AnnotateSelect) {
132 for (auto &I : BB) {
133 if (auto *SI = dyn_cast<SelectInst>(Val: &I)) {
134 if (SI->getCondition()->getType()->isVectorTy())
135 continue;
136 if (I.getMetadata(KindID: LLVMContext::MD_prof))
137 continue;
138 setBranchWeights(I, Weights: {SelectTrueWeight, SelectFalseWeight},
139 /*IsExpected=*/false);
140 }
141 }
142 }
143 auto *Term = getTerminatorBenefitingFromMDProf(BB);
144 if (!Term || Term->getMetadata(KindID: LLVMContext::MD_prof))
145 continue;
146 SmallVector<BranchProbability> Probs;
147
148 SmallVector<uint32_t> Weights;
149 Weights.reserve(N: Term->getNumSuccessors());
150 if (WeightsForTest) {
151 static const std::array Primes{3, 5, 7, 11, 13, 17, 19, 23, 29, 31,
152 37, 41, 43, 47, 53, 59, 61, 67, 71};
153 for (uint32_t I = 0, E = Term->getNumSuccessors(); I < E; ++I)
154 Weights.emplace_back(
155 Args: Primes[(WeightsForTestOffset + I) % Primes.size()]);
156 ++WeightsForTestOffset;
157 } else {
158 Probs.reserve(N: Term->getNumSuccessors());
159 for (auto I = 0U, E = Term->getNumSuccessors(); I < E; ++I)
160 Probs.emplace_back(Args: BPI.getEdgeProbability(Src: &BB, Dst: Term->getSuccessor(Idx: I)));
161
162 assert(llvm::find_if(Probs,
163 [](const BranchProbability &P) {
164 return P.isUnknown();
165 }) == Probs.end() &&
166 "All branch probabilities should be valid");
167 const auto *FirstZeroDenominator =
168 find_if(Range&: Probs, P: [](const BranchProbability &P) {
169 return P.getDenominator() == 0;
170 });
171 (void)FirstZeroDenominator;
172 assert(FirstZeroDenominator == Probs.end());
173 const auto *FirstNonZeroNumerator = find_if(
174 Range&: Probs, P: [](const BranchProbability &P) { return !P.isZero(); });
175 assert(FirstNonZeroNumerator != Probs.end());
176 DynamicAPInt LCM(Probs[0].getDenominator());
177 DynamicAPInt GCD(FirstNonZeroNumerator->getNumerator());
178 for (const auto &Prob : drop_begin(RangeOrContainer&: Probs)) {
179 if (!Prob.getNumerator())
180 continue;
181 LCM = llvm::lcm(A: LCM, B: DynamicAPInt(Prob.getDenominator()));
182 GCD = llvm::gcd(A: GCD, B: DynamicAPInt(Prob.getNumerator()));
183 }
184 for (const auto &Prob : Probs) {
185 DynamicAPInt W =
186 (Prob.getNumerator() * LCM / GCD) / Prob.getDenominator();
187 Weights.emplace_back(Args: static_cast<uint32_t>((int64_t)W));
188 }
189 }
190 setBranchWeights(I&: *Term, Weights, /*IsExpected=*/false);
191 Changed = true;
192 }
193 return Changed;
194}
195
196PreservedAnalyses ProfileInjectorPass::run(Function &F,
197 FunctionAnalysisManager &FAM) {
198 ProfileInjector PI(F, FAM);
199 if (!PI.inject())
200 return PreservedAnalyses::all();
201
202 return PreservedAnalyses::none();
203}
204
205PreservedAnalyses ProfileVerifierPass::run(Module &M,
206 ModuleAnalysisManager &MAM) {
207 auto PopulateIgnoreList = [&](StringRef GVName) {
208 if (const auto *CT = M.getGlobalVariable(Name: GVName))
209 if (const auto *CA =
210 dyn_cast_if_present<ConstantArray>(Val: CT->getInitializer()))
211 for (const auto &Elt : CA->operands())
212 if (const auto *CS = dyn_cast<ConstantStruct>(Val: Elt))
213 if (CS->getNumOperands() >= 2 && CS->getOperand(i_nocapture: 1))
214 if (const auto *F = dyn_cast<Function>(
215 Val: CS->getOperand(i_nocapture: 1)->stripPointerCasts()))
216 IgnoreList.insert(V: F);
217 };
218 PopulateIgnoreList("llvm.global_ctors");
219 PopulateIgnoreList("llvm.global_dtors");
220
221 // expose the function-level run as public through a wrapper, so we can use
222 // pass manager mechanisms dealing with declarations and with composing the
223 // returned PreservedAnalyses values.
224 struct Wrapper : PassInfoMixin<Wrapper> {
225 ProfileVerifierPass &PVP;
226 PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM) {
227 return PVP.run(F, FAM);
228 }
229 explicit Wrapper(ProfileVerifierPass &PVP) : PVP(PVP) {}
230 };
231
232 return createModuleToFunctionPassAdaptor(Pass: Wrapper(*this)).run(M, AM&: MAM);
233}
234
235PreservedAnalyses ProfileVerifierPass::run(Function &F,
236 FunctionAnalysisManager &FAM) {
237 // skip purely asm functions
238 if (isAsmOnly(F))
239 return PreservedAnalyses::all();
240 if (IgnoreList.contains(V: &F))
241 return PreservedAnalyses::all();
242
243 const auto EntryCount = F.getEntryCount(/*AllowSynthetic=*/true);
244 if (!EntryCount) {
245 auto *MD = F.getMetadata(KindID: LLVMContext::MD_prof);
246 if (!MD || !isExplicitlyUnknownProfileMetadata(MD: *MD)) {
247 emitProfileError(Msg: "function entry count missing (set to 0 if cold)", F);
248 return PreservedAnalyses::all();
249 }
250 } else if (EntryCount->getCount() == 0) {
251 return PreservedAnalyses::all();
252 }
253 for (const auto &BB : F) {
254 if (AnnotateSelect) {
255 for (const auto &I : BB)
256 if (auto *SI = dyn_cast<SelectInst>(Val: &I)) {
257 if (SI->getCondition()->getType()->isVectorTy())
258 continue;
259 if (I.getMetadata(KindID: LLVMContext::MD_prof))
260 continue;
261 emitProfileError(Msg: "select annotation missing", F);
262 }
263 }
264 if (const auto *Term =
265 ProfileInjector::getTerminatorBenefitingFromMDProf(BB))
266 if (!Term->getMetadata(KindID: LLVMContext::MD_prof))
267 emitProfileError(Msg: "branch annotation missing", F);
268 }
269 return PreservedAnalyses::all();
270}
271