1 | //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for working with Profiling Metadata. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "llvm/IR/ProfDataUtils.h" |
14 | |
15 | #include "llvm/ADT/SmallVector.h" |
16 | #include "llvm/IR/Constants.h" |
17 | #include "llvm/IR/Function.h" |
18 | #include "llvm/IR/Instructions.h" |
19 | #include "llvm/IR/LLVMContext.h" |
20 | #include "llvm/IR/MDBuilder.h" |
21 | #include "llvm/IR/Metadata.h" |
22 | |
23 | using namespace llvm; |
24 | |
25 | namespace { |
26 | |
27 | // MD_prof nodes have the following layout |
28 | // |
29 | // In general: |
30 | // { String name, Array of i32 } |
31 | // |
32 | // In terms of Types: |
33 | // { MDString, [i32, i32, ...]} |
34 | // |
35 | // Concretely for Branch Weights |
36 | // { "branch_weights", [i32 1, i32 10000]} |
37 | // |
38 | // We maintain some constants here to ensure that we access the branch weights |
39 | // correctly, and can change the behavior in the future if the layout changes |
40 | |
41 | // the minimum number of operands for MD_prof nodes with branch weights |
42 | constexpr unsigned MinBWOps = 3; |
43 | |
44 | // the minimum number of operands for MD_prof nodes with value profiles |
45 | constexpr unsigned MinVPOps = 5; |
46 | |
47 | // We may want to add support for other MD_prof types, so provide an abstraction |
48 | // for checking the metadata type. |
49 | bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { |
50 | // TODO: This routine may be simplified if MD_prof used an enum instead of a |
51 | // string to differentiate the types of MD_prof nodes. |
52 | if (!ProfData || !Name || MinOps < 2) |
53 | return false; |
54 | |
55 | unsigned NOps = ProfData->getNumOperands(); |
56 | if (NOps < MinOps) |
57 | return false; |
58 | |
59 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0)); |
60 | if (!ProfDataName) |
61 | return false; |
62 | |
63 | return ProfDataName->getString() == Name; |
64 | } |
65 | |
66 | template <typename T, |
67 | typename = typename std::enable_if<std::is_arithmetic_v<T>>> |
68 | static void (const MDNode *ProfileData, |
69 | SmallVectorImpl<T> &Weights) { |
70 | assert(isBranchWeightMD(ProfileData) && "wrong metadata" ); |
71 | |
72 | unsigned NOps = ProfileData->getNumOperands(); |
73 | unsigned WeightsIdx = getBranchWeightOffset(ProfileData); |
74 | assert(WeightsIdx < NOps && "Weights Index must be less than NOps." ); |
75 | Weights.resize(NOps - WeightsIdx); |
76 | |
77 | for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { |
78 | ConstantInt *Weight = |
79 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
80 | assert(Weight && "Malformed branch_weight in MD_prof node" ); |
81 | assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && |
82 | "Too many bits for MD_prof branch_weight" ); |
83 | Weights[Idx - WeightsIdx] = Weight->getZExtValue(); |
84 | } |
85 | } |
86 | |
87 | } // namespace |
88 | |
89 | namespace llvm { |
90 | |
91 | const char *MDProfLabels::BranchWeights = "branch_weights" ; |
92 | const char *MDProfLabels::ExpectedBranchWeights = "expected" ; |
93 | const char *MDProfLabels::ValueProfile = "VP" ; |
94 | const char *MDProfLabels::FunctionEntryCount = "function_entry_count" ; |
95 | const char *MDProfLabels::SyntheticFunctionEntryCount = |
96 | "synthetic_function_entry_count" ; |
97 | const char *MDProfLabels::UnknownBranchWeightsMarker = "unknown" ; |
98 | |
99 | bool hasProfMD(const Instruction &I) { |
100 | return I.hasMetadata(KindID: LLVMContext::MD_prof); |
101 | } |
102 | |
103 | bool isBranchWeightMD(const MDNode *ProfileData) { |
104 | return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::BranchWeights, MinOps: MinBWOps); |
105 | } |
106 | |
107 | bool isValueProfileMD(const MDNode *ProfileData) { |
108 | return isTargetMD(ProfData: ProfileData, Name: MDProfLabels::ValueProfile, MinOps: MinVPOps); |
109 | } |
110 | |
111 | bool hasBranchWeightMD(const Instruction &I) { |
112 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
113 | return isBranchWeightMD(ProfileData); |
114 | } |
115 | |
116 | static bool hasCountTypeMD(const Instruction &I) { |
117 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
118 | // Value profiles record count-type information. |
119 | if (isValueProfileMD(ProfileData)) |
120 | return true; |
121 | // Conservatively assume non CallBase instruction only get taken/not-taken |
122 | // branch probability, so not interpret them as count. |
123 | return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData); |
124 | } |
125 | |
126 | bool hasValidBranchWeightMD(const Instruction &I) { |
127 | return getValidBranchWeightMDNode(I); |
128 | } |
129 | |
130 | bool hasBranchWeightOrigin(const Instruction &I) { |
131 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
132 | return hasBranchWeightOrigin(ProfileData); |
133 | } |
134 | |
135 | bool hasBranchWeightOrigin(const MDNode *ProfileData) { |
136 | if (!isBranchWeightMD(ProfileData)) |
137 | return false; |
138 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1)); |
139 | // NOTE: if we ever have more types of branch weight provenance, |
140 | // we need to check the string value is "expected". For now, we |
141 | // supply a more generic API, and avoid the spurious comparisons. |
142 | assert(ProfDataName == nullptr || |
143 | ProfDataName->getString() == MDProfLabels::ExpectedBranchWeights); |
144 | return ProfDataName != nullptr; |
145 | } |
146 | |
147 | unsigned getBranchWeightOffset(const MDNode *ProfileData) { |
148 | return hasBranchWeightOrigin(ProfileData) ? 2 : 1; |
149 | } |
150 | |
151 | unsigned getNumBranchWeights(const MDNode &ProfileData) { |
152 | return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData); |
153 | } |
154 | |
155 | MDNode *getBranchWeightMDNode(const Instruction &I) { |
156 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
157 | if (!isBranchWeightMD(ProfileData)) |
158 | return nullptr; |
159 | return ProfileData; |
160 | } |
161 | |
162 | MDNode *getValidBranchWeightMDNode(const Instruction &I) { |
163 | auto *ProfileData = getBranchWeightMDNode(I); |
164 | if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors()) |
165 | return ProfileData; |
166 | return nullptr; |
167 | } |
168 | |
169 | void (const MDNode *ProfileData, |
170 | SmallVectorImpl<uint32_t> &Weights) { |
171 | extractFromBranchWeightMD(ProfileData, Weights); |
172 | } |
173 | |
174 | void (const MDNode *ProfileData, |
175 | SmallVectorImpl<uint64_t> &Weights) { |
176 | extractFromBranchWeightMD(ProfileData, Weights); |
177 | } |
178 | |
179 | bool (const MDNode *ProfileData, |
180 | SmallVectorImpl<uint32_t> &Weights) { |
181 | if (!isBranchWeightMD(ProfileData)) |
182 | return false; |
183 | extractFromBranchWeightMD(ProfileData, Weights); |
184 | return true; |
185 | } |
186 | |
187 | bool (const Instruction &I, |
188 | SmallVectorImpl<uint32_t> &Weights) { |
189 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
190 | return extractBranchWeights(ProfileData, Weights); |
191 | } |
192 | |
193 | bool (const Instruction &I, uint64_t &TrueVal, |
194 | uint64_t &FalseVal) { |
195 | assert((I.getOpcode() == Instruction::Br || |
196 | I.getOpcode() == Instruction::Select) && |
197 | "Looking for branch weights on something besides branch, select, or " |
198 | "switch" ); |
199 | |
200 | SmallVector<uint32_t, 2> Weights; |
201 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
202 | if (!extractBranchWeights(ProfileData, Weights)) |
203 | return false; |
204 | |
205 | if (Weights.size() > 2) |
206 | return false; |
207 | |
208 | TrueVal = Weights[0]; |
209 | FalseVal = Weights[1]; |
210 | return true; |
211 | } |
212 | |
213 | bool (const MDNode *ProfileData, uint64_t &TotalVal) { |
214 | TotalVal = 0; |
215 | if (!ProfileData) |
216 | return false; |
217 | |
218 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
219 | if (!ProfDataName) |
220 | return false; |
221 | |
222 | if (ProfDataName->getString() == MDProfLabels::BranchWeights) { |
223 | unsigned Offset = getBranchWeightOffset(ProfileData); |
224 | for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { |
225 | auto *V = mdconst::extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
226 | TotalVal += V->getValue().getZExtValue(); |
227 | } |
228 | return true; |
229 | } |
230 | |
231 | if (ProfDataName->getString() == MDProfLabels::ValueProfile && |
232 | ProfileData->getNumOperands() > 3) { |
233 | TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2)) |
234 | ->getValue() |
235 | .getZExtValue(); |
236 | return true; |
237 | } |
238 | return false; |
239 | } |
240 | |
241 | bool (const Instruction &I, uint64_t &TotalVal) { |
242 | return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal); |
243 | } |
244 | |
245 | void setExplicitlyUnknownBranchWeights(Instruction &I) { |
246 | MDBuilder MDB(I.getContext()); |
247 | I.setMetadata( |
248 | KindID: LLVMContext::MD_prof, |
249 | Node: MDNode::get(Context&: I.getContext(), |
250 | MDs: MDB.createString(Str: MDProfLabels::UnknownBranchWeightsMarker))); |
251 | } |
252 | |
253 | bool isExplicitlyUnknownBranchWeightsMetadata(const MDNode &MD) { |
254 | if (MD.getNumOperands() != 1) |
255 | return false; |
256 | return MD.getOperand(I: 0).equalsStr(Str: MDProfLabels::UnknownBranchWeightsMarker); |
257 | } |
258 | |
259 | bool hasExplicitlyUnknownBranchWeights(const Instruction &I) { |
260 | auto *MD = I.getMetadata(KindID: LLVMContext::MD_prof); |
261 | if (!MD) |
262 | return false; |
263 | return isExplicitlyUnknownBranchWeightsMetadata(MD: *MD); |
264 | } |
265 | |
266 | void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, |
267 | bool IsExpected) { |
268 | MDBuilder MDB(I.getContext()); |
269 | MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); |
270 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights); |
271 | } |
272 | |
273 | void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { |
274 | assert(T != 0 && "Caller should guarantee" ); |
275 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
276 | if (ProfileData == nullptr) |
277 | return; |
278 | |
279 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
280 | if (!ProfDataName || |
281 | (ProfDataName->getString() != MDProfLabels::BranchWeights && |
282 | ProfDataName->getString() != MDProfLabels::ValueProfile)) |
283 | return; |
284 | |
285 | if (!hasCountTypeMD(I)) |
286 | return; |
287 | |
288 | LLVMContext &C = I.getContext(); |
289 | |
290 | MDBuilder MDB(C); |
291 | SmallVector<Metadata *, 3> Vals; |
292 | Vals.push_back(Elt: ProfileData->getOperand(I: 0)); |
293 | APInt APS(128, S), APT(128, T); |
294 | if (ProfDataName->getString() == MDProfLabels::BranchWeights && |
295 | ProfileData->getNumOperands() > 0) { |
296 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
297 | APInt Val(128, |
298 | mdconst::dyn_extract<ConstantInt>( |
299 | MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData))) |
300 | ->getValue() |
301 | .getZExtValue()); |
302 | Val *= APS; |
303 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
304 | Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX)))); |
305 | } else if (ProfDataName->getString() == MDProfLabels::ValueProfile) |
306 | for (unsigned Idx = 1; Idx < ProfileData->getNumOperands(); Idx += 2) { |
307 | // The first value is the key of the value profile, which will not change. |
308 | Vals.push_back(Elt: ProfileData->getOperand(I: Idx)); |
309 | uint64_t Count = |
310 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx + 1)) |
311 | ->getValue() |
312 | .getZExtValue(); |
313 | // Don't scale the magic number. |
314 | if (Count == NOMORE_ICP_MAGICNUM) { |
315 | Vals.push_back(Elt: ProfileData->getOperand(I: Idx + 1)); |
316 | continue; |
317 | } |
318 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
319 | APInt Val(128, Count); |
320 | Val *= APS; |
321 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
322 | Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue()))); |
323 | } |
324 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals)); |
325 | } |
326 | |
327 | } // namespace llvm |
328 | |