1//===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements utilities for working with Profiling Metadata.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/IR/ProfDataUtils.h"
14
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/IR/Constants.h"
17#include "llvm/IR/Function.h"
18#include "llvm/IR/Instructions.h"
19#include "llvm/IR/LLVMContext.h"
20#include "llvm/IR/MDBuilder.h"
21#include "llvm/IR/Metadata.h"
22
23using namespace llvm;
24
25namespace {
26
27// MD_prof nodes have the following layout
28//
29// In general:
30// { String name, Array of i32 }
31//
32// In terms of Types:
33// { MDString, [i32, i32, ...]}
34//
35// Concretely for Branch Weights
36// { "branch_weights", [i32 1, i32 10000]}
37//
38// We maintain some constants here to ensure that we access the branch weights
39// correctly, and can change the behavior in the future if the layout changes
40
41// the minimum number of operands for MD_prof nodes with branch weights
42constexpr unsigned MinBWOps = 3;
43
44// the minimum number of operands for MD_prof nodes with value profiles
45constexpr unsigned MinVPOps = 5;
46
47// We may want to add support for other MD_prof types, so provide an abstraction
48// for checking the metadata type.
49bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) {
50 // TODO: This routine may be simplified if MD_prof used an enum instead of a
51 // string to differentiate the types of MD_prof nodes.
52 if (!ProfData || !Name || MinOps < 2)
53 return false;
54
55 unsigned NOps = ProfData->getNumOperands();
56 if (NOps < MinOps)
57 return false;
58
59 auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0));
60 if (!ProfDataName)
61 return false;
62
63 return ProfDataName->getString() == Name;
64}
65
66template <typename T,
67 typename = typename std::enable_if<std::is_arithmetic_v<T>>>
68static void extractFromBranchWeightMD(const MDNode *ProfileData,
69 SmallVectorImpl<T> &Weights) {
70 assert(isBranchWeightMD(ProfileData) && "wrong metadata");
71
72 unsigned NOps = ProfileData->getNumOperands();
73 unsigned WeightsIdx = getBranchWeightOffset(ProfileData);
74 assert(WeightsIdx < NOps && "Weights Index must be less than NOps.");
75 Weights.resize(NOps - WeightsIdx);
76
77 for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) {
78 ConstantInt *Weight =
79 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
80 assert(Weight && "Malformed branch_weight in MD_prof node");
81 assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) &&
82 "Too many bits for MD_prof branch_weight");
83 Weights[Idx - WeightsIdx] = Weight->getZExtValue();
84 }
85}
86
87} // namespace
88
89namespace llvm {
90
91const char *MDProfLabels::BranchWeights = "branch_weights";
92const char *MDProfLabels::ExpectedBranchWeights = "expected";
93const char *MDProfLabels::ValueProfile = "VP";
94const char *MDProfLabels::FunctionEntryCount = "function_entry_count";
95const char *MDProfLabels::SyntheticFunctionEntryCount =
96 "synthetic_function_entry_count";
97const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown";
98
99bool hasProfMD(const Instruction &I) {
100 return I.hasMetadata(KindID: LLVMContext::MD_prof);
101}
102
103bool isBranchWeightMD(const MDNode *ProfileData) {
104 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::BranchWeights, MinOps: MinBWOps);
105}
106
107bool isValueProfileMD(const MDNode *ProfileData) {
108 return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::ValueProfile, MinOps: MinVPOps);
109}
110
111bool hasBranchWeightMD(const Instruction &I) {
112 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
113 return isBranchWeightMD(ProfileData);
114}
115
116static bool hasCountTypeMD(const Instruction &I) {
117 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
118 // Value profiles record count-type information.
119 if (isValueProfileMD(ProfileData))
120 return true;
121 // Conservatively assume non CallBase instruction only get taken/not-taken
122 // branch probability, so not interpret them as count.
123 return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData);
124}
125
126bool hasValidBranchWeightMD(const Instruction &I) {
127 return getValidBranchWeightMDNode(I);
128}
129
130bool hasBranchWeightOrigin(const Instruction &I) {
131 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
132 return hasBranchWeightOrigin(ProfileData);
133}
134
135bool hasBranchWeightOrigin(const MDNode *ProfileData) {
136 if (!isBranchWeightMD(ProfileData))
137 return false;
138 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1));
139 // NOTE: if we ever have more types of branch weight provenance,
140 // we need to check the string value is "expected". For now, we
141 // supply a more generic API, and avoid the spurious comparisons.
142 assert(ProfDataName == nullptr ||
143 ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights);
144 return ProfDataName != nullptr;
145}
146
147unsigned getBranchWeightOffset(const MDNode *ProfileData) {
148 return hasBranchWeightOrigin(ProfileData) ? 2 : 1;
149}
150
151unsigned getNumBranchWeights(const MDNode &ProfileData) {
152 return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData);
153}
154
155MDNode *getBranchWeightMDNode(const Instruction &I) {
156 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
157 if (!isBranchWeightMD(ProfileData))
158 return nullptr;
159 return ProfileData;
160}
161
162MDNode *getValidBranchWeightMDNode(const Instruction &I) {
163 auto *ProfileData = getBranchWeightMDNode(I);
164 if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors())
165 return ProfileData;
166 return nullptr;
167}
168
169void extractFromBranchWeightMD32(const MDNode *ProfileData,
170 SmallVectorImpl<uint32_t> &Weights) {
171 extractFromBranchWeightMD(ProfileData, Weights);
172}
173
174void extractFromBranchWeightMD64(const MDNode *ProfileData,
175 SmallVectorImpl<uint64_t> &Weights) {
176 extractFromBranchWeightMD(ProfileData, Weights);
177}
178
179bool extractBranchWeights(const MDNode *ProfileData,
180 SmallVectorImpl<uint32_t> &Weights) {
181 if (!isBranchWeightMD(ProfileData))
182 return false;
183 extractFromBranchWeightMD(ProfileData, Weights);
184 return true;
185}
186
187bool extractBranchWeights(const Instruction &I,
188 SmallVectorImpl<uint32_t> &Weights) {
189 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
190 return extractBranchWeights(ProfileData, Weights);
191}
192
193bool extractBranchWeights(const Instruction &I, uint64_t &TrueVal,
194 uint64_t &FalseVal) {
195 assert((I.getOpcode() == Instruction::Br ||
196 I.getOpcode() == Instruction::Select) &&
197 "Looking for branch weights on something besides branch, select, or "
198 "switch");
199
200 SmallVector<uint32_t, 2> Weights;
201 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
202 if (!extractBranchWeights(ProfileData, Weights))
203 return false;
204
205 if (Weights.size() > 2)
206 return false;
207
208 TrueVal = Weights[0];
209 FalseVal = Weights[1];
210 return true;
211}
212
213bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalVal) {
214 TotalVal = 0;
215 if (!ProfileData)
216 return false;
217
218 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
219 if (!ProfDataName)
220 return false;
221
222 if (ProfDataName->getString() == MDProfLabels::BranchWeights) {
223 unsigned Offset = getBranchWeightOffset(ProfileData);
224 for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) {
225 auto *V = mdconst::extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx));
226 TotalVal += V->getValue().getZExtValue();
227 }
228 return true;
229 }
230
231 if (ProfDataName->getString() == MDProfLabels::ValueProfile &&
232 ProfileData->getNumOperands() > 3) {
233 TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2))
234 ->getValue()
235 .getZExtValue();
236 return true;
237 }
238 return false;
239}
240
241bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
242 return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal);
243}
244
245void setExplicitlyUnknownBranchWeights(Instruction &I) {
246 MDBuilder MDB(I.getContext());
247 I.setMetadata(
248 KindID: LLVMContext::MD_prof,
249 Node: MDNode::get(Context&: I.getContext(),
250 MDs: MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker)));
251}
252
253bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) {
254 if (MD.getNumOperands() != 1)
255 return false;
256 return MD.getOperand(I: 0).equalsStr(Str: MDProfLabels::UnknownBranchWeightsMarker);
257}
258
259bool hasExplicitlyUnknownBranchWeights(const Instruction &I) {
260 auto *MD = I.getMetadata(KindID: LLVMContext::MD_prof);
261 if (!MD)
262 return false;
263 return isExplicitlyUnknownBranchWeightsMetadata(MD: *MD);
264}
265
266void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights,
267 bool IsExpected) {
268 MDBuilder MDB(I.getContext());
269 MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected);
270 I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights);
271}
272
273void scaleProfData(Instruction &I, uint64_t S, uint64_t T) {
274 assert(T != 0 && "Caller should guarantee");
275 auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof);
276 if (ProfileData == nullptr)
277 return;
278
279 auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0));
280 if (!ProfDataName ||
281 (ProfDataName->getString() != MDProfLabels::BranchWeights &&
282 ProfDataName->getString() != MDProfLabels::ValueProfile))
283 return;
284
285 if (!hasCountTypeMD(I))
286 return;
287
288 LLVMContext &C = I.getContext();
289
290 MDBuilder MDB(C);
291 SmallVector<Metadata *, 3> Vals;
292 Vals.push_back(Elt: ProfileData->getOperand(I: 0));
293 APInt APS(128, S), APT(128, T);
294 if (ProfDataName->getString() == MDProfLabels::BranchWeights &&
295 ProfileData->getNumOperands() > 0) {
296 // Using APInt::div may be expensive, but most cases should fit 64 bits.
297 APInt Val(128,
298 mdconst::dyn_extract<ConstantInt>(
299 MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData)))
300 ->getValue()
301 .getZExtValue());
302 Val *= APS;
303 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
304 Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX))));
305 } else if (ProfDataName->getString() == MDProfLabels::ValueProfile)
306 for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) {
307 // The first value is the key of the value profile, which will not change.
308 Vals.push_back(Elt: ProfileData->getOperand(I: Idx));
309 uint64_t Count =
310 mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx + 1))
311 ->getValue()
312 .getZExtValue();
313 // Don't scale the magic number.
314 if (Count == NOMORE_ICP_MAGICNUM) {
315 Vals.push_back(Elt: ProfileData->getOperand(I: Idx + 1));
316 continue;
317 }
318 // Using APInt::div may be expensive, but most cases should fit 64 bits.
319 APInt Val(128, Count);
320 Val *= APS;
321 Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get(
322 Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue())));
323 }
324 I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals));
325}
326
327} // namespace llvm
328