1//===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a class for representing known zeros and ones used by
10// computeKnownBits.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/KnownBits.h"
15#include "llvm/ADT/Sequence.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/raw_ostream.h"
18#include <cassert>
19
20using namespace llvm;
21
22KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
23 unsigned SignBitPosition = Val.getBitWidth() - 1;
24 APInt Zero = Val.Zero;
25 APInt One = Val.One;
26 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
27 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
28 return KnownBits(Zero, One);
29}
30
31static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
32 bool CarryZero, bool CarryOne) {
33
34 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
35 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
36
37 // Compute known bits of the carry.
38 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
39 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
40
41 // Compute set of known bits (where all three relevant bits are known).
42 APInt LHSKnownUnion = LHS.Zero | LHS.One;
43 APInt RHSKnownUnion = RHS.Zero | RHS.One;
44 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
45 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
46
47 // Compute known bits of the result.
48 KnownBits KnownOut;
49 KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
50 KnownOut.One = std::move(PossibleSumOne) & Known;
51 return KnownOut;
52}
53
54KnownBits KnownBits::computeForAddCarry(
55 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
56 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
57 return ::computeForAddCarry(
58 LHS, RHS, CarryZero: Carry.Zero.getBoolValue(), CarryOne: Carry.One.getBoolValue());
59}
60
61KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
62 const KnownBits &LHS,
63 const KnownBits &RHS) {
64 unsigned BitWidth = LHS.getBitWidth();
65 KnownBits KnownOut(BitWidth);
66 // This can be a relatively expensive helper, so optimistically save some
67 // work.
68 if (LHS.isUnknown() && RHS.isUnknown())
69 return KnownOut;
70
71 if (!LHS.isUnknown() && !RHS.isUnknown()) {
72 if (Add) {
73 // Sum = LHS + RHS + 0
74 KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
75 /*CarryOne=*/false);
76 } else {
77 // Sum = LHS + ~RHS + 1
78 KnownBits NotRHS = RHS;
79 std::swap(a&: NotRHS.Zero, b&: NotRHS.One);
80 KnownOut = ::computeForAddCarry(LHS, RHS: NotRHS, /*CarryZero=*/false,
81 /*CarryOne=*/true);
82 }
83 }
84
85 // Handle add/sub given nsw and/or nuw.
86 if (NUW) {
87 if (Add) {
88 // (add nuw X, Y)
89 APInt MinVal = LHS.getMinValue().uadd_sat(RHS: RHS.getMinValue());
90 // None of the adds can end up overflowing, so min consecutive highbits
91 // in minimum possible of X + Y must all remain set.
92 if (NSW) {
93 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
94 // If we have NSW as well, we also know we can't overflow the signbit so
95 // can start counting from 1 bit back.
96 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
97 }
98 KnownOut.One.setHighBits(MinVal.countl_one());
99 } else {
100 // (sub nuw X, Y)
101 APInt MaxVal = LHS.getMaxValue().usub_sat(RHS: RHS.getMinValue());
102 // None of the subs can overflow at any point, so any common high bits
103 // will subtract away and result in zeros.
104 if (NSW) {
105 // If we have NSW as well, we also know we can't overflow the signbit so
106 // can start counting from 1 bit back.
107 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
108 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
109 }
110 KnownOut.Zero.setHighBits(MaxVal.countl_zero());
111 }
112 }
113
114 if (NSW) {
115 APInt MinVal;
116 APInt MaxVal;
117 if (Add) {
118 // (add nsw X, Y)
119 MinVal = LHS.getSignedMinValue().sadd_sat(RHS: RHS.getSignedMinValue());
120 MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS: RHS.getSignedMaxValue());
121 } else {
122 // (sub nsw X, Y)
123 MinVal = LHS.getSignedMinValue().ssub_sat(RHS: RHS.getSignedMaxValue());
124 MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS: RHS.getSignedMinValue());
125 }
126 if (MinVal.isNonNegative()) {
127 // If min is non-negative, result will always be non-neg (can't overflow
128 // around).
129 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
130 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
131 KnownOut.Zero.setSignBit();
132 }
133 if (MaxVal.isNegative()) {
134 // If max is negative, result will always be neg (can't overflow around).
135 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
136 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
137 KnownOut.One.setSignBit();
138 }
139 }
140
141 // Just return 0 if the nsw/nuw is violated and we have poison.
142 if (KnownOut.hasConflict())
143 KnownOut.setAllZero();
144 return KnownOut;
145}
146
147KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
148 const KnownBits &Borrow) {
149 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
150
151 // LHS - RHS = LHS + ~RHS + 1
152 // Carry 1 - Borrow in ::computeForAddCarry
153 std::swap(a&: RHS.Zero, b&: RHS.One);
154 return ::computeForAddCarry(LHS, RHS,
155 /*CarryZero=*/Borrow.One.getBoolValue(),
156 /*CarryOne=*/Borrow.Zero.getBoolValue());
157}
158
159KnownBits KnownBits::truncSSat(unsigned BitWidth) const {
160 unsigned InputBits = getBitWidth();
161 APInt MinInRange = APInt::getSignedMinValue(numBits: BitWidth).sext(width: InputBits);
162 APInt MaxInRange = APInt::getSignedMaxValue(numBits: BitWidth).sext(width: InputBits);
163 APInt InputMin = getSignedMinValue();
164 APInt InputMax = getSignedMaxValue();
165 KnownBits Known(BitWidth);
166
167 // Case 1: All values fit - just truncate
168 if (InputMin.sge(RHS: MinInRange) && InputMax.sle(RHS: MaxInRange)) {
169 Known = trunc(BitWidth);
170 }
171 // Case 2: All saturate to min
172 else if (InputMax.slt(RHS: MinInRange)) {
173 Known = KnownBits::makeConstant(C: APInt::getSignedMinValue(numBits: BitWidth));
174 }
175 // Case 3: All saturate to max
176 else if (InputMin.sgt(RHS: MaxInRange)) {
177 Known = KnownBits::makeConstant(C: APInt::getSignedMaxValue(numBits: BitWidth));
178 }
179 // Case 4: All non-negative, some fit, some saturate to max
180 else if (InputMin.isNonNegative()) {
181 // Output: truncated OR max saturation
182 // Max saturation has only sign bit as 0
183 Known.Zero = APInt(BitWidth, 0);
184 Known.Zero.setBit(BitWidth - 1); // Sign bit always 0
185 // Max saturation has all lower bits as 1, so preserve InputOneLower
186 Known.One = One.trunc(width: BitWidth);
187 Known.One.clearBit(BitPosition: BitWidth - 1); // Sign bit is 0, not 1
188 }
189 // Case 5: All negative, some fit, some saturate to min
190 else if (InputMax.isNegative()) {
191 // Output: truncated OR min saturation
192 // Min saturation has all lower bits as 0, so preserve InputZeroLower
193 Known.Zero = Zero.trunc(width: BitWidth);
194 Known.Zero.clearBit(BitPosition: BitWidth - 1); // Sign bit is 1, not 0
195 // Min saturation has only sign bit as 1
196 Known.One = APInt(BitWidth, 0);
197 Known.One.setBit(BitWidth - 1); // Sign bit always 1
198 }
199 // Case 6: Mixed positive and negative
200 else {
201 // Output: min saturation, truncated, or max saturation
202 APInt InputZeroLower = Zero.trunc(width: BitWidth);
203 APInt InputOneLower = One.trunc(width: BitWidth);
204 APInt MinSat = APInt::getSignedMinValue(numBits: BitWidth);
205 APInt MaxSat = APInt::getSignedMaxValue(numBits: BitWidth);
206 KnownBits MinSatKB = KnownBits::makeConstant(C: MinSat);
207 KnownBits MaxSatKB = KnownBits::makeConstant(C: MaxSat);
208 if (InputMax.sle(RHS: MaxInRange)) {
209 // Positive values fit, only negatives saturate to min
210 Known.Zero = InputZeroLower & MinSatKB.Zero;
211 Known.One = InputOneLower & MinSatKB.One;
212 } else if (InputMin.sge(RHS: MinInRange)) {
213 // Negative values fit, only positives saturate to max
214 Known.Zero = InputZeroLower & MaxSatKB.Zero;
215 Known.One = InputOneLower & MaxSatKB.One;
216 } else {
217 // Both positive and negative values might saturate
218 Known.Zero = InputZeroLower & MinSatKB.Zero & MaxSatKB.Zero;
219 Known.One = InputOneLower & MinSatKB.One & MaxSatKB.One;
220 }
221 }
222 return Known;
223}
224
225KnownBits KnownBits::truncSSatU(unsigned BitWidth) const {
226 unsigned InputBits = getBitWidth();
227 APInt MaxInRange = APInt::getAllOnes(numBits: BitWidth).zext(width: InputBits);
228 APInt InputMin = getSignedMinValue();
229 APInt InputMax = getSignedMaxValue();
230 KnownBits Known(BitWidth);
231
232 if (isNegative()) {
233 Known.setAllZero();
234 } else if (InputMin.isNonNegative() && InputMax.ule(RHS: MaxInRange)) {
235 Known = trunc(BitWidth);
236 } else if (InputMin.isNonNegative() && InputMin.ugt(RHS: MaxInRange)) {
237 Known.setAllOnes();
238 } else if (InputMin.isNonNegative()) {
239 // All non-negative but mixed: some fit, some saturate to all-ones
240 // No common zero bits (saturation is all-ones)
241 Known.Zero = APInt(BitWidth, 0);
242 // Common one bits: bits that are 1 in input stay 1
243 Known.One = One.trunc(width: BitWidth);
244 } else {
245 // Mixed: sign bit unknown
246 if (InputMax.ule(RHS: MaxInRange)) {
247 // Positive values all fit (no saturation to all-ones)
248 // Output: all-zeros (negatives) OR truncated (positives)
249 Known.Zero = Zero.trunc(width: BitWidth);
250 Known.One = APInt(BitWidth, 0);
251 } else {
252 // Positive values might saturate to all-ones
253 // Output: all-zeros OR truncated OR all-ones
254 // No common bits
255 Known.Zero = APInt(BitWidth, 0);
256 Known.One = APInt(BitWidth, 0);
257 }
258 }
259 return Known;
260}
261
262KnownBits KnownBits::truncUSat(unsigned BitWidth) const {
263 unsigned InputBits = getBitWidth();
264 APInt MaxInRange = APInt::getLowBitsSet(numBits: InputBits, loBitsSet: BitWidth);
265 APInt InputMax = getMaxValue();
266 APInt InputMin = getMinValue();
267 KnownBits Known(BitWidth);
268
269 if (InputMax.ule(RHS: MaxInRange)) {
270 Known = trunc(BitWidth);
271 } else if (InputMin.ugt(RHS: MaxInRange)) {
272 Known.setAllOnes();
273 } else {
274 Known.resetAll();
275 Known.One = One.trunc(width: BitWidth);
276 }
277 return Known;
278}
279
280KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
281 unsigned BitWidth = getBitWidth();
282 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
283 "Illegal sext-in-register");
284
285 if (SrcBitWidth == BitWidth)
286 return *this;
287
288 unsigned ExtBits = BitWidth - SrcBitWidth;
289 KnownBits Result;
290 Result.One = One << ExtBits;
291 Result.Zero = Zero << ExtBits;
292 Result.One.ashrInPlace(ShiftAmt: ExtBits);
293 Result.Zero.ashrInPlace(ShiftAmt: ExtBits);
294 return Result;
295}
296
297KnownBits KnownBits::makeGE(const APInt &Val) const {
298 // Count the number of leading bit positions where our underlying value is
299 // known to be less than or equal to Val.
300 unsigned N = (Zero | Val).countl_one();
301
302 // For each of those bit positions, if Val has a 1 in that bit then our
303 // underlying value must also have a 1.
304 APInt MaskedVal(Val);
305 MaskedVal.clearLowBits(loBits: getBitWidth() - N);
306 return KnownBits(Zero, One | MaskedVal);
307}
308
309KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
310 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
311 // RHS. Ideally our caller would already have spotted these cases and
312 // optimized away the umax operation, but we handle them here for
313 // completeness.
314 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
315 return LHS;
316 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
317 return RHS;
318
319 // If the result of the umax is LHS then it must be greater than or equal to
320 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
321 // are common to these two values are also known in the result.
322 KnownBits L = LHS.makeGE(Val: RHS.getMinValue());
323 KnownBits R = RHS.makeGE(Val: LHS.getMinValue());
324 return L.intersectWith(RHS: R);
325}
326
327KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
328 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
329 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
330 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
331}
332
333KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
334 return flipSignBit(Val: umax(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
335}
336
337KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
338 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
339 auto Flip = [](const KnownBits &Val) {
340 unsigned SignBitPosition = Val.getBitWidth() - 1;
341 APInt Zero = Val.One;
342 APInt One = Val.Zero;
343 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
344 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
345 return KnownBits(Zero, One);
346 };
347 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
348}
349
350KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
351 // If we know which argument is larger, return (sub LHS, RHS) or
352 // (sub RHS, LHS) directly.
353 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
354 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
355 RHS);
356 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
357 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
358 RHS: LHS);
359
360 // By construction, the subtraction in abdu never has unsigned overflow.
361 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
362 KnownBits Diff0 =
363 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
364 KnownBits Diff1 =
365 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
366 return Diff0.intersectWith(RHS: Diff1);
367}
368
369KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
370 // If we know which argument is larger, return (sub LHS, RHS) or
371 // (sub RHS, LHS) directly.
372 if (LHS.getSignedMinValue().sge(RHS: RHS.getSignedMaxValue()))
373 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
374 RHS);
375 if (RHS.getSignedMinValue().sge(RHS: LHS.getSignedMaxValue()))
376 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
377 RHS: LHS);
378
379 // Shift both arguments from the signed range to the unsigned range, e.g. from
380 // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
381 // abdu does.
382 // Note that we can't just use "sub nsw" instead because abds has signed
383 // inputs but an unsigned result, which makes the overflow conditions
384 // different.
385 unsigned SignBitPosition = LHS.getBitWidth() - 1;
386 for (auto Arg : {&LHS, &RHS}) {
387 bool Tmp = Arg->Zero[SignBitPosition];
388 Arg->Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Arg->One[SignBitPosition]);
389 Arg->One.setBitVal(BitPosition: SignBitPosition, BitValue: Tmp);
390 }
391
392 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
393 KnownBits Diff0 =
394 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
395 KnownBits Diff1 =
396 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
397 return Diff0.intersectWith(RHS: Diff1);
398}
399
400static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
401 if (isPowerOf2_32(Value: BitWidth))
402 // Clamp to the shift amount's width: a narrower amount is already
403 // < BitWidth, so this stays a valid upper bound.
404 return MaxValue.extractBitsAsZExtValue(
405 numBits: std::min(a: Log2_32(Value: BitWidth), b: MaxValue.getBitWidth()), bitPosition: 0);
406 // This is only an approximate upper bound.
407 return MaxValue.getLimitedValue(Limit: BitWidth - 1);
408}
409
410KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
411 bool NSW, bool ShAmtNonZero) {
412 unsigned BitWidth = LHS.getBitWidth();
413 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
414 KnownBits Known;
415 bool ShiftedOutZero, ShiftedOutOne;
416 Known.Zero = LHS.Zero.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutZero);
417 Known.Zero.setLowBits(ShiftAmt);
418 Known.One = LHS.One.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutOne);
419
420 // All cases returning poison have been handled by MaxShiftAmount already.
421 if (NSW) {
422 if (NUW && ShiftAmt != 0)
423 // NUW means we can assume anything shifted out was a zero.
424 ShiftedOutZero = true;
425
426 if (ShiftedOutZero)
427 Known.makeNonNegative();
428 else if (ShiftedOutOne)
429 Known.makeNegative();
430 }
431 return Known;
432 };
433
434 // Fast path for a common case when LHS is completely unknown.
435 KnownBits Known(BitWidth);
436 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
437 if (MinShiftAmount == 0 && ShAmtNonZero)
438 MinShiftAmount = 1;
439 if (LHS.isUnknown()) {
440 Known.Zero.setLowBits(MinShiftAmount);
441 if (NUW && NSW && MinShiftAmount != 0)
442 Known.makeNonNegative();
443 return Known;
444 }
445
446 // Determine maximum shift amount, taking NUW/NSW flags into account.
447 APInt MaxValue = RHS.getMaxValue();
448 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
449 if (NUW && NSW)
450 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros() - 1);
451 if (NUW)
452 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros());
453 if (NSW)
454 MaxShiftAmount = std::min(
455 a: MaxShiftAmount,
456 b: std::max(a: LHS.countMaxLeadingZeros(), b: LHS.countMaxLeadingOnes()) - 1);
457
458 // Fast path for common case where the shift amount is unknown.
459 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
460 isPowerOf2_32(Value: BitWidth)) {
461 Known.Zero.setLowBits(LHS.countMinTrailingZeros());
462 if (LHS.isAllOnes())
463 Known.One.setSignBit();
464 if (NSW) {
465 if (LHS.isNonNegative())
466 Known.makeNonNegative();
467 if (LHS.isNegative())
468 Known.makeNegative();
469 }
470 return Known;
471 }
472
473 // Find the common bits from all possible shifts.
474 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
475 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
476 Known.setAllConflict();
477 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
478 ++ShiftAmt) {
479 // Skip if the shift amount is impossible.
480 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
481 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
482 continue;
483 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
484 if (Known.isUnknown())
485 break;
486 }
487
488 // All shift amounts may result in poison.
489 if (Known.hasConflict())
490 Known.setAllZero();
491 return Known;
492}
493
494KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
495 bool ShAmtNonZero, bool Exact) {
496 unsigned BitWidth = LHS.getBitWidth();
497 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
498 KnownBits Known = LHS;
499 Known >>= ShiftAmt;
500 // High bits are known zero.
501 Known.Zero.setHighBits(ShiftAmt);
502 return Known;
503 };
504
505 // Fast path for a common case when LHS is completely unknown.
506 KnownBits Known(BitWidth);
507 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
508 if (MinShiftAmount == 0 && ShAmtNonZero)
509 MinShiftAmount = 1;
510 if (LHS.isUnknown()) {
511 Known.Zero.setHighBits(MinShiftAmount);
512 return Known;
513 }
514
515 // Find the common bits from all possible shifts.
516 APInt MaxValue = RHS.getMaxValue();
517 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
518
519 // If exact, bound MaxShiftAmount to first known 1 in LHS.
520 if (Exact) {
521 unsigned FirstOne = LHS.countMaxTrailingZeros();
522 if (FirstOne < MinShiftAmount) {
523 // Always poison. Return zero because we don't like returning conflict.
524 Known.setAllZero();
525 return Known;
526 }
527 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
528 }
529
530 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
531 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
532 Known.setAllConflict();
533 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
534 ++ShiftAmt) {
535 // Skip if the shift amount is impossible.
536 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
537 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
538 continue;
539 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
540 if (Known.isUnknown())
541 break;
542 }
543
544 // All shift amounts may result in poison.
545 if (Known.hasConflict())
546 Known.setAllZero();
547 return Known;
548}
549
550KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
551 bool ShAmtNonZero, bool Exact) {
552 unsigned BitWidth = LHS.getBitWidth();
553 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
554 KnownBits Known = LHS;
555 Known.Zero.ashrInPlace(ShiftAmt);
556 Known.One.ashrInPlace(ShiftAmt);
557 return Known;
558 };
559
560 // Fast path for a common case when LHS is completely unknown.
561 KnownBits Known(BitWidth);
562 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
563 if (MinShiftAmount == 0 && ShAmtNonZero)
564 MinShiftAmount = 1;
565 if (LHS.isUnknown()) {
566 if (MinShiftAmount == BitWidth) {
567 // Always poison. Return zero because we don't like returning conflict.
568 Known.setAllZero();
569 return Known;
570 }
571 return Known;
572 }
573
574 // Find the common bits from all possible shifts.
575 APInt MaxValue = RHS.getMaxValue();
576 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
577
578 // If exact, bound MaxShiftAmount to first known 1 in LHS.
579 if (Exact) {
580 unsigned FirstOne = LHS.countMaxTrailingZeros();
581 if (FirstOne < MinShiftAmount) {
582 // Always poison. Return zero because we don't like returning conflict.
583 Known.setAllZero();
584 return Known;
585 }
586 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
587 }
588
589 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
590 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
591 Known.setAllConflict();
592 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
593 ++ShiftAmt) {
594 // Skip if the shift amount is impossible.
595 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
596 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
597 continue;
598 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
599 if (Known.isUnknown())
600 break;
601 }
602
603 // All shift amounts may result in poison.
604 if (Known.hasConflict())
605 Known.setAllZero();
606 return Known;
607}
608
609KnownBits KnownBits::fshl(const KnownBits &LHS, const KnownBits &RHS,
610 const APInt &Amt) {
611 return KnownBits(APIntOps::fshl(Hi: LHS.Zero, Lo: RHS.Zero, Shift: Amt),
612 APIntOps::fshl(Hi: LHS.One, Lo: RHS.One, Shift: Amt));
613}
614
615KnownBits KnownBits::fshr(const KnownBits &LHS, const KnownBits &RHS,
616 const APInt &Amt) {
617 return KnownBits(APIntOps::fshr(Hi: LHS.Zero, Lo: RHS.Zero, Shift: Amt),
618 APIntOps::fshr(Hi: LHS.One, Lo: RHS.One, Shift: Amt));
619}
620
621KnownBits KnownBits::clmul(const KnownBits &LHS, const KnownBits &RHS) {
622 KnownBits Res =
623 makeConstant(C: APIntOps::clmul(LHS: LHS.getMinValue(), RHS: RHS.getMinValue()));
624
625 // This is the same operation as clmul except it accumulates the result with
626 // an OR instead of an XOR.
627 auto ClMulOr = [](const APInt &LHS, const APInt &RHS) {
628 unsigned BW = LHS.getBitWidth();
629 assert(BW == RHS.getBitWidth() && "Operand mismatch");
630 APInt Result(BW, 0);
631 for (unsigned I :
632 seq(Size: std::min(a: RHS.getActiveBits(), b: BW - LHS.countr_zero())))
633 if (RHS[I])
634 Result |= LHS << I;
635 return Result;
636 };
637
638 // Bits in the result are known if, for every corresponding pair of input
639 // bits, both input bits are known or either input bit is known to be zero.
640 APInt Known = ~(ClMulOr(~LHS.Zero & ~LHS.One, ~RHS.Zero) |
641 ClMulOr(~LHS.Zero, ~RHS.Zero & ~RHS.One));
642 Res.Zero &= Known;
643 Res.One &= Known;
644
645 return Res;
646}
647
648KnownBits KnownBits::pext(const KnownBits &Val, const KnownBits &Mask) {
649 unsigned BitWidth = Val.getBitWidth();
650 KnownBits Res(BitWidth);
651 // We start by asserting that bits cannot be 0 and cannot be 1, then clear
652 // Res bits where we know a bit could have some value.
653 Res.setAllConflict();
654
655 // For each source position I where Mask[I] could be 1, the output position J
656 // lies in [M0, M1], where M0 and M1 track the range of possible 1-bit counts
657 // seen so far in Mask. Note that M0=M1 as long as bits in Mask are known;
658 // otherwise, the range of possible output positions widens.
659 unsigned M0 = 0, M1 = 0;
660 for (unsigned I = 0; I < BitWidth; ++I) {
661 if (!Mask.Zero[I]) {
662 // Mask[I] could be 1, so we decide what value the Res bits could have.
663 if (!Val.Zero[I])
664 // Val[I] could be 1 => Res[J] for J in [M0, M1] could be 1
665 Res.Zero.clearBits(LoBit: M0, HiBit: M1 + 1);
666 if (!Val.One[I])
667 // Val[I] could be 0 => Res[J] for J in [M0, M1] could be 0
668 Res.One.clearBits(LoBit: M0, HiBit: M1 + 1);
669 }
670 if (Mask.One[I])
671 ++M0, ++M1;
672 else if (!Mask.Zero[I])
673 ++M1;
674 }
675
676 // Output bits at J >= M0 may have no source (popcount(Mask) may be <= J), so
677 // they may be 0.
678 Res.One.clearBits(LoBit: M0, HiBit: BitWidth);
679 return Res;
680}
681
682KnownBits KnownBits::pdep(const KnownBits &Val, const KnownBits &Mask) {
683 unsigned BitWidth = Val.getBitWidth();
684 KnownBits Res(BitWidth);
685 // We start by asserting that bits cannot be 0 and cannot be 1, then clear
686 // Res bits where we know a bit could have some value.
687 Res.setAllConflict();
688
689 // For each output position I where Mask[I] could be 1, the source position J
690 // lies in [M0, M1], where M0 and M1 track the range of possible 1-bit counts
691 // seen so far in Mask. Note that M0=M1 as long as bits in Mask are known;
692 // otherwise, the range of possible source positions widens.
693 unsigned M0 = 0, M1 = 0;
694 for (unsigned I = 0; I < BitWidth; ++I) {
695 if (!Mask.One[I])
696 // Mask[I] could be 0 => Res[I] could be 0
697 Res.One.clearBit(BitPosition: I);
698 if (!Mask.Zero[I]) {
699 // Mask[I] could be 1, so we check what value the Val bits could have.
700 APInt Range = APInt::getBitsSet(numBits: BitWidth, loBit: M0, hiBit: M1 + 1);
701 if (!Range.isSubsetOf(RHS: Val.One))
702 // Any Val[J] for J in [M0, M1] could be 0 => Res[I] could be 0
703 Res.One.clearBit(BitPosition: I);
704 if (!Range.isSubsetOf(RHS: Val.Zero))
705 // Any Val[J] for J in [M0, M1] could be 1 => Res[I] could be 1
706 Res.Zero.clearBit(BitPosition: I);
707 }
708 if (Mask.One[I])
709 ++M0, ++M1;
710 else if (!Mask.Zero[I])
711 ++M1;
712 }
713 return Res;
714}
715
716std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
717 if (LHS.isConstant() && RHS.isConstant())
718 return LHS.getConstant() == RHS.getConstant();
719 if (LHS.One.intersects(RHS: RHS.Zero) || RHS.One.intersects(RHS: LHS.Zero))
720 return false;
721 return std::nullopt;
722}
723
724std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
725 if (std::optional<bool> KnownEQ = eq(LHS, RHS))
726 return !*KnownEQ;
727 return std::nullopt;
728}
729
730std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
731 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
732 if (LHS.getMaxValue().ule(RHS: RHS.getMinValue()))
733 return false;
734 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
735 if (LHS.getMinValue().ugt(RHS: RHS.getMaxValue()))
736 return true;
737 return std::nullopt;
738}
739
740std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
741 if (std::optional<bool> IsUGT = ugt(LHS: RHS, RHS: LHS))
742 return !*IsUGT;
743 return std::nullopt;
744}
745
746std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
747 return ugt(LHS: RHS, RHS: LHS);
748}
749
750std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
751 return uge(LHS: RHS, RHS: LHS);
752}
753
754std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
755 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
756 if (LHS.getSignedMaxValue().sle(RHS: RHS.getSignedMinValue()))
757 return false;
758 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
759 if (LHS.getSignedMinValue().sgt(RHS: RHS.getSignedMaxValue()))
760 return true;
761 return std::nullopt;
762}
763
764std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
765 if (std::optional<bool> KnownSGT = sgt(LHS: RHS, RHS: LHS))
766 return !*KnownSGT;
767 return std::nullopt;
768}
769
770std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
771 return sgt(LHS: RHS, RHS: LHS);
772}
773
774std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
775 return sge(LHS: RHS, RHS: LHS);
776}
777
778KnownBits KnownBits::abs(bool IntMinIsPoison) const {
779 // If the source's MSB is zero then we know the rest of the bits already.
780 if (isNonNegative())
781 return *this;
782
783 // Absolute value preserves trailing zero count.
784 KnownBits KnownAbs(getBitWidth());
785
786 // If the input is negative, then abs(x) == -x.
787 if (isNegative()) {
788 KnownBits Tmp = *this;
789 // Special case for IntMinIsPoison. We know the sign bit is set and we know
790 // all the rest of the bits except one to be zero. Since we have
791 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
792 // INT_MIN.
793 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
794 Tmp.One.setBit(countMinTrailingZeros());
795
796 KnownAbs = computeForAddSub(
797 /*Add*/ false, NSW: IntMinIsPoison, /*NUW=*/false,
798 LHS: KnownBits::makeConstant(C: APInt(getBitWidth(), 0)), RHS: Tmp);
799
800 // One more special case for IntMinIsPoison. If we don't know any ones other
801 // than the signbit, we know for certain that all the unknowns can't be
802 // zero. So if we know high zero bits, but have unknown low bits, we know
803 // for certain those high-zero bits will end up as one. This is because,
804 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
805 // to the high bits. If we know a known INT_MIN input skip this. The result
806 // is poison anyways.
807 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
808 Tmp.countMaxPopulation() != 1) {
809 Tmp.One.clearSignBit();
810 Tmp.Zero.setSignBit();
811 KnownAbs.One.setBits(loBit: getBitWidth() - Tmp.countMinLeadingZeros(),
812 hiBit: getBitWidth() - 1);
813 }
814
815 } else {
816 unsigned MaxTZ = countMaxTrailingZeros();
817 unsigned MinTZ = countMinTrailingZeros();
818
819 KnownAbs.Zero.setLowBits(MinTZ);
820 // If we know the lowest set 1, then preserve it.
821 if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
822 KnownAbs.One.setBit(MaxTZ);
823
824 // We only know that the absolute values's MSB will be zero if INT_MIN is
825 // poison, or there is a set bit that isn't the sign bit (otherwise it could
826 // be INT_MIN).
827 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
828 KnownAbs.One.clearSignBit();
829 KnownAbs.Zero.setSignBit();
830 }
831 }
832
833 return KnownAbs;
834}
835
836KnownBits KnownBits::reduceAdd(unsigned NumElts) const {
837 if (NumElts == 0)
838 return KnownBits(getBitWidth());
839
840 unsigned BitWidth = getBitWidth();
841 KnownBits Result(BitWidth);
842
843 if (isConstant())
844 // If all elements are the same constant, we can simply compute it
845 return KnownBits::makeConstant(C: NumElts * getConstant());
846
847 // The main idea is as follows.
848 //
849 // If KnownBits for each element has L leading zeros then
850 // X_i < 2^(W - L) for every i from [1, N].
851 //
852 // ADD X_i <= ADD max(X_i) = N * max(X_i)
853 // < N * 2^(W - L)
854 // < 2^(W - L + ceil(log2(N)))
855 //
856 // As the result, we can conclude that
857 //
858 // L' = L - ceil(log2(N))
859 //
860 // Similar logic can be applied to leading ones.
861 unsigned LostBits = Log2_32_Ceil(Value: NumElts);
862
863 if (isNonNegative()) {
864 unsigned LeadingZeros = countMinLeadingZeros();
865 LeadingZeros = LeadingZeros > LostBits ? LeadingZeros - LostBits : 0;
866 Result.Zero.setHighBits(LeadingZeros);
867 } else if (isNegative()) {
868 unsigned LeadingOnes = countMinLeadingOnes();
869 LeadingOnes = LeadingOnes > LostBits ? LeadingOnes - LostBits : 0;
870 Result.One.setHighBits(LeadingOnes);
871 }
872
873 return Result;
874}
875
876static KnownBits computeForSatAddSub(bool Add, bool Signed,
877 const KnownBits &LHS,
878 const KnownBits &RHS) {
879 // We don't see NSW even for sadd/ssub as we want to check if the result has
880 // signed overflow.
881 unsigned BitWidth = LHS.getBitWidth();
882
883 std::optional<bool> Overflow;
884 // Even if we can't entirely rule out overflow, we may be able to rule out
885 // overflow in one direction. This allows us to potentially keep some of the
886 // add/sub bits. I.e if we can't overflow in the positive direction we won't
887 // clamp to INT_MAX so we can keep low 0s from the add/sub result.
888 bool MayNegClamp = true;
889 bool MayPosClamp = true;
890 if (Signed) {
891 // Easy cases we can rule out any overflow.
892 if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
893 (LHS.isNonNegative() && RHS.isNegative())))
894 Overflow = false;
895 else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
896 (LHS.isNonNegative() && RHS.isNonNegative()))))
897 Overflow = false;
898 else {
899 // Check if we may overflow. If we can't rule out overflow then check if
900 // we can rule out a direction at least.
901 KnownBits UnsignedLHS = LHS;
902 KnownBits UnsignedRHS = RHS;
903 // Get version of LHS/RHS with clearer signbit. This allows us to detect
904 // how the addition/subtraction might overflow into the signbit. Then
905 // using the actual known signbits of LHS/RHS, we can figure out which
906 // overflows are/aren't possible.
907 UnsignedLHS.One.clearSignBit();
908 UnsignedLHS.Zero.setSignBit();
909 UnsignedRHS.One.clearSignBit();
910 UnsignedRHS.Zero.setSignBit();
911 KnownBits Res =
912 KnownBits::computeForAddSub(Add, /*NSW=*/false,
913 /*NUW=*/false, LHS: UnsignedLHS, RHS: UnsignedRHS);
914 if (Add) {
915 if (Res.isNegative()) {
916 // Only overflow scenario is Pos + Pos.
917 MayNegClamp = false;
918 // Pos + Pos will overflow with extra signbit.
919 if (LHS.isNonNegative() && RHS.isNonNegative())
920 Overflow = true;
921 } else if (Res.isNonNegative()) {
922 // Only overflow scenario is Neg + Neg
923 MayPosClamp = false;
924 // Neg + Neg will overflow without extra signbit.
925 if (LHS.isNegative() && RHS.isNegative())
926 Overflow = true;
927 }
928 // We will never clamp to the opposite sign of N-bit result.
929 if (LHS.isNegative() || RHS.isNegative())
930 MayPosClamp = false;
931 if (LHS.isNonNegative() || RHS.isNonNegative())
932 MayNegClamp = false;
933 } else {
934 if (Res.isNegative()) {
935 // Only overflow scenario is Neg - Pos.
936 MayPosClamp = false;
937 // Neg - Pos will overflow with extra signbit.
938 if (LHS.isNegative() && RHS.isNonNegative())
939 Overflow = true;
940 } else if (Res.isNonNegative()) {
941 // Only overflow scenario is Pos - Neg.
942 MayNegClamp = false;
943 // Pos - Neg will overflow without extra signbit.
944 if (LHS.isNonNegative() && RHS.isNegative())
945 Overflow = true;
946 }
947 // We will never clamp to the opposite sign of N-bit result.
948 if (LHS.isNegative() || RHS.isNonNegative())
949 MayPosClamp = false;
950 if (LHS.isNonNegative() || RHS.isNegative())
951 MayNegClamp = false;
952 }
953 }
954 // If we have ruled out all clamping, we will never overflow.
955 if (!MayNegClamp && !MayPosClamp)
956 Overflow = false;
957 } else if (Add) {
958 // uadd.sat
959 bool Of;
960 (void)LHS.getMaxValue().uadd_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
961 if (!Of) {
962 Overflow = false;
963 } else {
964 (void)LHS.getMinValue().uadd_ov(RHS: RHS.getMinValue(), Overflow&: Of);
965 if (Of)
966 Overflow = true;
967 }
968 } else {
969 // usub.sat
970 bool Of;
971 (void)LHS.getMinValue().usub_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
972 if (!Of) {
973 Overflow = false;
974 } else {
975 (void)LHS.getMaxValue().usub_ov(RHS: RHS.getMinValue(), Overflow&: Of);
976 if (Of)
977 Overflow = true;
978 }
979 }
980
981 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
982 /*NUW=*/!Signed, LHS, RHS);
983
984 if (Overflow) {
985 // We know whether or not we overflowed.
986 if (!(*Overflow)) {
987 // No overflow.
988 return Res;
989 }
990
991 // We overflowed
992 APInt C;
993 if (Signed) {
994 // sadd.sat / ssub.sat
995 assert(!LHS.isSignUnknown() &&
996 "We somehow know overflow without knowing input sign");
997 C = LHS.isNegative() ? APInt::getSignedMinValue(numBits: BitWidth)
998 : APInt::getSignedMaxValue(numBits: BitWidth);
999 } else if (Add) {
1000 // uadd.sat
1001 C = APInt::getMaxValue(numBits: BitWidth);
1002 } else {
1003 // uadd.sat
1004 C = APInt::getMinValue(numBits: BitWidth);
1005 }
1006
1007 Res.One = C;
1008 Res.Zero = ~C;
1009 return Res;
1010 }
1011
1012 // We don't know if we overflowed.
1013 if (Signed) {
1014 // sadd.sat/ssub.sat
1015 // We can keep our information about the sign bits.
1016 if (MayPosClamp)
1017 Res.Zero.clearLowBits(loBits: BitWidth - 1);
1018 if (MayNegClamp)
1019 Res.One.clearLowBits(loBits: BitWidth - 1);
1020 } else if (Add) {
1021 // uadd.sat
1022 // We need to clear all the known zeros as we can only use the leading ones.
1023 Res.Zero.clearAllBits();
1024 } else {
1025 // usub.sat
1026 // We need to clear all the known ones as we can only use the leading zero.
1027 Res.One.clearAllBits();
1028 }
1029
1030 return Res;
1031}
1032
1033KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
1034 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
1035}
1036KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
1037 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
1038}
1039KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
1040 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
1041}
1042KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
1043 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
1044}
1045
1046static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
1047 unsigned BitWidth = LHS.getBitWidth();
1048 LHS = LHS.zext(BitWidth: BitWidth + 1);
1049 RHS = RHS.zext(BitWidth: BitWidth + 1);
1050 LHS =
1051 computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
1052 LHS = LHS.extractBits(NumBits: BitWidth, BitPosition: 1);
1053 return LHS;
1054}
1055
1056KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
1057 return flipSignBit(Val: avgFloorU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
1058}
1059
1060KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
1061 return avgComputeU(LHS, RHS, /*IsCeil=*/false);
1062}
1063
1064KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
1065 return flipSignBit(Val: avgCeilU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
1066}
1067
1068KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
1069 return avgComputeU(LHS, RHS, /*IsCeil=*/true);
1070}
1071
1072KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
1073 bool NoUndefSelfMultiply) {
1074 unsigned BitWidth = LHS.getBitWidth();
1075 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
1076 assert((!NoUndefSelfMultiply || LHS == RHS) &&
1077 "Self multiplication knownbits mismatch");
1078
1079 // Compute the high known-0 bits by multiplying the unsigned max of each side.
1080 // Conservatively, M active bits * N active bits results in M + N bits in the
1081 // result. But if we know a value is a power-of-2 for example, then this
1082 // computes one more leading zero.
1083 // TODO: This could be generalized to number of sign bits (negative numbers).
1084 APInt UMaxLHS = LHS.getMaxValue();
1085 APInt UMaxRHS = RHS.getMaxValue();
1086
1087 // For leading zeros in the result to be valid, the unsigned max product must
1088 // fit in the bitwidth (it must not overflow).
1089 bool HasOverflow;
1090 APInt UMaxResult = UMaxLHS.umul_ov(RHS: UMaxRHS, Overflow&: HasOverflow);
1091 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
1092
1093 // The result of the bottom bits of an integer multiply can be
1094 // inferred by looking at the bottom bits of both operands and
1095 // multiplying them together.
1096 // We can infer at least the minimum number of known trailing bits
1097 // of both operands. Depending on number of trailing zeros, we can
1098 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
1099 // a and b are divisible by m and n respectively.
1100 // We then calculate how many of those bits are inferrable and set
1101 // the output. For example, the i8 mul:
1102 // a = XXXX1100 (12)
1103 // b = XXXX1110 (14)
1104 // We know the bottom 3 bits are zero since the first can be divided by
1105 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
1106 // Applying the multiplication to the trimmed arguments gets:
1107 // XX11 (3)
1108 // X111 (7)
1109 // -------
1110 // XX11
1111 // XX11
1112 // XX11
1113 // XX11
1114 // -------
1115 // XXXXX01
1116 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
1117 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
1118 // The proof for this can be described as:
1119 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
1120 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
1121 // umin(countTrailingZeros(C2), C6) +
1122 // umin(C5 - umin(countTrailingZeros(C1), C5),
1123 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
1124 // %aa = shl i8 %a, C5
1125 // %bb = shl i8 %b, C6
1126 // %aaa = or i8 %aa, C1
1127 // %bbb = or i8 %bb, C2
1128 // %mul = mul i8 %aaa, %bbb
1129 // %mask = and i8 %mul, C7
1130 // =>
1131 // %mask = i8 ((C1*C2)&C7)
1132 // Where C5, C6 describe the known bits of %a, %b
1133 // C1, C2 describe the known bottom bits of %a, %b.
1134 // C7 describes the mask of the known bits of the result.
1135
1136 // How many times we'd be able to divide each argument by 2 (shr by 1).
1137 // This gives us the number of trailing zeros on the multiplication result.
1138 unsigned TrailBitsKnownLHS = (LHS.Zero | LHS.One).countr_one();
1139 unsigned TrailBitsKnownRHS = (RHS.Zero | RHS.One).countr_one();
1140 unsigned TrailZeroLHS = LHS.countMinTrailingZeros();
1141 unsigned TrailZeroRHS = RHS.countMinTrailingZeros();
1142 unsigned TrailZ = TrailZeroLHS + TrailZeroRHS;
1143
1144 // Figure out the fewest known-bits operand.
1145 unsigned SmallestOperand = std::min(a: TrailBitsKnownLHS - TrailZeroLHS,
1146 b: TrailBitsKnownRHS - TrailZeroRHS);
1147 unsigned ResultBitsKnown = std::min(a: SmallestOperand + TrailZ, b: BitWidth);
1148
1149 // The lower ResultBitsKnown bits of this are known.
1150 APInt BottomKnown = LHS.One * RHS.One;
1151
1152 KnownBits Res(BitWidth);
1153 Res.Zero.setHighBits(LeadZ);
1154 Res.Zero |= (~BottomKnown).getLoBits(numBits: ResultBitsKnown);
1155 Res.One = BottomKnown.getLoBits(numBits: ResultBitsKnown);
1156
1157 if (NoUndefSelfMultiply) {
1158 // If X has at least TZ trailing zeroes, then bit (2 * TZ + 1) must be zero.
1159 unsigned TwoTZP1 = 2 * TrailZeroLHS + 1;
1160 if (TwoTZP1 < BitWidth)
1161 Res.Zero.setBit(TwoTZP1);
1162
1163 // If X has exactly TZ trailing zeros, then bit (2 * TZ + 2) must also be
1164 // zero.
1165 if (TrailZeroLHS < BitWidth && LHS.One[TrailZeroLHS]) {
1166 unsigned TwoTZP2 = TwoTZP1 + 1;
1167 if (TwoTZP2 < BitWidth)
1168 Res.Zero.setBit(TwoTZP2);
1169 }
1170 }
1171
1172 return Res;
1173}
1174
1175KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
1176 unsigned BitWidth = LHS.getBitWidth();
1177 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
1178 KnownBits WideLHS = LHS.sext(BitWidth: 2 * BitWidth);
1179 KnownBits WideRHS = RHS.sext(BitWidth: 2 * BitWidth);
1180 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
1181}
1182
1183KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
1184 unsigned BitWidth = LHS.getBitWidth();
1185 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
1186 KnownBits WideLHS = LHS.zext(BitWidth: 2 * BitWidth);
1187 KnownBits WideRHS = RHS.zext(BitWidth: 2 * BitWidth);
1188 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
1189}
1190
1191static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
1192 const KnownBits &RHS, bool Exact) {
1193
1194 if (!Exact)
1195 return Known;
1196
1197 // If LHS is Odd, the result is Odd no matter what.
1198 // Odd / Odd -> Odd
1199 // Odd / Even -> Impossible (because its exact division)
1200 if (LHS.One[0])
1201 Known.One.setBit(0);
1202
1203 int MinTZ =
1204 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
1205 int MaxTZ =
1206 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
1207 if (MinTZ >= 0) {
1208 // Result has at least MinTZ trailing zeros.
1209 Known.Zero.setLowBits(MinTZ);
1210 if (MinTZ == MaxTZ) {
1211 // Result has exactly MinTZ trailing zeros.
1212 Known.One.setBit(MinTZ);
1213 }
1214 } else if (MaxTZ < 0) {
1215 // Poison Result
1216 Known.setAllZero();
1217 }
1218
1219 // In the KnownBits exhaustive tests, we have poison inputs for exact values
1220 // a LOT. If we have a conflict, just return all zeros.
1221 if (Known.hasConflict())
1222 Known.setAllZero();
1223
1224 return Known;
1225}
1226
1227KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
1228 bool Exact) {
1229 // Equivalent of `udiv`. We must have caught this before it was folded.
1230 if (LHS.isNonNegative() && RHS.isNonNegative())
1231 return udiv(LHS, RHS, Exact);
1232
1233 unsigned BitWidth = LHS.getBitWidth();
1234 KnownBits Known(BitWidth);
1235
1236 if (LHS.isZero() || RHS.isZero()) {
1237 // Result is either known Zero or UB. Return Zero either way.
1238 // Checking this earlier saves us a lot of special cases later on.
1239 Known.setAllZero();
1240 return Known;
1241 }
1242
1243 std::optional<APInt> Res;
1244 if (LHS.isNegative() && RHS.isNegative()) {
1245 // Result non-negative.
1246 APInt Denom = RHS.getSignedMaxValue();
1247 APInt Num = LHS.getSignedMinValue();
1248 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
1249 // as signed max (we will only set sign bit in the result).
1250 Res = (Num.isMinSignedValue() && Denom.isAllOnes())
1251 ? APInt::getSignedMaxValue(numBits: BitWidth)
1252 : Num.sdiv(RHS: Denom);
1253 } else if (LHS.isNegative() && RHS.isNonNegative()) {
1254 // Result is negative if Exact OR -LHS u>= RHS.
1255 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS: RHS.getSignedMaxValue())) {
1256 APInt Denom = RHS.getSignedMinValue();
1257 APInt Num = LHS.getSignedMinValue();
1258 Res = Denom.isZero() ? Num : Num.sdiv(RHS: Denom);
1259 }
1260 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
1261 // Result is negative if Exact OR LHS u>= -RHS.
1262 if (Exact || LHS.getSignedMinValue().uge(RHS: -RHS.getSignedMinValue())) {
1263 APInt Denom = RHS.getSignedMaxValue();
1264 APInt Num = LHS.getSignedMaxValue();
1265 Res = Num.sdiv(RHS: Denom);
1266 }
1267 }
1268
1269 if (Res) {
1270 if (Res->isNonNegative()) {
1271 unsigned LeadZ = Res->countLeadingZeros();
1272 Known.Zero.setHighBits(LeadZ);
1273 } else {
1274 unsigned LeadO = Res->countLeadingOnes();
1275 Known.One.setHighBits(LeadO);
1276 }
1277 }
1278
1279 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1280 return Known;
1281}
1282
1283KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1284 bool Exact) {
1285 unsigned BitWidth = LHS.getBitWidth();
1286 KnownBits Known(BitWidth);
1287
1288 if (LHS.isZero() || RHS.isZero()) {
1289 // Result is either known Zero or UB. Return Zero either way.
1290 // Checking this earlier saves us a lot of special cases later on.
1291 Known.setAllZero();
1292 return Known;
1293 }
1294
1295 // We can figure out the minimum number of upper zero bits by doing
1296 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1297 // gets larger, the number of upper zero bits increases.
1298 APInt MinDenom = RHS.getMinValue();
1299 APInt MaxNum = LHS.getMaxValue();
1300 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(RHS: MinDenom);
1301
1302 unsigned LeadZ = MaxRes.countLeadingZeros();
1303
1304 Known.Zero.setHighBits(LeadZ);
1305 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1306
1307 return Known;
1308}
1309
1310KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1311 unsigned BitWidth = LHS.getBitWidth();
1312 if (!RHS.isZero() && RHS.Zero[0]) {
1313 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1314 unsigned RHSZeros = RHS.countMinTrailingZeros();
1315 APInt Mask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: RHSZeros);
1316 APInt OnesMask = LHS.One & Mask;
1317 APInt ZerosMask = LHS.Zero & Mask;
1318 return KnownBits(ZerosMask, OnesMask);
1319 }
1320 return KnownBits(BitWidth);
1321}
1322
1323KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1324 KnownBits Known = remGetLowBits(LHS, RHS);
1325 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1326 // NB: Low bits set in `remGetLowBits`.
1327 APInt HighBits = ~(RHS.getConstant() - 1);
1328 Known.Zero |= std::move(HighBits);
1329 return Known;
1330 }
1331
1332 // Since the result is less than or equal to either operand, any leading
1333 // zero bits in either operand must also exist in the result.
1334 uint32_t Leaders =
1335 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingZeros());
1336 Known.Zero.setHighBits(Leaders);
1337 return Known;
1338}
1339
1340KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1341 KnownBits Known = remGetLowBits(LHS, RHS);
1342 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1343 // NB: Low bits are set in `remGetLowBits`.
1344 APInt LowBits = RHS.getConstant() - 1;
1345 // If the first operand is non-negative or has all low bits zero, then
1346 // the upper bits are all zero.
1347 if (LHS.isNonNegative() || LowBits.isSubsetOf(RHS: LHS.Zero))
1348 Known.Zero |= ~LowBits;
1349
1350 // If the first operand is negative and not all low bits are zero, then
1351 // the upper bits are all one.
1352 if (LHS.isNegative() && LowBits.intersects(RHS: LHS.One))
1353 Known.One |= ~LowBits;
1354 return Known;
1355 }
1356
1357 // The sign bit is the LHS's sign bit, except when the result of the
1358 // remainder is zero. The magnitude of the result should be less than or
1359 // equal to the magnitude of either operand.
1360 if (LHS.isNegative() && Known.isNonZero())
1361 Known.One.setHighBits(
1362 std::max(a: LHS.countMinLeadingOnes(), b: RHS.countMinSignBits()));
1363 else if (LHS.isNonNegative())
1364 Known.Zero.setHighBits(
1365 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinSignBits()));
1366 return Known;
1367}
1368
1369KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1370 // Result bit is 0 if either operand bit is 0.
1371 Zero |= RHS.Zero;
1372 // Result bit is 1 if both operand bits are 1.
1373 One &= RHS.One;
1374 return *this;
1375}
1376
1377KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1378 // Result bit is 0 if both operand bits are 0.
1379 Zero &= RHS.Zero;
1380 // Result bit is 1 if either operand bit is 1.
1381 One |= RHS.One;
1382 return *this;
1383}
1384
1385KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1386 // Result bit is 0 if both operand bits are 0 or both are 1.
1387 APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1388 // Result bit is 1 if one operand bit is 0 and the other is 1.
1389 One = (Zero & RHS.One) | (One & RHS.Zero);
1390 Zero = std::move(Z);
1391 return *this;
1392}
1393
1394KnownBits KnownBits::blsi() const {
1395 unsigned BitWidth = getBitWidth();
1396 KnownBits Known(Zero, APInt(BitWidth, 0));
1397 unsigned Max = countMaxTrailingZeros();
1398 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1399 unsigned Min = countMinTrailingZeros();
1400 if (Max == Min && Max < BitWidth)
1401 Known.One.setBit(Max);
1402 return Known;
1403}
1404
1405KnownBits KnownBits::blsmsk() const {
1406 unsigned BitWidth = getBitWidth();
1407 KnownBits Known(BitWidth);
1408 unsigned Max = countMaxTrailingZeros();
1409 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1410 unsigned Min = countMinTrailingZeros();
1411 Known.One.setLowBits(std::min(a: Min + 1, b: BitWidth));
1412 return Known;
1413}
1414
1415void KnownBits::print(raw_ostream &OS) const {
1416 unsigned BitWidth = getBitWidth();
1417 for (unsigned I = 0; I < BitWidth; ++I) {
1418 unsigned N = BitWidth - I - 1;
1419 if (Zero[N] && One[N])
1420 OS << "!";
1421 else if (Zero[N])
1422 OS << "0";
1423 else if (One[N])
1424 OS << "1";
1425 else
1426 OS << "?";
1427 }
1428}
1429
1430#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1431LLVM_DUMP_METHOD void KnownBits::dump() const {
1432 print(dbgs());
1433 dbgs() << "\n";
1434}
1435#endif
1436