1//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with Profiling Metadata.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/IR/ProfDataUtils.h"
14
15#include "llvm/ADT/STLExtras.h"
16#include "llvm/ADT/STLFunctionalExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/IR/Constants.h"
19#include "llvm/IR/Function.h"
20#include "llvm/IR/Instructions.h"
21#include "llvm/IR/LLVMContext.h"
22#include "llvm/IR/MDBuilder.h"
23#include "llvm/IR/Metadata.h"
24#include "llvm/Support/CommandLine.h"
25
26using namespace llvm;
27
28namespace llvm {
29extern cl::opt<bool> ProfcheckDisableMetadataFixes;
30}
31
32// MD_prof nodes have the following layout
33//
34// In general:
35// { String name, Array of i32 }
36//
37// In terms of Types:
38// { MDString, [i32, i32, ...]}
39//
40// Concretely for Branch Weights
41// { "branch_weights", [i32 1, i32 10000]}
42//
43// We maintain some constants here to ensure that we access the branch weights
44// correctly, and can change the behavior in the future if the layout changes
45
46// the minimum number of operands for MD_prof nodes with branch weights
47static constexpr unsigned MinBWOps = 3;
48
49// the minimum number of operands for MD_prof nodes with value profiles
50static constexpr unsigned MinVPOps = 5;
51
52// We may want to add support for other MD_prof types, so provide an abstraction
53// for checking the metadata type.
54static bool isTargetMD(const MDNode *ProfData, const char *Name,
55 unsigned MinOps) {
56 // TODO: This routine may be simplified if MD_prof used an enum instead of a
57 // string to differentiate the types of MD_prof nodes.
58 if (!ProfData || !Name || MinOps < 2)
59 return false;
60
61 unsigned NOps = ProfData->getNumOperands();
62 if (NOps < MinOps)
63 return false;
64
65 auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0));
66 if (!ProfDataName)
67 return false;
68
69 return ProfDataName->getString() == Name;
70}
71
72template <typename T,
73 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
74static void extractFromBranchWeightMD(const MDNode *ProfileData,
75 SmallVectorImpl<T> &Weights) {
76 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
77
78 unsigned NOps = ProfileData->getNumOperands();
79 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
80 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
81 Weights.resize(NOps - WeightsIdx);
82
83 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
84 ConstantInt *Weight =
85 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
86 assert(Weight && "Malformed branch_weight in MD_prof node");
87 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
88 "Too many bits for MD_prof branch_weight");
89 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
90 }
91}
92
93/// Push the weights right to fit in uint32_t.
94SmallVector<uint32_t> llvm::fitWeights(ArrayRef<uint64_t> Weights) {
95 SmallVector<uint32_t> Ret;
96 Ret.reserve(N: Weights.size());
97 uint64_t Max = *llvm::max_element(Range&: Weights);
98 if (Max > UINT_MAX) {
99 unsigned Offset = 32 - llvm::countl_zero(Val: Max);
100 for (const uint64_t &Value : Weights)
101 Ret.push_back(Elt: static_cast<uint32_t>(Value >> Offset));
102 } else {
103 append_range(C&: Ret, R&: Weights);
104 }
105 return Ret;
106}
107
108static cl::opt<bool> ElideAllZeroBranchWeights("elide-all-zero-branch-weights",
109#if defined(LLVM_ENABLE_PROFCHECK)
110 cl::init(false)
111#else
112 cl::init(Val: true)
113#endif
114);
115const char *MDProfLabels::BranchWeights = "branch_weights";
116const char *MDProfLabels::ExpectedBranchWeights = "expected";
117const char *MDProfLabels::ValueProfile = "VP";
118const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
119const char *MDProfLabels::SyntheticFunctionEntryCount =
120 "synthetic_function_entry_count";
121const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
122const char *llvm::LLVMLoopEstimatedTripCount = "llvm.loop.estimated_trip_count";
123
124bool llvm::hasProfMD(const Instruction &I) {
125 return I.hasMetadata(KindID: LLVMContext::MD_prof);
126}
127
128bool llvm::isBranchWeightMD(const MDNode *ProfileData) {
129 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::BranchWeights, MinOps: MinBWOps);
130}
131
132bool llvm::isValueProfileMD(const MDNode *ProfileData) {
133 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::ValueProfile, MinOps: MinVPOps);
134}
135
136bool llvm::hasBranchWeightMD(const Instruction &I) {
137 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
138 return isBranchWeightMD(ProfileData);
139}
140
141static bool hasCountTypeMD(const Instruction &I) {
142 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
143 // Value profiles record count-type information.
144 if (isValueProfileMD(ProfileData))
145 return true;
146 // Conservatively assume non CallBase instruction only get taken/not-taken
147 // branch probability, so not interpret them as count.
148 return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData);
149}
150
151bool llvm::hasValidBranchWeightMD(const Instruction &I) {
152 return getValidBranchWeightMDNode(I);
153}
154
155bool llvm::hasBranchWeightOrigin(const Instruction &I) {
156 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
157 return hasBranchWeightOrigin(ProfileData);
158}
159
160bool llvm::hasBranchWeightOrigin(const MDNode *ProfileData) {
161 if (!isBranchWeightMD(ProfileData))
162 return false;
163 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1));
164 // NOTE: if we ever have more types of branch weight provenance,
165 // we need to check the string value is "expected". For now, we
166 // supply a more generic API, and avoid the spurious comparisons.
167 assert(ProfDataName == nullptr ||
168 ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
169 return ProfDataName != nullptr;
170}
171
172unsigned llvm::getBranchWeightOffset(const MDNode *ProfileData) {
173 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
174}
175
176unsigned llvm::getNumBranchWeights(const MDNode &ProfileData) {
177 return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData);
178}
179
180MDNode *llvm::getBranchWeightMDNode(const Instruction &I) {
181 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
182 if (!isBranchWeightMD(ProfileData))
183 return nullptr;
184 return ProfileData;
185}
186
187MDNode *llvm::getValidBranchWeightMDNode(const Instruction &I) {
188 auto *ProfileData = getBranchWeightMDNode(I);
189 if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors())
190 return ProfileData;
191 return nullptr;
192}
193
194void llvm::extractFromBranchWeightMD32(const MDNode *ProfileData,
195 SmallVectorImpl<uint32_t> &Weights) {
196 extractFromBranchWeightMD(ProfileData, Weights);
197}
198
199void llvm::extractFromBranchWeightMD64(const MDNode *ProfileData,
200 SmallVectorImpl<uint64_t> &Weights) {
201 extractFromBranchWeightMD(ProfileData, Weights);
202}
203
204bool llvm::extractBranchWeights(const MDNode *ProfileData,
205 SmallVectorImpl<uint32_t> &Weights) {
206 if (!isBranchWeightMD(ProfileData))
207 return false;
208 extractFromBranchWeightMD(ProfileData, Weights);
209 return true;
210}
211
212bool llvm::extractBranchWeights(const Instruction &I,
213 SmallVectorImpl<uint32_t> &Weights) {
214 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
215 return extractBranchWeights(ProfileData, Weights);
216}
217
218bool llvm::extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
219 uint64_t &FalseVal) {
220 assert((isa<CondBrInst, SelectInst>(I)) &&
221 "Looking for branch weights on something besides CondBr or Select");
222
223 SmallVector<uint32_t, 2> Weights;
224 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
225 if (!extractBranchWeights(ProfileData, Weights))
226 return false;
227
228 if (Weights.size() > 2)
229 return false;
230
231 TrueVal = Weights[0];
232 FalseVal = Weights[1];
233 return true;
234}
235
236bool llvm::extractProfTotalWeight(const MDNode *ProfileData,
237 uint64_t &TotalVal) {
238 TotalVal = 0;
239 if (!ProfileData)
240 return false;
241
242 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
243 if (!ProfDataName)
244 return false;
245
246 if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
247 unsigned Offset = getBranchWeightOffset(ProfileData);
248 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
249 auto *V = mdconst::extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
250 TotalVal += V->getValue().getZExtValue();
251 }
252 return true;
253 }
254
255 if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
256 ProfileData->getNumOperands() > 3) {
257 TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2))
258 ->getValue()
259 .getZExtValue();
260 return true;
261 }
262 return false;
263}
264
265bool llvm::extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
266 return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal);
267}
268
269void llvm::setExplicitlyUnknownBranchWeights(Instruction &I,
270 StringRef PassName) {
271 MDBuilder MDB(I.getContext());
272 I.setMetadata(
273 KindID: LLVMContext::MD_prof,
274 Node: MDNode::get(Context&: I.getContext(),
275 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
276 MDB.createString(Str: PassName)}));
277}
278
279void llvm::setExplicitlyUnknownBranchWeightsIfProfiled(Instruction &I,
280 StringRef PassName,
281 const Function *F) {
282 F = F ? F : I.getFunction();
283 assert(F && "Either pass a instruction attached to a Function, or explicitly "
284 "pass the Function that it will be attached to");
285 if (std::optional<uint64_t> EC = F->getEntryCount(); EC && *EC > 0)
286 setExplicitlyUnknownBranchWeights(I, PassName);
287}
288
289MDNode *llvm::getExplicitlyUnknownBranchWeightsIfProfiled(Function &F,
290 StringRef PassName) {
291 if (std::optional<uint64_t> EC = F.getEntryCount(); !EC || *EC == 0)
292 return nullptr;
293 MDBuilder MDB(F.getContext());
294 return MDNode::get(
295 Context&: F.getContext(),
296 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
297 MDB.createString(Str: PassName)});
298}
299
300void llvm::setExplicitlyUnknownFunctionEntryCount(Function &F,
301 StringRef PassName) {
302 MDBuilder MDB(F.getContext());
303 F.setMetadata(
304 KindID: LLVMContext::MD_prof,
305 Node: MDNode::get(Context&: F.getContext(),
306 MDs: {MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker),
307 MDB.createString(Str: PassName)}));
308}
309
310bool llvm::isExplicitlyUnknownProfileMetadata(const MDNode &MD) {
311 if (MD.getNumOperands() != 2)
312 return false;
313 return MD.getOperand(I: 0).equalsStr(Str: MDProfLabels::UnknownBranchWeightsMarker);
314}
315
316bool llvm::hasExplicitlyUnknownBranchWeights(const Instruction &I) {
317 auto *MD = I.getMetadata(KindID: LLVMContext::MD_prof);
318 if (!MD)
319 return false;
320 return isExplicitlyUnknownProfileMetadata(MD: *MD);
321}
322
323void llvm::setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
324 bool IsExpected, bool ElideAllZero) {
325 if ((ElideAllZeroBranchWeights && ElideAllZero) &&
326 llvm::all_of(Range&: Weights, P: equal_to(Arg: 0))) {
327 I.setMetadata(KindID: LLVMContext::MD_prof, Node: nullptr);
328 return;
329 }
330
331 MDBuilder MDB(I.getContext());
332 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
333 I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights);
334}
335
336void llvm::setFittedBranchWeights(Instruction &I, ArrayRef<uint64_t> Weights,
337 bool IsExpected, bool ElideAllZero) {
338 setBranchWeights(I, Weights: fitWeights(Weights), IsExpected, ElideAllZero);
339}
340
341SmallVector<uint32_t>
342llvm::downscaleWeights(ArrayRef<uint64_t> Weights,
343 std::optional<uint64_t> KnownMaxCount) {
344 uint64_t MaxCount = KnownMaxCount.has_value() ? KnownMaxCount.value()
345 : *llvm::max_element(Range&: Weights);
346 assert(MaxCount > 0 && "Bad max count");
347 uint64_t Scale = calculateCountScale(MaxCount);
348 SmallVector<uint32_t> DownscaledWeights;
349 for (const auto &ECI : Weights)
350 DownscaledWeights.push_back(Elt: scaleBranchCount(Count: ECI, Scale));
351 return DownscaledWeights;
352}
353
354void llvm::scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
355 assert(T != 0 && "Caller should guarantee");
356 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
357 if (ProfileData == nullptr)
358 return;
359
360 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
361 if (!ProfDataName ||
362 (ProfDataName->getString() != MDProfLabels::BranchWeights &&
363 ProfDataName->getString() != MDProfLabels::ValueProfile))
364 return;
365
366 if (!hasCountTypeMD(I))
367 return;
368
369 LLVMContext &C = I.getContext();
370
371 MDBuilder MDB(C);
372 SmallVector<Metadata *, 3> Vals;
373 Vals.push_back(Elt: ProfileData->getOperand(I: 0));
374 APInt APS(128, S), APT(128, T);
375 if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
376 ProfileData->getNumOperands() > 0) {
377 // Using APInt::div may be expensive, but most cases should fit 64 bits.
378 APInt Val(128,
379 mdconst::dyn_extract<ConstantInt>(
380 MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData)))
381 ->getValue()
382 .getZExtValue());
383 Val *= APS;
384 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
385 Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX))));
386 } else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
387 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
388 // The first value is the key of the value profile, which will not change.
389 Vals.push_back(Elt: ProfileData->getOperand(I: Idx));
390 uint64_t Count =
391 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx + 1))
392 ->getValue()
393 .getZExtValue();
394 // Don't scale the magic number.
395 if (Count == NOMORE_ICP_MAGICNUM) {
396 Vals.push_back(Elt: ProfileData->getOperand(I: Idx + 1));
397 continue;
398 }
399 // Using APInt::div may be expensive, but most cases should fit 64 bits.
400 APInt Val(128, Count);
401 Val *= APS;
402 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
403 Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue())));
404 }
405 I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals));
406}
407
408void llvm::applyProfMetadataIfEnabled(
409 Value *V, llvm::function_ref<void(Instruction *)> setMetadataCallback) {
410 if (!ProfcheckDisableMetadataFixes) {
411 if (Instruction *Inst = dyn_cast<Instruction>(Val: V)) {
412 setMetadataCallback(Inst);
413 }
414 }
415}
416