1//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with Profiling Metadata.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/IR/ProfDataUtils.h"
14
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/STLFunctionalExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/MDBuilder.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/Support/CommandLine.h"
25
26using namespace llvm;
27
28namespace llvm {
29extern cl::opt<bool> ProfcheckDisableMetadataFixes;
30}
31
32// MD_prof nodes have the following layout
33//
34// In general:
35// { String name, Array of i32 }
36//
37// In terms of Types:
38// { MDString, [i32, i32, ...]}
39//
40// Concretely for Branch Weights
41// { "branch_weights", [i32 1, i32 10000]}
42//
43// We maintain some constants here to ensure that we access the branch weights
44// correctly, and can change the behavior in the future if the layout changes
45
46// the minimum number of operands for MD_prof nodes with branch weights
47static constexpr unsigned MinBWOps = 3;
48
49// the minimum number of operands for MD_prof nodes with value profiles
50static constexpr unsigned MinVPOps = 5;
51
52// We may want to add support for other MD_prof types, so provide an abstraction
53// for checking the metadata type.
54static bool isTargetMD(const MDNode *ProfData, const char *Name,
55 unsigned MinOps) {
56 // TODO: This routine may be simplified if MD_prof used an enum instead of a
57 // string to differentiate the types of MD_prof nodes.
58 if (!ProfData || !Name || MinOps < 2)
59 return false;
60
61 unsigned NOps = ProfData->getNumOperands();
62 if (NOps < MinOps)
63 return false;
64
65 auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0));
66 if (!ProfDataName)
67 return false;
68
69 return ProfDataName->getString() == Name;
70}
71
72template <typename T,
73 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
74static void extractFromBranchWeightMD(const MDNode *ProfileData,
75 SmallVectorImpl<T> &Weights) {
76 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
77
78 unsigned NOps = ProfileData->getNumOperands();
79 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
80 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
81 Weights.resize(NOps - WeightsIdx);
82
83 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
84 ConstantInt *Weight =
85 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
86 assert(Weight && "Malformed branch_weight in MD_prof node");
87 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
88 "Too many bits for MD_prof branch_weight");
89 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
90 }
91}
92
93/// Push the weights right to fit in uint32_t.
94SmallVector<uint32_t> llvm::fitWeights(ArrayRef<uint64_t> Weights) {
95 SmallVector<uint32_t> Ret;
96 Ret.reserve(N: Weights.size());
97 uint64_t Max = *llvm::max_element(Range&: Weights);
98 if (Max > UINT_MAX) {
99 unsigned Offset = 32 - llvm::countl_zero(Val: Max);
100 for (const uint64_t &Value : Weights)
101 Ret.push_back(Elt: static_cast<uint32_t>(Value >> Offset));
102 } else {
103 append_range(C&: Ret, R&: Weights);
104 }
105 return Ret;
106}
107
108static cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
109#if defined(LLVM_ENABLE_PROFCHECK)
110 cl::init(false)
111#else
112 cl::init(Val: true)
113#endif
114);
115const char *MDProfLabels::BranchWeights = "branch_weights";
116const char *MDProfLabels::ExpectedBranchWeights = "expected";
117const char *MDProfLabels::ValueProfile = "VP";
118const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
119const char *MDProfLabels::SyntheticFunctionEntryCount =
120 "synthetic_function_entry_count";
121const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
122const char *llvm::LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
123
124bool llvm::hasProfMD(const Instruction &I) {
125 return I.hasMetadata(KindID: LLVMContext::MD_prof);
126}
127
128bool llvm::isBranchWeightMD(const MDNode *ProfileData) {
129 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::BranchWeights, MinOps: MinBWOps);
130}
131
132bool llvm::isValueProfileMD(const MDNode *ProfileData) {
133 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::ValueProfile, MinOps: MinVPOps);
134}
135
136bool llvm::hasBranchWeightMD(const Instruction &I) {
137 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
138 return isBranchWeightMD(ProfileData);
139}
140
141static bool hasCountTypeMD(const Instruction &I) {
142 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
143 // Value profiles record count-type information.
144 if (isValueProfileMD(ProfileData))
145 return true;
146 // Conservatively assume non CallBase instruction only get taken/not-taken
147 // branch probability, so not interpret them as count.
148 return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData);
149}
150
151bool llvm::hasValidBranchWeightMD(const Instruction &I) {
152 return getValidBranchWeightMDNode(I);
153}
154
155bool llvm::hasBranchWeightOrigin(const Instruction &I) {
156 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
157 return hasBranchWeightOrigin(ProfileData);
158}
159
160bool llvm::hasBranchWeightOrigin(const MDNode *ProfileData) {
161 if (!isBranchWeightMD(ProfileData))
162 return false;
163 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1));
164 // NOTE: if we ever have more types of branch weight provenance,
165 // we need to check the string value is "expected". For now, we
166 // supply a more generic API, and avoid the spurious comparisons.
167 assert(ProfDataName == nullptr ||
168 ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
169 return ProfDataName != nullptr;
170}
171
172unsigned llvm::getBranchWeightOffset(const MDNode *ProfileData) {
173 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
174}
175
176unsigned llvm::getNumBranchWeights(const MDNode &ProfileData) {
177 return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData);
178}
179
180MDNode *llvm::getBranchWeightMDNode(const Instruction &I) {
181 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
182 if (!isBranchWeightMD(ProfileData))
183 return nullptr;
184 return ProfileData;
185}
186
187MDNode *llvm::getValidBranchWeightMDNode(const Instruction &I) {
188 auto *ProfileData = getBranchWeightMDNode(I);
189 if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors())
190 return ProfileData;
191 return nullptr;
192}
193
194void llvm::extractFromBranchWeightMD32(const MDNode *ProfileData,
195 SmallVectorImpl<uint32_t> &Weights) {
196 extractFromBranchWeightMD(ProfileData, Weights);
197}
198
199void llvm::extractFromBranchWeightMD64(const MDNode *ProfileData,
200 SmallVectorImpl<uint64_t> &Weights) {
201 extractFromBranchWeightMD(ProfileData, Weights);
202}
203
204bool llvm::extractBranchWeights(const MDNode *ProfileData,
205 SmallVectorImpl<uint32_t> &Weights) {
206 if (!isBranchWeightMD(ProfileData))
207 return false;
208 extractFromBranchWeightMD(ProfileData, Weights);
209 return true;
210}
211
212bool llvm::extractBranchWeights(const Instruction &I,
213 SmallVectorImpl<uint32_t> &Weights) {
214 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
215 return extractBranchWeights(ProfileData, Weights);
216}
217
218bool llvm::extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
219 uint64_t &FalseVal) {
220 assert((I.getOpcode() == Instruction::Br ||
221 I.getOpcode() == Instruction::Select) &&
222 "Looking for branch weights on something besides branch, select, or "
223 "switch");
224
225 SmallVector<uint32_t, 2> Weights;
226 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
227 if (!extractBranchWeights(ProfileData, Weights))
228 return false;
229
230 if (Weights.size() > 2)
231 return false;
232
233 TrueVal = Weights[0];
234 FalseVal = Weights[1];
235 return true;
236}
237
238bool llvm::extractProfTotalWeight(const MDNode *ProfileData,
239 uint64_t &TotalVal) {
240 TotalVal = 0;
241 if (!ProfileData)
242 return false;
243
244 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
245 if (!ProfDataName)
246 return false;
247
248 if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
249 unsigned Offset = getBranchWeightOffset(ProfileData);
250 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
251 auto *V = mdconst::extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
252 TotalVal += V->getValue().getZExtValue();
253 }
254 return true;
255 }
256
257 if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
258 ProfileData->getNumOperands() > 3) {
259 TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2))
260 ->getValue()
261 .getZExtValue();
262 return true;
263 }
264 return false;
265}
266
267bool llvm::extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
268 return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal);
269}
270
271void llvm::setExplicitlyUnknownBranchWeights(Instruction &I,
272 StringRef PassName) {
273 MDBuilder MDB(I.getContext());
274 I.setMetadata(
275 KindID: LLVMContext::MD_prof,
276 Node: MDNode::get(Context&: I.getContext(),
277 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
278 MDB.createString(Str: PassName)}));
279}
280
281void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I,
282 StringRef PassName,
283 const Function *F) {
284 F = F ? F : I.getFunction();
285 assert(F && "Either pass a instruction attached to a Function, or explicitly "
286 "pass the Function that it will be attached to");
287 if (std::optional<Function::ProfileCount> EC = F->getEntryCount();
288 EC && EC->getCount() > 0)
289 setExplicitlyUnknownBranchWeights(I, PassName);
290}
291
292void llvm::setExplicitlyUnknownFunctionEntryCount(Function &F,
293 StringRef PassName) {
294 MDBuilder MDB(F.getContext());
295 F.setMetadata(
296 KindID: LLVMContext::MD_prof,
297 Node: MDNode::get(Context&: F.getContext(),
298 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
299 MDB.createString(Str: PassName)}));
300}
301
302bool llvm::isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
303 if (MD.getNumOperands() != 2)
304 return false;
305 return MD.getOperand(I: 0).equalsStr(Str: MDProfLabels::UnknownBranchWeightsMarker);
306}
307
308bool llvm::hasExplicitlyUnknownBranchWeights(const Instruction &I) {
309 auto *MD = I.getMetadata(KindID: LLVMContext::MD_prof);
310 if (!MD)
311 return false;
312 return isExplicitlyUnknownProfileMetadata(MD: *MD);
313}
314
315void llvm::setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
316 bool IsExpected, bool ElideAllZero) {
317 if ((ElideAllZeroBranchWeights && ElideAllZero) &&
318 llvm::all_of(Range&: Weights, P: equal_to(Arg: 0))) {
319 I.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
320 return;
321 }
322
323 MDBuilder MDB(I.getContext());
324 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
325 I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights);
326}
327
328void llvm::setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
329 bool IsExpected, bool ElideAllZero) {
330 setBranchWeights(I, Weights: fitWeights(Weights), IsExpected, ElideAllZero);
331}
332
333SmallVector<uint32_t>
334llvm::downscaleWeights(ArrayRef<uint64_t> Weights,
335 std::optional<uint64_t> KnownMaxCount) {
336 uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
337 : *llvm::max_element(Range&: Weights);
338 assert(MaxCount > 0 && "Bad max count");
339 uint64_t Scale = calculateCountScale(MaxCount);
340 SmallVector<uint32_t> DownscaledWeights;
341 for (const auto &ECI : Weights)
342 DownscaledWeights.push_back(Elt: scaleBranchCount(Count: ECI, Scale));
343 return DownscaledWeights;
344}
345
346void llvm::scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
347 assert(T != 0 && "Caller should guarantee");
348 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
349 if (ProfileData == nullptr)
350 return;
351
352 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
353 if (!ProfDataName ||
354 (ProfDataName->getString() != MDProfLabels::BranchWeights &&
355 ProfDataName->getString() != MDProfLabels::ValueProfile))
356 return;
357
358 if (!hasCountTypeMD(I))
359 return;
360
361 LLVMContext &C = I.getContext();
362
363 MDBuilder MDB(C);
364 SmallVector<Metadata *, 3> Vals;
365 Vals.push_back(Elt: ProfileData->getOperand(I: 0));
366 APInt APS(128, S), APT(128, T);
367 if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
368 ProfileData->getNumOperands() > 0) {
369 // Using APInt::div may be expensive, but most cases should fit 64 bits.
370 APInt Val(128,
371 mdconst::dyn_extract<ConstantInt>(
372 MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData)))
373 ->getValue()
374 .getZExtValue());
375 Val *= APS;
376 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
377 Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX))));
378 } else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
379 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
380 // The first value is the key of the value profile, which will not change.
381 Vals.push_back(Elt: ProfileData->getOperand(I: Idx));
382 uint64_t Count =
383 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx + 1))
384 ->getValue()
385 .getZExtValue();
386 // Don't scale the magic number.
387 if (Count == NOMORE_ICP_MAGICNUM) {
388 Vals.push_back(Elt: ProfileData->getOperand(I: Idx + 1));
389 continue;
390 }
391 // Using APInt::div may be expensive, but most cases should fit 64 bits.
392 APInt Val(128, Count);
393 Val *= APS;
394 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
395 Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue())));
396 }
397 I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals));
398}
399
400void llvm::applyProfMetadataIfEnabled(
401 Value *V, llvm::function_ref<void(Instruction *)> setMetadataCallback) {
402 if (!ProfcheckDisableMetadataFixes) {
403 if (Instruction *Inst = dyn_cast<Instruction>(Val: V)) {
404 setMetadataCallback(Inst);
405 }
406 }
407}
408