1//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with Profiling Metadata.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/IR/ProfDataUtils.h"
14
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/STLFunctionalExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/MDBuilder.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/Support/CommandLine.h"
25
26using namespace llvm;
27
28namespace llvm {
29extern cl::opt<bool> ProfcheckDisableMetadataFixes;
30}
31
32// MD_prof nodes have the following layout
33//
34// In general:
35// { String name, Array of i32 }
36//
37// In terms of Types:
38// { MDString, [i32, i32, ...]}
39//
40// Concretely for Branch Weights
41// { "branch_weights", [i32 1, i32 10000]}
42//
43// We maintain some constants here to ensure that we access the branch weights
44// correctly, and can change the behavior in the future if the layout changes
45
46// the minimum number of operands for MD_prof nodes with branch weights
47static constexpr unsigned MinBWOps = 3;
48
49// the minimum number of operands for MD_prof nodes with value profiles
50static constexpr unsigned MinVPOps = 5;
51
52// We may want to add support for other MD_prof types, so provide an abstraction
53// for checking the metadata type.
54static bool isTargetMD(const MDNode *ProfData, const char *Name,
55 unsigned MinOps) {
56 // TODO: This routine may be simplified if MD_prof used an enum instead of a
57 // string to differentiate the types of MD_prof nodes.
58 if (!ProfData || !Name || MinOps < 2)
59 return false;
60
61 unsigned NOps = ProfData->getNumOperands();
62 if (NOps < MinOps)
63 return false;
64
65 auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0));
66 if (!ProfDataName)
67 return false;
68
69 return ProfDataName->getString() == Name;
70}
71
72template <typename T,
73 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
74static void extractFromBranchWeightMD(const MDNode *ProfileData,
75 SmallVectorImpl<T> &Weights) {
76 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
77
78 unsigned NOps = ProfileData->getNumOperands();
79 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
80 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
81 Weights.resize(NOps - WeightsIdx);
82
83 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
84 ConstantInt *Weight =
85 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
86 assert(Weight && "Malformed branch_weight in MD_prof node");
87 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
88 "Too many bits for MD_prof branch_weight");
89 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
90 }
91}
92
93/// Push the weights right to fit in uint32_t.
94SmallVector<uint32_t> llvm::fitWeights(ArrayRef<uint64_t> Weights) {
95 SmallVector<uint32_t> Ret;
96 Ret.reserve(N: Weights.size());
97 uint64_t Max = *llvm::max_element(Range&: Weights);
98 if (Max > UINT_MAX) {
99 unsigned Offset = 32 - llvm::countl_zero(Val: Max);
100 for (const uint64_t &Value : Weights)
101 Ret.push_back(Elt: static_cast<uint32_t>(Value >> Offset));
102 } else {
103 append_range(C&: Ret, R&: Weights);
104 }
105 return Ret;
106}
107
108static cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
109#if defined(LLVM_ENABLE_PROFCHECK)
110 cl::init(false)
111#else
112 cl::init(Val: true)
113#endif
114);
115const char *MDProfLabels::BranchWeights = "branch_weights";
116const char *MDProfLabels::ExpectedBranchWeights = "expected";
117const char *MDProfLabels::ValueProfile = "VP";
118const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
119const char *MDProfLabels::SyntheticFunctionEntryCount =
120 "synthetic_function_entry_count";
121const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
122const char *llvm::LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
123
124bool llvm::hasProfMD(const Instruction &I) {
125 return I.hasMetadata(KindID: LLVMContext::MD_prof);
126}
127
128bool llvm::isBranchWeightMD(const MDNode *ProfileData) {
129 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::BranchWeights, MinOps: MinBWOps);
130}
131
132bool llvm::isValueProfileMD(const MDNode *ProfileData) {
133 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::ValueProfile, MinOps: MinVPOps);
134}
135
136bool llvm::hasBranchWeightMD(const Instruction &I) {
137 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
138 return isBranchWeightMD(ProfileData);
139}
140
141static bool hasCountTypeMD(const Instruction &I) {
142 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
143 // Value profiles record count-type information.
144 if (isValueProfileMD(ProfileData))
145 return true;
146 // Conservatively assume non CallBase instruction only get taken/not-taken
147 // branch probability, so not interpret them as count.
148 return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData);
149}
150
151bool llvm::hasValidBranchWeightMD(const Instruction &I) {
152 return getValidBranchWeightMDNode(I);
153}
154
155bool llvm::hasBranchWeightOrigin(const Instruction &I) {
156 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
157 return hasBranchWeightOrigin(ProfileData);
158}
159
160bool llvm::hasBranchWeightOrigin(const MDNode *ProfileData) {
161 if (!isBranchWeightMD(ProfileData))
162 return false;
163 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1));
164 // NOTE: if we ever have more types of branch weight provenance,
165 // we need to check the string value is "expected". For now, we
166 // supply a more generic API, and avoid the spurious comparisons.
167 assert(ProfDataName == nullptr ||
168 ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
169 return ProfDataName != nullptr;
170}
171
172unsigned llvm::getBranchWeightOffset(const MDNode *ProfileData) {
173 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
174}
175
176unsigned llvm::getNumBranchWeights(const MDNode &ProfileData) {
177 return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData);
178}
179
180MDNode *llvm::getBranchWeightMDNode(const Instruction &I) {
181 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
182 if (!isBranchWeightMD(ProfileData))
183 return nullptr;
184 return ProfileData;
185}
186
187MDNode *llvm::getValidBranchWeightMDNode(const Instruction &I) {
188 auto *ProfileData = getBranchWeightMDNode(I);
189 if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors())
190 return ProfileData;
191 return nullptr;
192}
193
194void llvm::extractFromBranchWeightMD32(const MDNode *ProfileData,
195 SmallVectorImpl<uint32_t> &Weights) {
196 extractFromBranchWeightMD(ProfileData, Weights);
197}
198
199void llvm::extractFromBranchWeightMD64(const MDNode *ProfileData,
200 SmallVectorImpl<uint64_t> &Weights) {
201 extractFromBranchWeightMD(ProfileData, Weights);
202}
203
204bool llvm::extractBranchWeights(const MDNode *ProfileData,
205 SmallVectorImpl<uint32_t> &Weights) {
206 if (!isBranchWeightMD(ProfileData))
207 return false;
208 extractFromBranchWeightMD(ProfileData, Weights);
209 return true;
210}
211
212bool llvm::extractBranchWeights(const Instruction &I,
213 SmallVectorImpl<uint32_t> &Weights) {
214 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
215 return extractBranchWeights(ProfileData, Weights);
216}
217
218bool llvm::extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
219 uint64_t &FalseVal) {
220 assert((isa<CondBrInst, SelectInst>(I)) &&
221 "Looking for branch weights on something besides CondBr or Select");
222
223 SmallVector<uint32_t, 2> Weights;
224 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
225 if (!extractBranchWeights(ProfileData, Weights))
226 return false;
227
228 if (Weights.size() > 2)
229 return false;
230
231 TrueVal = Weights[0];
232 FalseVal = Weights[1];
233 return true;
234}
235
236bool llvm::extractProfTotalWeight(const MDNode *ProfileData,
237 uint64_t &TotalVal) {
238 TotalVal = 0;
239 if (!ProfileData)
240 return false;
241
242 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
243 if (!ProfDataName)
244 return false;
245
246 if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
247 unsigned Offset = getBranchWeightOffset(ProfileData);
248 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
249 auto *V = mdconst::extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
250 TotalVal += V->getValue().getZExtValue();
251 }
252 return true;
253 }
254
255 if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
256 ProfileData->getNumOperands() > 3) {
257 TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2))
258 ->getValue()
259 .getZExtValue();
260 return true;
261 }
262 return false;
263}
264
265bool llvm::extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
266 return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal);
267}
268
269void llvm::setExplicitlyUnknownBranchWeights(Instruction &I,
270 StringRef PassName) {
271 MDBuilder MDB(I.getContext());
272 I.setMetadata(
273 KindID: LLVMContext::MD_prof,
274 Node: MDNode::get(Context&: I.getContext(),
275 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
276 MDB.createString(Str: PassName)}));
277}
278
279void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I,
280 StringRef PassName,
281 const Function *F) {
282 F = F ? F : I.getFunction();
283 assert(F && "Either pass a instruction attached to a Function, or explicitly "
284 "pass the Function that it will be attached to");
285 if (std::optional<Function::ProfileCount> EC = F->getEntryCount();
286 EC && EC->getCount() > 0)
287 setExplicitlyUnknownBranchWeights(I, PassName);
288}
289
290MDNode *llvm::getExplicitlyUnknownBranchWeightsIfProfiled(Function &F,
291 StringRef PassName) {
292 if (std::optional<Function::ProfileCount> EC = F.getEntryCount();
293 !EC || EC->getCount() == 0)
294 return nullptr;
295 MDBuilder MDB(F.getContext());
296 return MDNode::get(
297 Context&: F.getContext(),
298 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
299 MDB.createString(Str: PassName)});
300}
301
302void llvm::setExplicitlyUnknownFunctionEntryCount(Function &F,
303 StringRef PassName) {
304 MDBuilder MDB(F.getContext());
305 F.setMetadata(
306 KindID: LLVMContext::MD_prof,
307 Node: MDNode::get(Context&: F.getContext(),
308 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
309 MDB.createString(Str: PassName)}));
310}
311
312bool llvm::isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
313 if (MD.getNumOperands() != 2)
314 return false;
315 return MD.getOperand(I: 0).equalsStr(Str: MDProfLabels::UnknownBranchWeightsMarker);
316}
317
318bool llvm::hasExplicitlyUnknownBranchWeights(const Instruction &I) {
319 auto *MD = I.getMetadata(KindID: LLVMContext::MD_prof);
320 if (!MD)
321 return false;
322 return isExplicitlyUnknownProfileMetadata(MD: *MD);
323}
324
325void llvm::setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
326 bool IsExpected, bool ElideAllZero) {
327 if ((ElideAllZeroBranchWeights && ElideAllZero) &&
328 llvm::all_of(Range&: Weights, P: equal_to(Arg: 0))) {
329 I.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
330 return;
331 }
332
333 MDBuilder MDB(I.getContext());
334 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
335 I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights);
336}
337
338void llvm::setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
339 bool IsExpected, bool ElideAllZero) {
340 setBranchWeights(I, Weights: fitWeights(Weights), IsExpected, ElideAllZero);
341}
342
343SmallVector<uint32_t>
344llvm::downscaleWeights(ArrayRef<uint64_t> Weights,
345 std::optional<uint64_t> KnownMaxCount) {
346 uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
347 : *llvm::max_element(Range&: Weights);
348 assert(MaxCount > 0 && "Bad max count");
349 uint64_t Scale = calculateCountScale(MaxCount);
350 SmallVector<uint32_t> DownscaledWeights;
351 for (const auto &ECI : Weights)
352 DownscaledWeights.push_back(Elt: scaleBranchCount(Count: ECI, Scale));
353 return DownscaledWeights;
354}
355
356void llvm::scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
357 assert(T != 0 && "Caller should guarantee");
358 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
359 if (ProfileData == nullptr)
360 return;
361
362 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
363 if (!ProfDataName ||
364 (ProfDataName->getString() != MDProfLabels::BranchWeights &&
365 ProfDataName->getString() != MDProfLabels::ValueProfile))
366 return;
367
368 if (!hasCountTypeMD(I))
369 return;
370
371 LLVMContext &C = I.getContext();
372
373 MDBuilder MDB(C);
374 SmallVector<Metadata *, 3> Vals;
375 Vals.push_back(Elt: ProfileData->getOperand(I: 0));
376 APInt APS(128, S), APT(128, T);
377 if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
378 ProfileData->getNumOperands() > 0) {
379 // Using APInt::div may be expensive, but most cases should fit 64 bits.
380 APInt Val(128,
381 mdconst::dyn_extract<ConstantInt>(
382 MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData)))
383 ->getValue()
384 .getZExtValue());
385 Val *= APS;
386 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
387 Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX))));
388 } else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
389 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
390 // The first value is the key of the value profile, which will not change.
391 Vals.push_back(Elt: ProfileData->getOperand(I: Idx));
392 uint64_t Count =
393 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx + 1))
394 ->getValue()
395 .getZExtValue();
396 // Don't scale the magic number.
397 if (Count == NOMORE_ICP_MAGICNUM) {
398 Vals.push_back(Elt: ProfileData->getOperand(I: Idx + 1));
399 continue;
400 }
401 // Using APInt::div may be expensive, but most cases should fit 64 bits.
402 APInt Val(128, Count);
403 Val *= APS;
404 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
405 Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue())));
406 }
407 I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals));
408}
409
410void llvm::applyProfMetadataIfEnabled(
411 Value *V, llvm::function_ref<void(Instruction *)> setMetadataCallback) {
412 if (!ProfcheckDisableMetadataFixes) {
413 if (Instruction *Inst = dyn_cast<Instruction>(Val: V)) {
414 setMetadataCallback(Inst);
415 }
416 }
417}
418