1 | //===- ProfDataUtils.cpp - Utility functions for MD_prof Metadata ---------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements utilities for working with Profiling Metadata. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "llvm/IR/ProfDataUtils.h" |
14 | #include "llvm/ADT/SmallVector.h" |
15 | #include "llvm/ADT/Twine.h" |
16 | #include "llvm/IR/Constants.h" |
17 | #include "llvm/IR/Function.h" |
18 | #include "llvm/IR/Instructions.h" |
19 | #include "llvm/IR/LLVMContext.h" |
20 | #include "llvm/IR/MDBuilder.h" |
21 | #include "llvm/IR/Metadata.h" |
22 | #include "llvm/IR/ProfDataUtils.h" |
23 | #include "llvm/Support/BranchProbability.h" |
24 | #include "llvm/Support/CommandLine.h" |
25 | |
26 | using namespace llvm; |
27 | |
28 | namespace { |
29 | |
30 | // MD_prof nodes have the following layout |
31 | // |
32 | // In general: |
33 | // { String name, Array of i32 } |
34 | // |
35 | // In terms of Types: |
36 | // { MDString, [i32, i32, ...]} |
37 | // |
38 | // Concretely for Branch Weights |
39 | // { "branch_weights", [i32 1, i32 10000]} |
40 | // |
41 | // We maintain some constants here to ensure that we access the branch weights |
42 | // correctly, and can change the behavior in the future if the layout changes |
43 | |
44 | // the minimum number of operands for MD_prof nodes with branch weights |
45 | constexpr unsigned MinBWOps = 3; |
46 | |
47 | // the minimum number of operands for MD_prof nodes with value profiles |
48 | constexpr unsigned MinVPOps = 5; |
49 | |
50 | // We may want to add support for other MD_prof types, so provide an abstraction |
51 | // for checking the metadata type. |
52 | bool isTargetMD(const MDNode *ProfData, const char *Name, unsigned MinOps) { |
53 | // TODO: This routine may be simplified if MD_prof used an enum instead of a |
54 | // string to differentiate the types of MD_prof nodes. |
55 | if (!ProfData || !Name || MinOps < 2) |
56 | return false; |
57 | |
58 | unsigned NOps = ProfData->getNumOperands(); |
59 | if (NOps < MinOps) |
60 | return false; |
61 | |
62 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfData->getOperand(I: 0)); |
63 | if (!ProfDataName) |
64 | return false; |
65 | |
66 | return ProfDataName->getString() == Name; |
67 | } |
68 | |
69 | template <typename T, |
70 | typename = typename std::enable_if<std::is_arithmetic_v<T>>> |
71 | static void (const MDNode *ProfileData, |
72 | SmallVectorImpl<T> &Weights) { |
73 | assert(isBranchWeightMD(ProfileData) && "wrong metadata" ); |
74 | |
75 | unsigned NOps = ProfileData->getNumOperands(); |
76 | unsigned WeightsIdx = getBranchWeightOffset(ProfileData); |
77 | assert(WeightsIdx < NOps && "Weights Index must be less than NOps." ); |
78 | Weights.resize(NOps - WeightsIdx); |
79 | |
80 | for (unsigned Idx = WeightsIdx, E = NOps; Idx != E; ++Idx) { |
81 | ConstantInt *Weight = |
82 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
83 | assert(Weight && "Malformed branch_weight in MD_prof node" ); |
84 | assert(Weight->getValue().getActiveBits() <= (sizeof(T) * 8) && |
85 | "Too many bits for MD_prof branch_weight" ); |
86 | Weights[Idx - WeightsIdx] = Weight->getZExtValue(); |
87 | } |
88 | } |
89 | |
90 | } // namespace |
91 | |
92 | namespace llvm { |
93 | |
94 | bool hasProfMD(const Instruction &I) { |
95 | return I.hasMetadata(KindID: LLVMContext::MD_prof); |
96 | } |
97 | |
98 | bool isBranchWeightMD(const MDNode *ProfileData) { |
99 | return isTargetMD(ProfData: ProfileData, Name: "branch_weights" , MinOps: MinBWOps); |
100 | } |
101 | |
102 | bool isValueProfileMD(const MDNode *ProfileData) { |
103 | return isTargetMD(ProfData: ProfileData, Name: "VP" , MinOps: MinVPOps); |
104 | } |
105 | |
106 | bool hasBranchWeightMD(const Instruction &I) { |
107 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
108 | return isBranchWeightMD(ProfileData); |
109 | } |
110 | |
111 | bool hasCountTypeMD(const Instruction &I) { |
112 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
113 | // Value profiles record count-type information. |
114 | if (isValueProfileMD(ProfileData)) |
115 | return true; |
116 | // Conservatively assume non CallBase instruction only get taken/not-taken |
117 | // branch probability, so not interpret them as count. |
118 | return isa<CallBase>(Val: I) && !isBranchWeightMD(ProfileData); |
119 | } |
120 | |
121 | bool hasValidBranchWeightMD(const Instruction &I) { |
122 | return getValidBranchWeightMDNode(I); |
123 | } |
124 | |
125 | bool hasBranchWeightOrigin(const Instruction &I) { |
126 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
127 | return hasBranchWeightOrigin(ProfileData); |
128 | } |
129 | |
130 | bool hasBranchWeightOrigin(const MDNode *ProfileData) { |
131 | if (!isBranchWeightMD(ProfileData)) |
132 | return false; |
133 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 1)); |
134 | // NOTE: if we ever have more types of branch weight provenance, |
135 | // we need to check the string value is "expected". For now, we |
136 | // supply a more generic API, and avoid the spurious comparisons. |
137 | assert(ProfDataName == nullptr || ProfDataName->getString() == "expected" ); |
138 | return ProfDataName != nullptr; |
139 | } |
140 | |
141 | unsigned getBranchWeightOffset(const MDNode *ProfileData) { |
142 | return hasBranchWeightOrigin(ProfileData) ? 2 : 1; |
143 | } |
144 | |
145 | unsigned getNumBranchWeights(const MDNode &ProfileData) { |
146 | return ProfileData.getNumOperands() - getBranchWeightOffset(ProfileData: &ProfileData); |
147 | } |
148 | |
149 | MDNode *getBranchWeightMDNode(const Instruction &I) { |
150 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
151 | if (!isBranchWeightMD(ProfileData)) |
152 | return nullptr; |
153 | return ProfileData; |
154 | } |
155 | |
156 | MDNode *getValidBranchWeightMDNode(const Instruction &I) { |
157 | auto *ProfileData = getBranchWeightMDNode(I); |
158 | if (ProfileData && getNumBranchWeights(ProfileData: *ProfileData) == I.getNumSuccessors()) |
159 | return ProfileData; |
160 | return nullptr; |
161 | } |
162 | |
163 | void (const MDNode *ProfileData, |
164 | SmallVectorImpl<uint32_t> &Weights) { |
165 | extractFromBranchWeightMD(ProfileData, Weights); |
166 | } |
167 | |
168 | void (const MDNode *ProfileData, |
169 | SmallVectorImpl<uint64_t> &Weights) { |
170 | extractFromBranchWeightMD(ProfileData, Weights); |
171 | } |
172 | |
173 | bool (const MDNode *ProfileData, |
174 | SmallVectorImpl<uint32_t> &Weights) { |
175 | if (!isBranchWeightMD(ProfileData)) |
176 | return false; |
177 | extractFromBranchWeightMD(ProfileData, Weights); |
178 | return true; |
179 | } |
180 | |
181 | bool (const Instruction &I, |
182 | SmallVectorImpl<uint32_t> &Weights) { |
183 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
184 | return extractBranchWeights(ProfileData, Weights); |
185 | } |
186 | |
187 | bool (const Instruction &I, uint64_t &TrueVal, |
188 | uint64_t &FalseVal) { |
189 | assert((I.getOpcode() == Instruction::Br || |
190 | I.getOpcode() == Instruction::Select) && |
191 | "Looking for branch weights on something besides branch, select, or " |
192 | "switch" ); |
193 | |
194 | SmallVector<uint32_t, 2> Weights; |
195 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
196 | if (!extractBranchWeights(ProfileData, Weights)) |
197 | return false; |
198 | |
199 | if (Weights.size() > 2) |
200 | return false; |
201 | |
202 | TrueVal = Weights[0]; |
203 | FalseVal = Weights[1]; |
204 | return true; |
205 | } |
206 | |
207 | bool (const MDNode *ProfileData, uint64_t &TotalVal) { |
208 | TotalVal = 0; |
209 | if (!ProfileData) |
210 | return false; |
211 | |
212 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
213 | if (!ProfDataName) |
214 | return false; |
215 | |
216 | if (ProfDataName->getString() == "branch_weights" ) { |
217 | unsigned Offset = getBranchWeightOffset(ProfileData); |
218 | for (unsigned Idx = Offset; Idx < ProfileData->getNumOperands(); ++Idx) { |
219 | auto *V = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: Idx)); |
220 | assert(V && "Malformed branch_weight in MD_prof node" ); |
221 | TotalVal += V->getValue().getZExtValue(); |
222 | } |
223 | return true; |
224 | } |
225 | |
226 | if (ProfDataName->getString() == "VP" && ProfileData->getNumOperands() > 3) { |
227 | TotalVal = mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: 2)) |
228 | ->getValue() |
229 | .getZExtValue(); |
230 | return true; |
231 | } |
232 | return false; |
233 | } |
234 | |
235 | bool (const Instruction &I, uint64_t &TotalVal) { |
236 | return extractProfTotalWeight(ProfileData: I.getMetadata(KindID: LLVMContext::MD_prof), TotalVal); |
237 | } |
238 | |
239 | void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights, |
240 | bool IsExpected) { |
241 | MDBuilder MDB(I.getContext()); |
242 | MDNode *BranchWeights = MDB.createBranchWeights(Weights, IsExpected); |
243 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: BranchWeights); |
244 | } |
245 | |
246 | void scaleProfData(Instruction &I, uint64_t S, uint64_t T) { |
247 | assert(T != 0 && "Caller should guarantee" ); |
248 | auto *ProfileData = I.getMetadata(KindID: LLVMContext::MD_prof); |
249 | if (ProfileData == nullptr) |
250 | return; |
251 | |
252 | auto *ProfDataName = dyn_cast<MDString>(Val: ProfileData->getOperand(I: 0)); |
253 | if (!ProfDataName || (ProfDataName->getString() != "branch_weights" && |
254 | ProfDataName->getString() != "VP" )) |
255 | return; |
256 | |
257 | if (!hasCountTypeMD(I)) |
258 | return; |
259 | |
260 | LLVMContext &C = I.getContext(); |
261 | |
262 | MDBuilder MDB(C); |
263 | SmallVector<Metadata *, 3> Vals; |
264 | Vals.push_back(Elt: ProfileData->getOperand(I: 0)); |
265 | APInt APS(128, S), APT(128, T); |
266 | if (ProfDataName->getString() == "branch_weights" && |
267 | ProfileData->getNumOperands() > 0) { |
268 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
269 | APInt Val(128, |
270 | mdconst::dyn_extract<ConstantInt>( |
271 | MD: ProfileData->getOperand(I: getBranchWeightOffset(ProfileData))) |
272 | ->getValue() |
273 | .getZExtValue()); |
274 | Val *= APS; |
275 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
276 | Ty: Type::getInt32Ty(C), V: Val.udiv(RHS: APT).getLimitedValue(UINT32_MAX)))); |
277 | } else if (ProfDataName->getString() == "VP" ) |
278 | for (unsigned i = 1; i < ProfileData->getNumOperands(); i += 2) { |
279 | // The first value is the key of the value profile, which will not change. |
280 | Vals.push_back(Elt: ProfileData->getOperand(I: i)); |
281 | uint64_t Count = |
282 | mdconst::dyn_extract<ConstantInt>(MD: ProfileData->getOperand(I: i + 1)) |
283 | ->getValue() |
284 | .getZExtValue(); |
285 | // Don't scale the magic number. |
286 | if (Count == NOMORE_ICP_MAGICNUM) { |
287 | Vals.push_back(Elt: ProfileData->getOperand(I: i + 1)); |
288 | continue; |
289 | } |
290 | // Using APInt::div may be expensive, but most cases should fit 64 bits. |
291 | APInt Val(128, Count); |
292 | Val *= APS; |
293 | Vals.push_back(Elt: MDB.createConstant(C: ConstantInt::get( |
294 | Ty: Type::getInt64Ty(C), V: Val.udiv(RHS: APT).getLimitedValue()))); |
295 | } |
296 | I.setMetadata(KindID: LLVMContext::MD_prof, Node: MDNode::get(Context&: C, MDs: Vals)); |
297 | } |
298 | |
299 | } // namespace llvm |
300 | |