1 | //===- BranchProbability.h - Branch Probability Wrapper ---------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Definition of BranchProbability shared by IR and Machine Instructions. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef LLVM_SUPPORT_BRANCHPROBABILITY_H |
14 | #define LLVM_SUPPORT_BRANCHPROBABILITY_H |
15 | |
16 | #include "llvm/Support/DataTypes.h" |
17 | #include <algorithm> |
18 | #include <cassert> |
19 | #include <iterator> |
20 | #include <numeric> |
21 | |
22 | namespace llvm { |
23 | |
24 | class raw_ostream; |
25 | |
26 | // This class represents Branch Probability as a non-negative fraction that is |
27 | // no greater than 1. It uses a fixed-point-like implementation, in which the |
28 | // denominator is always a constant value (here we use 1<<31 for maximum |
29 | // precision). |
30 | class BranchProbability { |
31 | // Numerator |
32 | uint32_t N; |
33 | |
34 | // Denominator, which is a constant value. |
35 | static constexpr uint32_t D = 1u << 31; |
36 | static constexpr uint32_t UnknownN = UINT32_MAX; |
37 | |
38 | // Construct a BranchProbability with only numerator assuming the denominator |
39 | // is 1<<31. For internal use only. |
40 | explicit BranchProbability(uint32_t n) : N(n) {} |
41 | |
42 | public: |
43 | BranchProbability() : N(UnknownN) {} |
44 | BranchProbability(uint32_t Numerator, uint32_t Denominator); |
45 | |
46 | bool isZero() const { return N == 0; } |
47 | bool isUnknown() const { return N == UnknownN; } |
48 | |
49 | static BranchProbability getZero() { return BranchProbability(0); } |
50 | static BranchProbability getOne() { return BranchProbability(D); } |
51 | static BranchProbability getUnknown() { return BranchProbability(UnknownN); } |
52 | // Create a BranchProbability object with the given numerator and 1<<31 |
53 | // as denominator. |
54 | static BranchProbability getRaw(uint32_t N) { return BranchProbability(N); } |
55 | // Create a BranchProbability object from 64-bit integers. |
56 | static BranchProbability getBranchProbability(uint64_t Numerator, |
57 | uint64_t Denominator); |
58 | |
59 | // Normalize given probabilties so that the sum of them becomes approximate |
60 | // one. |
61 | template <class ProbabilityIter> |
62 | static void normalizeProbabilities(ProbabilityIter Begin, |
63 | ProbabilityIter End); |
64 | |
65 | uint32_t getNumerator() const { return N; } |
66 | static uint32_t getDenominator() { return D; } |
67 | |
68 | // Return (1 - Probability). |
69 | BranchProbability getCompl() const { return BranchProbability(D - N); } |
70 | |
71 | raw_ostream &print(raw_ostream &OS) const; |
72 | |
73 | void dump() const; |
74 | |
75 | /// Scale a large integer. |
76 | /// |
77 | /// Scales \c Num. Guarantees full precision. Returns the floor of the |
78 | /// result. |
79 | /// |
80 | /// \return \c Num times \c this. |
81 | uint64_t scale(uint64_t Num) const; |
82 | |
83 | /// Scale a large integer by the inverse. |
84 | /// |
85 | /// Scales \c Num by the inverse of \c this. Guarantees full precision. |
86 | /// Returns the floor of the result. |
87 | /// |
88 | /// \return \c Num divided by \c this. |
89 | uint64_t scaleByInverse(uint64_t Num) const; |
90 | |
91 | BranchProbability &operator+=(BranchProbability RHS) { |
92 | assert(N != UnknownN && RHS.N != UnknownN && |
93 | "Unknown probability cannot participate in arithmetics." ); |
94 | // Saturate the result in case of overflow. |
95 | N = (uint64_t(N) + RHS.N > D) ? D : N + RHS.N; |
96 | return *this; |
97 | } |
98 | |
99 | BranchProbability &operator-=(BranchProbability RHS) { |
100 | assert(N != UnknownN && RHS.N != UnknownN && |
101 | "Unknown probability cannot participate in arithmetics." ); |
102 | // Saturate the result in case of underflow. |
103 | N = N < RHS.N ? 0 : N - RHS.N; |
104 | return *this; |
105 | } |
106 | |
107 | BranchProbability &operator*=(BranchProbability RHS) { |
108 | assert(N != UnknownN && RHS.N != UnknownN && |
109 | "Unknown probability cannot participate in arithmetics." ); |
110 | N = (static_cast<uint64_t>(N) * RHS.N + D / 2) / D; |
111 | return *this; |
112 | } |
113 | |
114 | BranchProbability &operator*=(uint32_t RHS) { |
115 | assert(N != UnknownN && |
116 | "Unknown probability cannot participate in arithmetics." ); |
117 | N = (uint64_t(N) * RHS > D) ? D : N * RHS; |
118 | return *this; |
119 | } |
120 | |
121 | BranchProbability &operator/=(BranchProbability RHS) { |
122 | assert(N != UnknownN && RHS.N != UnknownN && |
123 | "Unknown probability cannot participate in arithmetics." ); |
124 | N = (static_cast<uint64_t>(N) * D + RHS.N / 2) / RHS.N; |
125 | return *this; |
126 | } |
127 | |
128 | BranchProbability &operator/=(uint32_t RHS) { |
129 | assert(N != UnknownN && |
130 | "Unknown probability cannot participate in arithmetics." ); |
131 | assert(RHS > 0 && "The divider cannot be zero." ); |
132 | N /= RHS; |
133 | return *this; |
134 | } |
135 | |
136 | BranchProbability operator+(BranchProbability RHS) const { |
137 | BranchProbability Prob(*this); |
138 | Prob += RHS; |
139 | return Prob; |
140 | } |
141 | |
142 | BranchProbability operator-(BranchProbability RHS) const { |
143 | BranchProbability Prob(*this); |
144 | Prob -= RHS; |
145 | return Prob; |
146 | } |
147 | |
148 | BranchProbability operator*(BranchProbability RHS) const { |
149 | BranchProbability Prob(*this); |
150 | Prob *= RHS; |
151 | return Prob; |
152 | } |
153 | |
154 | BranchProbability operator*(uint32_t RHS) const { |
155 | BranchProbability Prob(*this); |
156 | Prob *= RHS; |
157 | return Prob; |
158 | } |
159 | |
160 | BranchProbability operator/(BranchProbability RHS) const { |
161 | BranchProbability Prob(*this); |
162 | Prob /= RHS; |
163 | return Prob; |
164 | } |
165 | |
166 | BranchProbability operator/(uint32_t RHS) const { |
167 | BranchProbability Prob(*this); |
168 | Prob /= RHS; |
169 | return Prob; |
170 | } |
171 | |
172 | bool operator==(BranchProbability RHS) const { return N == RHS.N; } |
173 | bool operator!=(BranchProbability RHS) const { return !(*this == RHS); } |
174 | |
175 | bool operator<(BranchProbability RHS) const { |
176 | assert(N != UnknownN && RHS.N != UnknownN && |
177 | "Unknown probability cannot participate in comparisons." ); |
178 | return N < RHS.N; |
179 | } |
180 | |
181 | bool operator>(BranchProbability RHS) const { |
182 | assert(N != UnknownN && RHS.N != UnknownN && |
183 | "Unknown probability cannot participate in comparisons." ); |
184 | return RHS < *this; |
185 | } |
186 | |
187 | bool operator<=(BranchProbability RHS) const { |
188 | assert(N != UnknownN && RHS.N != UnknownN && |
189 | "Unknown probability cannot participate in comparisons." ); |
190 | return !(RHS < *this); |
191 | } |
192 | |
193 | bool operator>=(BranchProbability RHS) const { |
194 | assert(N != UnknownN && RHS.N != UnknownN && |
195 | "Unknown probability cannot participate in comparisons." ); |
196 | return !(*this < RHS); |
197 | } |
198 | }; |
199 | |
200 | inline raw_ostream &operator<<(raw_ostream &OS, BranchProbability Prob) { |
201 | return Prob.print(OS); |
202 | } |
203 | |
204 | template <class ProbabilityIter> |
205 | void BranchProbability::normalizeProbabilities(ProbabilityIter Begin, |
206 | ProbabilityIter End) { |
207 | if (Begin == End) |
208 | return; |
209 | |
210 | unsigned UnknownProbCount = 0; |
211 | uint64_t Sum = std::accumulate(Begin, End, uint64_t(0), |
212 | [&](uint64_t S, const BranchProbability &BP) { |
213 | if (!BP.isUnknown()) |
214 | return S + BP.N; |
215 | UnknownProbCount++; |
216 | return S; |
217 | }); |
218 | |
219 | if (UnknownProbCount > 0) { |
220 | BranchProbability ProbForUnknown = BranchProbability::getZero(); |
221 | // If the sum of all known probabilities is less than one, evenly distribute |
222 | // the complement of sum to unknown probabilities. Otherwise, set unknown |
223 | // probabilities to zeros and continue to normalize known probabilities. |
224 | if (Sum < BranchProbability::getDenominator()) |
225 | ProbForUnknown = BranchProbability::getRaw( |
226 | N: (BranchProbability::getDenominator() - Sum) / UnknownProbCount); |
227 | |
228 | std::replace_if(Begin, End, |
229 | [](const BranchProbability &BP) { return BP.isUnknown(); }, |
230 | ProbForUnknown); |
231 | |
232 | if (Sum <= BranchProbability::getDenominator()) |
233 | return; |
234 | } |
235 | |
236 | if (Sum == 0) { |
237 | BranchProbability BP(1, std::distance(Begin, End)); |
238 | std::fill(Begin, End, BP); |
239 | return; |
240 | } |
241 | |
242 | for (auto I = Begin; I != End; ++I) |
243 | I->N = (I->N * uint64_t(D) + Sum / 2) / Sum; |
244 | } |
245 | |
246 | } |
247 | |
248 | #endif |
249 | |