1 | //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains a class for representing known zeros and ones used by |
10 | // computeKnownBits. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/KnownBits.h" |
15 | #include "llvm/Support/Debug.h" |
16 | #include "llvm/Support/raw_ostream.h" |
17 | #include <cassert> |
18 | |
19 | using namespace llvm; |
20 | |
21 | static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS, |
22 | bool CarryZero, bool CarryOne) { |
23 | |
24 | APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; |
25 | APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; |
26 | |
27 | // Compute known bits of the carry. |
28 | APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); |
29 | APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; |
30 | |
31 | // Compute set of known bits (where all three relevant bits are known). |
32 | APInt LHSKnownUnion = LHS.Zero | LHS.One; |
33 | APInt RHSKnownUnion = RHS.Zero | RHS.One; |
34 | APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; |
35 | APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; |
36 | |
37 | // Compute known bits of the result. |
38 | KnownBits KnownOut; |
39 | KnownOut.Zero = ~std::move(PossibleSumZero) & Known; |
40 | KnownOut.One = std::move(PossibleSumOne) & Known; |
41 | return KnownOut; |
42 | } |
43 | |
44 | KnownBits KnownBits::computeForAddCarry( |
45 | const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { |
46 | assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit" ); |
47 | return ::computeForAddCarry( |
48 | LHS, RHS, CarryZero: Carry.Zero.getBoolValue(), CarryOne: Carry.One.getBoolValue()); |
49 | } |
50 | |
51 | KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW, |
52 | const KnownBits &LHS, |
53 | const KnownBits &RHS) { |
54 | unsigned BitWidth = LHS.getBitWidth(); |
55 | KnownBits KnownOut(BitWidth); |
56 | // This can be a relatively expensive helper, so optimistically save some |
57 | // work. |
58 | if (LHS.isUnknown() && RHS.isUnknown()) |
59 | return KnownOut; |
60 | |
61 | if (!LHS.isUnknown() && !RHS.isUnknown()) { |
62 | if (Add) { |
63 | // Sum = LHS + RHS + 0 |
64 | KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true, |
65 | /*CarryOne=*/false); |
66 | } else { |
67 | // Sum = LHS + ~RHS + 1 |
68 | KnownBits NotRHS = RHS; |
69 | std::swap(a&: NotRHS.Zero, b&: NotRHS.One); |
70 | KnownOut = ::computeForAddCarry(LHS, RHS: NotRHS, /*CarryZero=*/false, |
71 | /*CarryOne=*/true); |
72 | } |
73 | } |
74 | |
75 | // Handle add/sub given nsw and/or nuw. |
76 | if (NUW) { |
77 | if (Add) { |
78 | // (add nuw X, Y) |
79 | APInt MinVal = LHS.getMinValue().uadd_sat(RHS: RHS.getMinValue()); |
80 | // None of the adds can end up overflowing, so min consecutive highbits |
81 | // in minimum possible of X + Y must all remain set. |
82 | if (NSW) { |
83 | unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one(); |
84 | // If we have NSW as well, we also know we can't overflow the signbit so |
85 | // can start counting from 1 bit back. |
86 | KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
87 | } |
88 | KnownOut.One.setHighBits(MinVal.countl_one()); |
89 | } else { |
90 | // (sub nuw X, Y) |
91 | APInt MaxVal = LHS.getMaxValue().usub_sat(RHS: RHS.getMinValue()); |
92 | // None of the subs can overflow at any point, so any common high bits |
93 | // will subtract away and result in zeros. |
94 | if (NSW) { |
95 | // If we have NSW as well, we also know we can't overflow the signbit so |
96 | // can start counting from 1 bit back. |
97 | unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero(); |
98 | KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
99 | } |
100 | KnownOut.Zero.setHighBits(MaxVal.countl_zero()); |
101 | } |
102 | } |
103 | |
104 | if (NSW) { |
105 | APInt MinVal; |
106 | APInt MaxVal; |
107 | if (Add) { |
108 | // (add nsw X, Y) |
109 | MinVal = LHS.getSignedMinValue().sadd_sat(RHS: RHS.getSignedMinValue()); |
110 | MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS: RHS.getSignedMaxValue()); |
111 | } else { |
112 | // (sub nsw X, Y) |
113 | MinVal = LHS.getSignedMinValue().ssub_sat(RHS: RHS.getSignedMaxValue()); |
114 | MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS: RHS.getSignedMinValue()); |
115 | } |
116 | if (MinVal.isNonNegative()) { |
117 | // If min is non-negative, result will always be non-neg (can't overflow |
118 | // around). |
119 | unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one(); |
120 | KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
121 | KnownOut.Zero.setSignBit(); |
122 | } |
123 | if (MaxVal.isNegative()) { |
124 | // If max is negative, result will always be neg (can't overflow around). |
125 | unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero(); |
126 | KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
127 | KnownOut.One.setSignBit(); |
128 | } |
129 | } |
130 | |
131 | // Just return 0 if the nsw/nuw is violated and we have poison. |
132 | if (KnownOut.hasConflict()) |
133 | KnownOut.setAllZero(); |
134 | return KnownOut; |
135 | } |
136 | |
137 | KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS, |
138 | const KnownBits &Borrow) { |
139 | assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit" ); |
140 | |
141 | // LHS - RHS = LHS + ~RHS + 1 |
142 | // Carry 1 - Borrow in ::computeForAddCarry |
143 | std::swap(a&: RHS.Zero, b&: RHS.One); |
144 | return ::computeForAddCarry(LHS, RHS, |
145 | /*CarryZero=*/Borrow.One.getBoolValue(), |
146 | /*CarryOne=*/Borrow.Zero.getBoolValue()); |
147 | } |
148 | |
149 | KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { |
150 | unsigned BitWidth = getBitWidth(); |
151 | assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth && |
152 | "Illegal sext-in-register" ); |
153 | |
154 | if (SrcBitWidth == BitWidth) |
155 | return *this; |
156 | |
157 | unsigned ExtBits = BitWidth - SrcBitWidth; |
158 | KnownBits Result; |
159 | Result.One = One << ExtBits; |
160 | Result.Zero = Zero << ExtBits; |
161 | Result.One.ashrInPlace(ShiftAmt: ExtBits); |
162 | Result.Zero.ashrInPlace(ShiftAmt: ExtBits); |
163 | return Result; |
164 | } |
165 | |
166 | KnownBits KnownBits::makeGE(const APInt &Val) const { |
167 | // Count the number of leading bit positions where our underlying value is |
168 | // known to be less than or equal to Val. |
169 | unsigned N = (Zero | Val).countl_one(); |
170 | |
171 | // For each of those bit positions, if Val has a 1 in that bit then our |
172 | // underlying value must also have a 1. |
173 | APInt MaskedVal(Val); |
174 | MaskedVal.clearLowBits(loBits: getBitWidth() - N); |
175 | return KnownBits(Zero, One | MaskedVal); |
176 | } |
177 | |
178 | KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { |
179 | // If we can prove that LHS >= RHS then use LHS as the result. Likewise for |
180 | // RHS. Ideally our caller would already have spotted these cases and |
181 | // optimized away the umax operation, but we handle them here for |
182 | // completeness. |
183 | if (LHS.getMinValue().uge(RHS: RHS.getMaxValue())) |
184 | return LHS; |
185 | if (RHS.getMinValue().uge(RHS: LHS.getMaxValue())) |
186 | return RHS; |
187 | |
188 | // If the result of the umax is LHS then it must be greater than or equal to |
189 | // the minimum possible value of RHS. Likewise for RHS. Any known bits that |
190 | // are common to these two values are also known in the result. |
191 | KnownBits L = LHS.makeGE(Val: RHS.getMinValue()); |
192 | KnownBits R = RHS.makeGE(Val: LHS.getMinValue()); |
193 | return L.intersectWith(RHS: R); |
194 | } |
195 | |
196 | KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { |
197 | // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] |
198 | auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; |
199 | return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS))); |
200 | } |
201 | |
202 | KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { |
203 | // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0, 0xFFFFFFFF] |
204 | auto Flip = [](const KnownBits &Val) { |
205 | unsigned SignBitPosition = Val.getBitWidth() - 1; |
206 | APInt Zero = Val.Zero; |
207 | APInt One = Val.One; |
208 | Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]); |
209 | One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]); |
210 | return KnownBits(Zero, One); |
211 | }; |
212 | return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS))); |
213 | } |
214 | |
215 | KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { |
216 | // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] |
217 | auto Flip = [](const KnownBits &Val) { |
218 | unsigned SignBitPosition = Val.getBitWidth() - 1; |
219 | APInt Zero = Val.One; |
220 | APInt One = Val.Zero; |
221 | Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]); |
222 | One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]); |
223 | return KnownBits(Zero, One); |
224 | }; |
225 | return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS))); |
226 | } |
227 | |
228 | KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) { |
229 | // If we know which argument is larger, return (sub LHS, RHS) or |
230 | // (sub RHS, LHS) directly. |
231 | if (LHS.getMinValue().uge(RHS: RHS.getMaxValue())) |
232 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, |
233 | RHS); |
234 | if (RHS.getMinValue().uge(RHS: LHS.getMaxValue())) |
235 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS, |
236 | RHS: LHS); |
237 | |
238 | // By construction, the subtraction in abdu never has unsigned overflow. |
239 | // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS). |
240 | KnownBits Diff0 = |
241 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS); |
242 | KnownBits Diff1 = |
243 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS); |
244 | return Diff0.intersectWith(RHS: Diff1); |
245 | } |
246 | |
247 | KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) { |
248 | // If we know which argument is larger, return (sub LHS, RHS) or |
249 | // (sub RHS, LHS) directly. |
250 | if (LHS.getSignedMinValue().sge(RHS: RHS.getSignedMaxValue())) |
251 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, |
252 | RHS); |
253 | if (RHS.getSignedMinValue().sge(RHS: LHS.getSignedMaxValue())) |
254 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS, |
255 | RHS: LHS); |
256 | |
257 | // Shift both arguments from the signed range to the unsigned range, e.g. from |
258 | // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like |
259 | // abdu does. |
260 | // Note that we can't just use "sub nsw" instead because abds has signed |
261 | // inputs but an unsigned result, which makes the overflow conditions |
262 | // different. |
263 | unsigned SignBitPosition = LHS.getBitWidth() - 1; |
264 | for (auto Arg : {&LHS, &RHS}) { |
265 | bool Tmp = Arg->Zero[SignBitPosition]; |
266 | Arg->Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Arg->One[SignBitPosition]); |
267 | Arg->One.setBitVal(BitPosition: SignBitPosition, BitValue: Tmp); |
268 | } |
269 | |
270 | // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS). |
271 | KnownBits Diff0 = |
272 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS); |
273 | KnownBits Diff1 = |
274 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS); |
275 | return Diff0.intersectWith(RHS: Diff1); |
276 | } |
277 | |
278 | static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) { |
279 | if (isPowerOf2_32(Value: BitWidth)) |
280 | return MaxValue.extractBitsAsZExtValue(numBits: Log2_32(Value: BitWidth), bitPosition: 0); |
281 | // This is only an approximate upper bound. |
282 | return MaxValue.getLimitedValue(Limit: BitWidth - 1); |
283 | } |
284 | |
285 | KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, |
286 | bool NSW, bool ShAmtNonZero) { |
287 | unsigned BitWidth = LHS.getBitWidth(); |
288 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
289 | KnownBits Known; |
290 | bool ShiftedOutZero, ShiftedOutOne; |
291 | Known.Zero = LHS.Zero.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutZero); |
292 | Known.Zero.setLowBits(ShiftAmt); |
293 | Known.One = LHS.One.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutOne); |
294 | |
295 | // All cases returning poison have been handled by MaxShiftAmount already. |
296 | if (NSW) { |
297 | if (NUW && ShiftAmt != 0) |
298 | // NUW means we can assume anything shifted out was a zero. |
299 | ShiftedOutZero = true; |
300 | |
301 | if (ShiftedOutZero) |
302 | Known.makeNonNegative(); |
303 | else if (ShiftedOutOne) |
304 | Known.makeNegative(); |
305 | } |
306 | return Known; |
307 | }; |
308 | |
309 | // Fast path for a common case when LHS is completely unknown. |
310 | KnownBits Known(BitWidth); |
311 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
312 | if (MinShiftAmount == 0 && ShAmtNonZero) |
313 | MinShiftAmount = 1; |
314 | if (LHS.isUnknown()) { |
315 | Known.Zero.setLowBits(MinShiftAmount); |
316 | if (NUW && NSW && MinShiftAmount != 0) |
317 | Known.makeNonNegative(); |
318 | return Known; |
319 | } |
320 | |
321 | // Determine maximum shift amount, taking NUW/NSW flags into account. |
322 | APInt MaxValue = RHS.getMaxValue(); |
323 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
324 | if (NUW && NSW) |
325 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros() - 1); |
326 | if (NUW) |
327 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros()); |
328 | if (NSW) |
329 | MaxShiftAmount = std::min( |
330 | a: MaxShiftAmount, |
331 | b: std::max(a: LHS.countMaxLeadingZeros(), b: LHS.countMaxLeadingOnes()) - 1); |
332 | |
333 | // Fast path for common case where the shift amount is unknown. |
334 | if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 && |
335 | isPowerOf2_32(Value: BitWidth)) { |
336 | Known.Zero.setLowBits(LHS.countMinTrailingZeros()); |
337 | if (LHS.isAllOnes()) |
338 | Known.One.setSignBit(); |
339 | if (NSW) { |
340 | if (LHS.isNonNegative()) |
341 | Known.makeNonNegative(); |
342 | if (LHS.isNegative()) |
343 | Known.makeNegative(); |
344 | } |
345 | return Known; |
346 | } |
347 | |
348 | // Find the common bits from all possible shifts. |
349 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
350 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
351 | Known.Zero.setAllBits(); |
352 | Known.One.setAllBits(); |
353 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
354 | ++ShiftAmt) { |
355 | // Skip if the shift amount is impossible. |
356 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
357 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
358 | continue; |
359 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
360 | if (Known.isUnknown()) |
361 | break; |
362 | } |
363 | |
364 | // All shift amounts may result in poison. |
365 | if (Known.hasConflict()) |
366 | Known.setAllZero(); |
367 | return Known; |
368 | } |
369 | |
370 | KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS, |
371 | bool ShAmtNonZero, bool Exact) { |
372 | unsigned BitWidth = LHS.getBitWidth(); |
373 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
374 | KnownBits Known = LHS; |
375 | Known.Zero.lshrInPlace(ShiftAmt); |
376 | Known.One.lshrInPlace(ShiftAmt); |
377 | // High bits are known zero. |
378 | Known.Zero.setHighBits(ShiftAmt); |
379 | return Known; |
380 | }; |
381 | |
382 | // Fast path for a common case when LHS is completely unknown. |
383 | KnownBits Known(BitWidth); |
384 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
385 | if (MinShiftAmount == 0 && ShAmtNonZero) |
386 | MinShiftAmount = 1; |
387 | if (LHS.isUnknown()) { |
388 | Known.Zero.setHighBits(MinShiftAmount); |
389 | return Known; |
390 | } |
391 | |
392 | // Find the common bits from all possible shifts. |
393 | APInt MaxValue = RHS.getMaxValue(); |
394 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
395 | |
396 | // If exact, bound MaxShiftAmount to first known 1 in LHS. |
397 | if (Exact) { |
398 | unsigned FirstOne = LHS.countMaxTrailingZeros(); |
399 | if (FirstOne < MinShiftAmount) { |
400 | // Always poison. Return zero because we don't like returning conflict. |
401 | Known.setAllZero(); |
402 | return Known; |
403 | } |
404 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne); |
405 | } |
406 | |
407 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
408 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
409 | Known.Zero.setAllBits(); |
410 | Known.One.setAllBits(); |
411 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
412 | ++ShiftAmt) { |
413 | // Skip if the shift amount is impossible. |
414 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
415 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
416 | continue; |
417 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
418 | if (Known.isUnknown()) |
419 | break; |
420 | } |
421 | |
422 | // All shift amounts may result in poison. |
423 | if (Known.hasConflict()) |
424 | Known.setAllZero(); |
425 | return Known; |
426 | } |
427 | |
428 | KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS, |
429 | bool ShAmtNonZero, bool Exact) { |
430 | unsigned BitWidth = LHS.getBitWidth(); |
431 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
432 | KnownBits Known = LHS; |
433 | Known.Zero.ashrInPlace(ShiftAmt); |
434 | Known.One.ashrInPlace(ShiftAmt); |
435 | return Known; |
436 | }; |
437 | |
438 | // Fast path for a common case when LHS is completely unknown. |
439 | KnownBits Known(BitWidth); |
440 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
441 | if (MinShiftAmount == 0 && ShAmtNonZero) |
442 | MinShiftAmount = 1; |
443 | if (LHS.isUnknown()) { |
444 | if (MinShiftAmount == BitWidth) { |
445 | // Always poison. Return zero because we don't like returning conflict. |
446 | Known.setAllZero(); |
447 | return Known; |
448 | } |
449 | return Known; |
450 | } |
451 | |
452 | // Find the common bits from all possible shifts. |
453 | APInt MaxValue = RHS.getMaxValue(); |
454 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
455 | |
456 | // If exact, bound MaxShiftAmount to first known 1 in LHS. |
457 | if (Exact) { |
458 | unsigned FirstOne = LHS.countMaxTrailingZeros(); |
459 | if (FirstOne < MinShiftAmount) { |
460 | // Always poison. Return zero because we don't like returning conflict. |
461 | Known.setAllZero(); |
462 | return Known; |
463 | } |
464 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne); |
465 | } |
466 | |
467 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
468 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
469 | Known.Zero.setAllBits(); |
470 | Known.One.setAllBits(); |
471 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
472 | ++ShiftAmt) { |
473 | // Skip if the shift amount is impossible. |
474 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
475 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
476 | continue; |
477 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
478 | if (Known.isUnknown()) |
479 | break; |
480 | } |
481 | |
482 | // All shift amounts may result in poison. |
483 | if (Known.hasConflict()) |
484 | Known.setAllZero(); |
485 | return Known; |
486 | } |
487 | |
488 | std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { |
489 | if (LHS.isConstant() && RHS.isConstant()) |
490 | return std::optional<bool>(LHS.getConstant() == RHS.getConstant()); |
491 | if (LHS.One.intersects(RHS: RHS.Zero) || RHS.One.intersects(RHS: LHS.Zero)) |
492 | return std::optional<bool>(false); |
493 | return std::nullopt; |
494 | } |
495 | |
496 | std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { |
497 | if (std::optional<bool> KnownEQ = eq(LHS, RHS)) |
498 | return std::optional<bool>(!*KnownEQ); |
499 | return std::nullopt; |
500 | } |
501 | |
502 | std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { |
503 | // LHS >u RHS -> false if umax(LHS) <= umax(RHS) |
504 | if (LHS.getMaxValue().ule(RHS: RHS.getMinValue())) |
505 | return std::optional<bool>(false); |
506 | // LHS >u RHS -> true if umin(LHS) > umax(RHS) |
507 | if (LHS.getMinValue().ugt(RHS: RHS.getMaxValue())) |
508 | return std::optional<bool>(true); |
509 | return std::nullopt; |
510 | } |
511 | |
512 | std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { |
513 | if (std::optional<bool> IsUGT = ugt(LHS: RHS, RHS: LHS)) |
514 | return std::optional<bool>(!*IsUGT); |
515 | return std::nullopt; |
516 | } |
517 | |
518 | std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { |
519 | return ugt(LHS: RHS, RHS: LHS); |
520 | } |
521 | |
522 | std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { |
523 | return uge(LHS: RHS, RHS: LHS); |
524 | } |
525 | |
526 | std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { |
527 | // LHS >s RHS -> false if smax(LHS) <= smax(RHS) |
528 | if (LHS.getSignedMaxValue().sle(RHS: RHS.getSignedMinValue())) |
529 | return std::optional<bool>(false); |
530 | // LHS >s RHS -> true if smin(LHS) > smax(RHS) |
531 | if (LHS.getSignedMinValue().sgt(RHS: RHS.getSignedMaxValue())) |
532 | return std::optional<bool>(true); |
533 | return std::nullopt; |
534 | } |
535 | |
536 | std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { |
537 | if (std::optional<bool> KnownSGT = sgt(LHS: RHS, RHS: LHS)) |
538 | return std::optional<bool>(!*KnownSGT); |
539 | return std::nullopt; |
540 | } |
541 | |
542 | std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { |
543 | return sgt(LHS: RHS, RHS: LHS); |
544 | } |
545 | |
546 | std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { |
547 | return sge(LHS: RHS, RHS: LHS); |
548 | } |
549 | |
550 | KnownBits KnownBits::abs(bool IntMinIsPoison) const { |
551 | // If the source's MSB is zero then we know the rest of the bits already. |
552 | if (isNonNegative()) |
553 | return *this; |
554 | |
555 | // Absolute value preserves trailing zero count. |
556 | KnownBits KnownAbs(getBitWidth()); |
557 | |
558 | // If the input is negative, then abs(x) == -x. |
559 | if (isNegative()) { |
560 | KnownBits Tmp = *this; |
561 | // Special case for IntMinIsPoison. We know the sign bit is set and we know |
562 | // all the rest of the bits except one to be zero. Since we have |
563 | // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is |
564 | // INT_MIN. |
565 | if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth()) |
566 | Tmp.One.setBit(countMinTrailingZeros()); |
567 | |
568 | KnownAbs = computeForAddSub( |
569 | /*Add*/ false, NSW: IntMinIsPoison, /*NUW=*/false, |
570 | LHS: KnownBits::makeConstant(C: APInt(getBitWidth(), 0)), RHS: Tmp); |
571 | |
572 | // One more special case for IntMinIsPoison. If we don't know any ones other |
573 | // than the signbit, we know for certain that all the unknowns can't be |
574 | // zero. So if we know high zero bits, but have unknown low bits, we know |
575 | // for certain those high-zero bits will end up as one. This is because, |
576 | // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up |
577 | // to the high bits. If we know a known INT_MIN input skip this. The result |
578 | // is poison anyways. |
579 | if (IntMinIsPoison && Tmp.countMinPopulation() == 1 && |
580 | Tmp.countMaxPopulation() != 1) { |
581 | Tmp.One.clearSignBit(); |
582 | Tmp.Zero.setSignBit(); |
583 | KnownAbs.One.setBits(loBit: getBitWidth() - Tmp.countMinLeadingZeros(), |
584 | hiBit: getBitWidth() - 1); |
585 | } |
586 | |
587 | } else { |
588 | unsigned MaxTZ = countMaxTrailingZeros(); |
589 | unsigned MinTZ = countMinTrailingZeros(); |
590 | |
591 | KnownAbs.Zero.setLowBits(MinTZ); |
592 | // If we know the lowest set 1, then preserve it. |
593 | if (MaxTZ == MinTZ && MaxTZ < getBitWidth()) |
594 | KnownAbs.One.setBit(MaxTZ); |
595 | |
596 | // We only know that the absolute values's MSB will be zero if INT_MIN is |
597 | // poison, or there is a set bit that isn't the sign bit (otherwise it could |
598 | // be INT_MIN). |
599 | if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) { |
600 | KnownAbs.One.clearSignBit(); |
601 | KnownAbs.Zero.setSignBit(); |
602 | } |
603 | } |
604 | |
605 | return KnownAbs; |
606 | } |
607 | |
608 | static KnownBits computeForSatAddSub(bool Add, bool Signed, |
609 | const KnownBits &LHS, |
610 | const KnownBits &RHS) { |
611 | // We don't see NSW even for sadd/ssub as we want to check if the result has |
612 | // signed overflow. |
613 | KnownBits Res = |
614 | KnownBits::computeForAddSub(Add, /*NSW=*/false, /*NUW=*/false, LHS, RHS); |
615 | unsigned BitWidth = Res.getBitWidth(); |
616 | auto SignBitKnown = [&](const KnownBits &K) { |
617 | return K.Zero[BitWidth - 1] || K.One[BitWidth - 1]; |
618 | }; |
619 | std::optional<bool> Overflow; |
620 | |
621 | if (Signed) { |
622 | // If we can actually detect overflow do so. Otherwise leave Overflow as |
623 | // nullopt (we assume it may have happened). |
624 | if (SignBitKnown(LHS) && SignBitKnown(RHS) && SignBitKnown(Res)) { |
625 | if (Add) { |
626 | // sadd.sat |
627 | Overflow = (LHS.isNonNegative() == RHS.isNonNegative() && |
628 | Res.isNonNegative() != LHS.isNonNegative()); |
629 | } else { |
630 | // ssub.sat |
631 | Overflow = (LHS.isNonNegative() != RHS.isNonNegative() && |
632 | Res.isNonNegative() != LHS.isNonNegative()); |
633 | } |
634 | } |
635 | } else if (Add) { |
636 | // uadd.sat |
637 | bool Of; |
638 | (void)LHS.getMaxValue().uadd_ov(RHS: RHS.getMaxValue(), Overflow&: Of); |
639 | if (!Of) { |
640 | Overflow = false; |
641 | } else { |
642 | (void)LHS.getMinValue().uadd_ov(RHS: RHS.getMinValue(), Overflow&: Of); |
643 | if (Of) |
644 | Overflow = true; |
645 | } |
646 | } else { |
647 | // usub.sat |
648 | bool Of; |
649 | (void)LHS.getMinValue().usub_ov(RHS: RHS.getMaxValue(), Overflow&: Of); |
650 | if (!Of) { |
651 | Overflow = false; |
652 | } else { |
653 | (void)LHS.getMaxValue().usub_ov(RHS: RHS.getMinValue(), Overflow&: Of); |
654 | if (Of) |
655 | Overflow = true; |
656 | } |
657 | } |
658 | |
659 | if (Signed) { |
660 | if (Add) { |
661 | if (LHS.isNonNegative() && RHS.isNonNegative()) { |
662 | // Pos + Pos -> Pos |
663 | Res.One.clearSignBit(); |
664 | Res.Zero.setSignBit(); |
665 | } |
666 | if (LHS.isNegative() && RHS.isNegative()) { |
667 | // Neg + Neg -> Neg |
668 | Res.One.setSignBit(); |
669 | Res.Zero.clearSignBit(); |
670 | } |
671 | } else { |
672 | if (LHS.isNegative() && RHS.isNonNegative()) { |
673 | // Neg - Pos -> Neg |
674 | Res.One.setSignBit(); |
675 | Res.Zero.clearSignBit(); |
676 | } else if (LHS.isNonNegative() && RHS.isNegative()) { |
677 | // Pos - Neg -> Pos |
678 | Res.One.clearSignBit(); |
679 | Res.Zero.setSignBit(); |
680 | } |
681 | } |
682 | } else { |
683 | // Add: Leading ones of either operand are preserved. |
684 | // Sub: Leading zeros of LHS and leading ones of RHS are preserved |
685 | // as leading zeros in the result. |
686 | unsigned LeadingKnown; |
687 | if (Add) |
688 | LeadingKnown = |
689 | std::max(a: LHS.countMinLeadingOnes(), b: RHS.countMinLeadingOnes()); |
690 | else |
691 | LeadingKnown = |
692 | std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingOnes()); |
693 | |
694 | // We select between the operation result and all-ones/zero |
695 | // respectively, so we can preserve known ones/zeros. |
696 | APInt Mask = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: LeadingKnown); |
697 | if (Add) { |
698 | Res.One |= Mask; |
699 | Res.Zero &= ~Mask; |
700 | } else { |
701 | Res.Zero |= Mask; |
702 | Res.One &= ~Mask; |
703 | } |
704 | } |
705 | |
706 | if (Overflow) { |
707 | // We know whether or not we overflowed. |
708 | if (!(*Overflow)) { |
709 | // No overflow. |
710 | return Res; |
711 | } |
712 | |
713 | // We overflowed |
714 | APInt C; |
715 | if (Signed) { |
716 | // sadd.sat / ssub.sat |
717 | assert(SignBitKnown(LHS) && |
718 | "We somehow know overflow without knowing input sign" ); |
719 | C = LHS.isNegative() ? APInt::getSignedMinValue(numBits: BitWidth) |
720 | : APInt::getSignedMaxValue(numBits: BitWidth); |
721 | } else if (Add) { |
722 | // uadd.sat |
723 | C = APInt::getMaxValue(numBits: BitWidth); |
724 | } else { |
725 | // uadd.sat |
726 | C = APInt::getMinValue(numBits: BitWidth); |
727 | } |
728 | |
729 | Res.One = C; |
730 | Res.Zero = ~C; |
731 | return Res; |
732 | } |
733 | |
734 | // We don't know if we overflowed. |
735 | if (Signed) { |
736 | // sadd.sat/ssub.sat |
737 | // We can keep our information about the sign bits. |
738 | Res.Zero.clearLowBits(loBits: BitWidth - 1); |
739 | Res.One.clearLowBits(loBits: BitWidth - 1); |
740 | } else if (Add) { |
741 | // uadd.sat |
742 | // We need to clear all the known zeros as we can only use the leading ones. |
743 | Res.Zero.clearAllBits(); |
744 | } else { |
745 | // usub.sat |
746 | // We need to clear all the known ones as we can only use the leading zero. |
747 | Res.One.clearAllBits(); |
748 | } |
749 | |
750 | return Res; |
751 | } |
752 | |
753 | KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { |
754 | return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); |
755 | } |
756 | KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { |
757 | return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); |
758 | } |
759 | KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { |
760 | return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); |
761 | } |
762 | KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { |
763 | return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); |
764 | } |
765 | |
766 | static KnownBits avgCompute(KnownBits LHS, KnownBits RHS, bool IsCeil, |
767 | bool IsSigned) { |
768 | unsigned BitWidth = LHS.getBitWidth(); |
769 | LHS = IsSigned ? LHS.sext(BitWidth: BitWidth + 1) : LHS.zext(BitWidth: BitWidth + 1); |
770 | RHS = IsSigned ? RHS.sext(BitWidth: BitWidth + 1) : RHS.zext(BitWidth: BitWidth + 1); |
771 | LHS = |
772 | computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil); |
773 | LHS = LHS.extractBits(NumBits: BitWidth, BitPosition: 1); |
774 | return LHS; |
775 | } |
776 | |
777 | KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) { |
778 | return avgCompute(LHS, RHS, /* IsCeil */ false, |
779 | /* IsSigned */ true); |
780 | } |
781 | |
782 | KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) { |
783 | return avgCompute(LHS, RHS, /* IsCeil */ false, |
784 | /* IsSigned */ false); |
785 | } |
786 | |
787 | KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) { |
788 | return avgCompute(LHS, RHS, /* IsCeil */ true, |
789 | /* IsSigned */ true); |
790 | } |
791 | |
792 | KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) { |
793 | return avgCompute(LHS, RHS, /* IsCeil */ true, |
794 | /* IsSigned */ false); |
795 | } |
796 | |
797 | KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, |
798 | bool NoUndefSelfMultiply) { |
799 | unsigned BitWidth = LHS.getBitWidth(); |
800 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
801 | assert((!NoUndefSelfMultiply || LHS == RHS) && |
802 | "Self multiplication knownbits mismatch" ); |
803 | |
804 | // Compute the high known-0 bits by multiplying the unsigned max of each side. |
805 | // Conservatively, M active bits * N active bits results in M + N bits in the |
806 | // result. But if we know a value is a power-of-2 for example, then this |
807 | // computes one more leading zero. |
808 | // TODO: This could be generalized to number of sign bits (negative numbers). |
809 | APInt UMaxLHS = LHS.getMaxValue(); |
810 | APInt UMaxRHS = RHS.getMaxValue(); |
811 | |
812 | // For leading zeros in the result to be valid, the unsigned max product must |
813 | // fit in the bitwidth (it must not overflow). |
814 | bool HasOverflow; |
815 | APInt UMaxResult = UMaxLHS.umul_ov(RHS: UMaxRHS, Overflow&: HasOverflow); |
816 | unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); |
817 | |
818 | // The result of the bottom bits of an integer multiply can be |
819 | // inferred by looking at the bottom bits of both operands and |
820 | // multiplying them together. |
821 | // We can infer at least the minimum number of known trailing bits |
822 | // of both operands. Depending on number of trailing zeros, we can |
823 | // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming |
824 | // a and b are divisible by m and n respectively. |
825 | // We then calculate how many of those bits are inferrable and set |
826 | // the output. For example, the i8 mul: |
827 | // a = XXXX1100 (12) |
828 | // b = XXXX1110 (14) |
829 | // We know the bottom 3 bits are zero since the first can be divided by |
830 | // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). |
831 | // Applying the multiplication to the trimmed arguments gets: |
832 | // XX11 (3) |
833 | // X111 (7) |
834 | // ------- |
835 | // XX11 |
836 | // XX11 |
837 | // XX11 |
838 | // XX11 |
839 | // ------- |
840 | // XXXXX01 |
841 | // Which allows us to infer the 2 LSBs. Since we're multiplying the result |
842 | // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. |
843 | // The proof for this can be described as: |
844 | // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && |
845 | // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + |
846 | // umin(countTrailingZeros(C2), C6) + |
847 | // umin(C5 - umin(countTrailingZeros(C1), C5), |
848 | // C6 - umin(countTrailingZeros(C2), C6)))) - 1) |
849 | // %aa = shl i8 %a, C5 |
850 | // %bb = shl i8 %b, C6 |
851 | // %aaa = or i8 %aa, C1 |
852 | // %bbb = or i8 %bb, C2 |
853 | // %mul = mul i8 %aaa, %bbb |
854 | // %mask = and i8 %mul, C7 |
855 | // => |
856 | // %mask = i8 ((C1*C2)&C7) |
857 | // Where C5, C6 describe the known bits of %a, %b |
858 | // C1, C2 describe the known bottom bits of %a, %b. |
859 | // C7 describes the mask of the known bits of the result. |
860 | const APInt &Bottom0 = LHS.One; |
861 | const APInt &Bottom1 = RHS.One; |
862 | |
863 | // How many times we'd be able to divide each argument by 2 (shr by 1). |
864 | // This gives us the number of trailing zeros on the multiplication result. |
865 | unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one(); |
866 | unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one(); |
867 | unsigned TrailZero0 = LHS.countMinTrailingZeros(); |
868 | unsigned TrailZero1 = RHS.countMinTrailingZeros(); |
869 | unsigned TrailZ = TrailZero0 + TrailZero1; |
870 | |
871 | // Figure out the fewest known-bits operand. |
872 | unsigned SmallestOperand = |
873 | std::min(a: TrailBitsKnown0 - TrailZero0, b: TrailBitsKnown1 - TrailZero1); |
874 | unsigned ResultBitsKnown = std::min(a: SmallestOperand + TrailZ, b: BitWidth); |
875 | |
876 | APInt BottomKnown = |
877 | Bottom0.getLoBits(numBits: TrailBitsKnown0) * Bottom1.getLoBits(numBits: TrailBitsKnown1); |
878 | |
879 | KnownBits Res(BitWidth); |
880 | Res.Zero.setHighBits(LeadZ); |
881 | Res.Zero |= (~BottomKnown).getLoBits(numBits: ResultBitsKnown); |
882 | Res.One = BottomKnown.getLoBits(numBits: ResultBitsKnown); |
883 | |
884 | // If we're self-multiplying then bit[1] is guaranteed to be zero. |
885 | if (NoUndefSelfMultiply && BitWidth > 1) { |
886 | assert(Res.One[1] == 0 && |
887 | "Self-multiplication failed Quadratic Reciprocity!" ); |
888 | Res.Zero.setBit(1); |
889 | } |
890 | |
891 | return Res; |
892 | } |
893 | |
894 | KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) { |
895 | unsigned BitWidth = LHS.getBitWidth(); |
896 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
897 | KnownBits WideLHS = LHS.sext(BitWidth: 2 * BitWidth); |
898 | KnownBits WideRHS = RHS.sext(BitWidth: 2 * BitWidth); |
899 | return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth); |
900 | } |
901 | |
902 | KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) { |
903 | unsigned BitWidth = LHS.getBitWidth(); |
904 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
905 | KnownBits WideLHS = LHS.zext(BitWidth: 2 * BitWidth); |
906 | KnownBits WideRHS = RHS.zext(BitWidth: 2 * BitWidth); |
907 | return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth); |
908 | } |
909 | |
910 | static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS, |
911 | const KnownBits &RHS, bool Exact) { |
912 | |
913 | if (!Exact) |
914 | return Known; |
915 | |
916 | // If LHS is Odd, the result is Odd no matter what. |
917 | // Odd / Odd -> Odd |
918 | // Odd / Even -> Impossible (because its exact division) |
919 | if (LHS.One[0]) |
920 | Known.One.setBit(0); |
921 | |
922 | int MinTZ = |
923 | (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros(); |
924 | int MaxTZ = |
925 | (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros(); |
926 | if (MinTZ >= 0) { |
927 | // Result has at least MinTZ trailing zeros. |
928 | Known.Zero.setLowBits(MinTZ); |
929 | if (MinTZ == MaxTZ) { |
930 | // Result has exactly MinTZ trailing zeros. |
931 | Known.One.setBit(MinTZ); |
932 | } |
933 | } else if (MaxTZ < 0) { |
934 | // Poison Result |
935 | Known.setAllZero(); |
936 | } |
937 | |
938 | // In the KnownBits exhaustive tests, we have poison inputs for exact values |
939 | // a LOT. If we have a conflict, just return all zeros. |
940 | if (Known.hasConflict()) |
941 | Known.setAllZero(); |
942 | |
943 | return Known; |
944 | } |
945 | |
946 | KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, |
947 | bool Exact) { |
948 | // Equivalent of `udiv`. We must have caught this before it was folded. |
949 | if (LHS.isNonNegative() && RHS.isNonNegative()) |
950 | return udiv(LHS, RHS, Exact); |
951 | |
952 | unsigned BitWidth = LHS.getBitWidth(); |
953 | KnownBits Known(BitWidth); |
954 | |
955 | if (LHS.isZero() || RHS.isZero()) { |
956 | // Result is either known Zero or UB. Return Zero either way. |
957 | // Checking this earlier saves us a lot of special cases later on. |
958 | Known.setAllZero(); |
959 | return Known; |
960 | } |
961 | |
962 | std::optional<APInt> Res; |
963 | if (LHS.isNegative() && RHS.isNegative()) { |
964 | // Result non-negative. |
965 | APInt Denom = RHS.getSignedMaxValue(); |
966 | APInt Num = LHS.getSignedMinValue(); |
967 | // INT_MIN/-1 would be a poison result (impossible). Estimate the division |
968 | // as signed max (we will only set sign bit in the result). |
969 | Res = (Num.isMinSignedValue() && Denom.isAllOnes()) |
970 | ? APInt::getSignedMaxValue(numBits: BitWidth) |
971 | : Num.sdiv(RHS: Denom); |
972 | } else if (LHS.isNegative() && RHS.isNonNegative()) { |
973 | // Result is negative if Exact OR -LHS u>= RHS. |
974 | if (Exact || (-LHS.getSignedMaxValue()).uge(RHS: RHS.getSignedMaxValue())) { |
975 | APInt Denom = RHS.getSignedMinValue(); |
976 | APInt Num = LHS.getSignedMinValue(); |
977 | Res = Denom.isZero() ? Num : Num.sdiv(RHS: Denom); |
978 | } |
979 | } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { |
980 | // Result is negative if Exact OR LHS u>= -RHS. |
981 | if (Exact || LHS.getSignedMinValue().uge(RHS: -RHS.getSignedMinValue())) { |
982 | APInt Denom = RHS.getSignedMaxValue(); |
983 | APInt Num = LHS.getSignedMaxValue(); |
984 | Res = Num.sdiv(RHS: Denom); |
985 | } |
986 | } |
987 | |
988 | if (Res) { |
989 | if (Res->isNonNegative()) { |
990 | unsigned LeadZ = Res->countLeadingZeros(); |
991 | Known.Zero.setHighBits(LeadZ); |
992 | } else { |
993 | unsigned LeadO = Res->countLeadingOnes(); |
994 | Known.One.setHighBits(LeadO); |
995 | } |
996 | } |
997 | |
998 | Known = divComputeLowBit(Known, LHS, RHS, Exact); |
999 | return Known; |
1000 | } |
1001 | |
1002 | KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS, |
1003 | bool Exact) { |
1004 | unsigned BitWidth = LHS.getBitWidth(); |
1005 | KnownBits Known(BitWidth); |
1006 | |
1007 | if (LHS.isZero() || RHS.isZero()) { |
1008 | // Result is either known Zero or UB. Return Zero either way. |
1009 | // Checking this earlier saves us a lot of special cases later on. |
1010 | Known.setAllZero(); |
1011 | return Known; |
1012 | } |
1013 | |
1014 | // We can figure out the minimum number of upper zero bits by doing |
1015 | // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator |
1016 | // gets larger, the number of upper zero bits increases. |
1017 | APInt MinDenom = RHS.getMinValue(); |
1018 | APInt MaxNum = LHS.getMaxValue(); |
1019 | APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(RHS: MinDenom); |
1020 | |
1021 | unsigned LeadZ = MaxRes.countLeadingZeros(); |
1022 | |
1023 | Known.Zero.setHighBits(LeadZ); |
1024 | Known = divComputeLowBit(Known, LHS, RHS, Exact); |
1025 | |
1026 | return Known; |
1027 | } |
1028 | |
1029 | KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) { |
1030 | unsigned BitWidth = LHS.getBitWidth(); |
1031 | if (!RHS.isZero() && RHS.Zero[0]) { |
1032 | // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result. |
1033 | unsigned RHSZeros = RHS.countMinTrailingZeros(); |
1034 | APInt Mask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: RHSZeros); |
1035 | APInt OnesMask = LHS.One & Mask; |
1036 | APInt ZerosMask = LHS.Zero & Mask; |
1037 | return KnownBits(ZerosMask, OnesMask); |
1038 | } |
1039 | return KnownBits(BitWidth); |
1040 | } |
1041 | |
1042 | KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { |
1043 | KnownBits Known = remGetLowBits(LHS, RHS); |
1044 | if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { |
1045 | // NB: Low bits set in `remGetLowBits`. |
1046 | APInt HighBits = ~(RHS.getConstant() - 1); |
1047 | Known.Zero |= HighBits; |
1048 | return Known; |
1049 | } |
1050 | |
1051 | // Since the result is less than or equal to either operand, any leading |
1052 | // zero bits in either operand must also exist in the result. |
1053 | uint32_t Leaders = |
1054 | std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingZeros()); |
1055 | Known.Zero.setHighBits(Leaders); |
1056 | return Known; |
1057 | } |
1058 | |
1059 | KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { |
1060 | KnownBits Known = remGetLowBits(LHS, RHS); |
1061 | if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { |
1062 | // NB: Low bits are set in `remGetLowBits`. |
1063 | APInt LowBits = RHS.getConstant() - 1; |
1064 | // If the first operand is non-negative or has all low bits zero, then |
1065 | // the upper bits are all zero. |
1066 | if (LHS.isNonNegative() || LowBits.isSubsetOf(RHS: LHS.Zero)) |
1067 | Known.Zero |= ~LowBits; |
1068 | |
1069 | // If the first operand is negative and not all low bits are zero, then |
1070 | // the upper bits are all one. |
1071 | if (LHS.isNegative() && LowBits.intersects(RHS: LHS.One)) |
1072 | Known.One |= ~LowBits; |
1073 | return Known; |
1074 | } |
1075 | |
1076 | // The sign bit is the LHS's sign bit, except when the result of the |
1077 | // remainder is zero. The magnitude of the result should be less than or |
1078 | // equal to the magnitude of the LHS. Therefore any leading zeros that exist |
1079 | // in the left hand side must also exist in the result. |
1080 | Known.Zero.setHighBits(LHS.countMinLeadingZeros()); |
1081 | return Known; |
1082 | } |
1083 | |
1084 | KnownBits &KnownBits::operator&=(const KnownBits &RHS) { |
1085 | // Result bit is 0 if either operand bit is 0. |
1086 | Zero |= RHS.Zero; |
1087 | // Result bit is 1 if both operand bits are 1. |
1088 | One &= RHS.One; |
1089 | return *this; |
1090 | } |
1091 | |
1092 | KnownBits &KnownBits::operator|=(const KnownBits &RHS) { |
1093 | // Result bit is 0 if both operand bits are 0. |
1094 | Zero &= RHS.Zero; |
1095 | // Result bit is 1 if either operand bit is 1. |
1096 | One |= RHS.One; |
1097 | return *this; |
1098 | } |
1099 | |
1100 | KnownBits &KnownBits::operator^=(const KnownBits &RHS) { |
1101 | // Result bit is 0 if both operand bits are 0 or both are 1. |
1102 | APInt Z = (Zero & RHS.Zero) | (One & RHS.One); |
1103 | // Result bit is 1 if one operand bit is 0 and the other is 1. |
1104 | One = (Zero & RHS.One) | (One & RHS.Zero); |
1105 | Zero = std::move(Z); |
1106 | return *this; |
1107 | } |
1108 | |
1109 | KnownBits KnownBits::blsi() const { |
1110 | unsigned BitWidth = getBitWidth(); |
1111 | KnownBits Known(Zero, APInt(BitWidth, 0)); |
1112 | unsigned Max = countMaxTrailingZeros(); |
1113 | Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth)); |
1114 | unsigned Min = countMinTrailingZeros(); |
1115 | if (Max == Min && Max < BitWidth) |
1116 | Known.One.setBit(Max); |
1117 | return Known; |
1118 | } |
1119 | |
1120 | KnownBits KnownBits::blsmsk() const { |
1121 | unsigned BitWidth = getBitWidth(); |
1122 | KnownBits Known(BitWidth); |
1123 | unsigned Max = countMaxTrailingZeros(); |
1124 | Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth)); |
1125 | unsigned Min = countMinTrailingZeros(); |
1126 | Known.One.setLowBits(std::min(a: Min + 1, b: BitWidth)); |
1127 | return Known; |
1128 | } |
1129 | |
1130 | void KnownBits::print(raw_ostream &OS) const { |
1131 | unsigned BitWidth = getBitWidth(); |
1132 | for (unsigned I = 0; I < BitWidth; ++I) { |
1133 | unsigned N = BitWidth - I - 1; |
1134 | if (Zero[N] && One[N]) |
1135 | OS << "!" ; |
1136 | else if (Zero[N]) |
1137 | OS << "0" ; |
1138 | else if (One[N]) |
1139 | OS << "1" ; |
1140 | else |
1141 | OS << "?" ; |
1142 | } |
1143 | } |
1144 | void KnownBits::dump() const { |
1145 | print(OS&: dbgs()); |
1146 | dbgs() << "\n" ; |
1147 | } |
1148 | |