1//===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a class for representing known zeros and ones used by
10// computeKnownBits.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/KnownBits.h"
15#include "llvm/ADT/Sequence.h"
16#include "llvm/Support/Debug.h"
17#include "llvm/Support/raw_ostream.h"
18#include <cassert>
19
20using namespace llvm;
21
22KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
23 unsigned SignBitPosition = Val.getBitWidth() - 1;
24 APInt Zero = Val.Zero;
25 APInt One = Val.One;
26 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
27 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
28 return KnownBits(Zero, One);
29}
30
31static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
32 bool CarryZero, bool CarryOne) {
33
34 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
35 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
36
37 // Compute known bits of the carry.
38 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
39 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
40
41 // Compute set of known bits (where all three relevant bits are known).
42 APInt LHSKnownUnion = LHS.Zero | LHS.One;
43 APInt RHSKnownUnion = RHS.Zero | RHS.One;
44 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
45 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
46
47 // Compute known bits of the result.
48 KnownBits KnownOut;
49 KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
50 KnownOut.One = std::move(PossibleSumOne) & Known;
51 return KnownOut;
52}
53
54KnownBits KnownBits::computeForAddCarry(
55 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
56 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
57 return ::computeForAddCarry(
58 LHS, RHS, CarryZero: Carry.Zero.getBoolValue(), CarryOne: Carry.One.getBoolValue());
59}
60
61KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
62 const KnownBits &LHS,
63 const KnownBits &RHS) {
64 unsigned BitWidth = LHS.getBitWidth();
65 KnownBits KnownOut(BitWidth);
66 // This can be a relatively expensive helper, so optimistically save some
67 // work.
68 if (LHS.isUnknown() && RHS.isUnknown())
69 return KnownOut;
70
71 if (!LHS.isUnknown() && !RHS.isUnknown()) {
72 if (Add) {
73 // Sum = LHS + RHS + 0
74 KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
75 /*CarryOne=*/false);
76 } else {
77 // Sum = LHS + ~RHS + 1
78 KnownBits NotRHS = RHS;
79 std::swap(a&: NotRHS.Zero, b&: NotRHS.One);
80 KnownOut = ::computeForAddCarry(LHS, RHS: NotRHS, /*CarryZero=*/false,
81 /*CarryOne=*/true);
82 }
83 }
84
85 // Handle add/sub given nsw and/or nuw.
86 if (NUW) {
87 if (Add) {
88 // (add nuw X, Y)
89 APInt MinVal = LHS.getMinValue().uadd_sat(RHS: RHS.getMinValue());
90 // None of the adds can end up overflowing, so min consecutive highbits
91 // in minimum possible of X + Y must all remain set.
92 if (NSW) {
93 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
94 // If we have NSW as well, we also know we can't overflow the signbit so
95 // can start counting from 1 bit back.
96 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
97 }
98 KnownOut.One.setHighBits(MinVal.countl_one());
99 } else {
100 // (sub nuw X, Y)
101 APInt MaxVal = LHS.getMaxValue().usub_sat(RHS: RHS.getMinValue());
102 // None of the subs can overflow at any point, so any common high bits
103 // will subtract away and result in zeros.
104 if (NSW) {
105 // If we have NSW as well, we also know we can't overflow the signbit so
106 // can start counting from 1 bit back.
107 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
108 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
109 }
110 KnownOut.Zero.setHighBits(MaxVal.countl_zero());
111 }
112 }
113
114 if (NSW) {
115 APInt MinVal;
116 APInt MaxVal;
117 if (Add) {
118 // (add nsw X, Y)
119 MinVal = LHS.getSignedMinValue().sadd_sat(RHS: RHS.getSignedMinValue());
120 MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS: RHS.getSignedMaxValue());
121 } else {
122 // (sub nsw X, Y)
123 MinVal = LHS.getSignedMinValue().ssub_sat(RHS: RHS.getSignedMaxValue());
124 MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS: RHS.getSignedMinValue());
125 }
126 if (MinVal.isNonNegative()) {
127 // If min is non-negative, result will always be non-neg (can't overflow
128 // around).
129 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
130 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
131 KnownOut.Zero.setSignBit();
132 }
133 if (MaxVal.isNegative()) {
134 // If max is negative, result will always be neg (can't overflow around).
135 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
136 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
137 KnownOut.One.setSignBit();
138 }
139 }
140
141 // Just return 0 if the nsw/nuw is violated and we have poison.
142 if (KnownOut.hasConflict())
143 KnownOut.setAllZero();
144 return KnownOut;
145}
146
147KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
148 const KnownBits &Borrow) {
149 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
150
151 // LHS - RHS = LHS + ~RHS + 1
152 // Carry 1 - Borrow in ::computeForAddCarry
153 std::swap(a&: RHS.Zero, b&: RHS.One);
154 return ::computeForAddCarry(LHS, RHS,
155 /*CarryZero=*/Borrow.One.getBoolValue(),
156 /*CarryOne=*/Borrow.Zero.getBoolValue());
157}
158
159KnownBits KnownBits::truncSSat(unsigned BitWidth) const {
160 unsigned InputBits = getBitWidth();
161 APInt MinInRange = APInt::getSignedMinValue(numBits: BitWidth).sext(width: InputBits);
162 APInt MaxInRange = APInt::getSignedMaxValue(numBits: BitWidth).sext(width: InputBits);
163 APInt InputMin = getSignedMinValue();
164 APInt InputMax = getSignedMaxValue();
165 KnownBits Known(BitWidth);
166
167 // Case 1: All values fit - just truncate
168 if (InputMin.sge(RHS: MinInRange) && InputMax.sle(RHS: MaxInRange)) {
169 Known = trunc(BitWidth);
170 }
171 // Case 2: All saturate to min
172 else if (InputMax.slt(RHS: MinInRange)) {
173 Known = KnownBits::makeConstant(C: APInt::getSignedMinValue(numBits: BitWidth));
174 }
175 // Case 3: All saturate to max
176 else if (InputMin.sgt(RHS: MaxInRange)) {
177 Known = KnownBits::makeConstant(C: APInt::getSignedMaxValue(numBits: BitWidth));
178 }
179 // Case 4: All non-negative, some fit, some saturate to max
180 else if (InputMin.isNonNegative()) {
181 // Output: truncated OR max saturation
182 // Max saturation has only sign bit as 0
183 Known.Zero = APInt(BitWidth, 0);
184 Known.Zero.setBit(BitWidth - 1); // Sign bit always 0
185 // Max saturation has all lower bits as 1, so preserve InputOneLower
186 Known.One = One.trunc(width: BitWidth);
187 Known.One.clearBit(BitPosition: BitWidth - 1); // Sign bit is 0, not 1
188 }
189 // Case 5: All negative, some fit, some saturate to min
190 else if (InputMax.isNegative()) {
191 // Output: truncated OR min saturation
192 // Min saturation has all lower bits as 0, so preserve InputZeroLower
193 Known.Zero = Zero.trunc(width: BitWidth);
194 Known.Zero.clearBit(BitPosition: BitWidth - 1); // Sign bit is 1, not 0
195 // Min saturation has only sign bit as 1
196 Known.One = APInt(BitWidth, 0);
197 Known.One.setBit(BitWidth - 1); // Sign bit always 1
198 }
199 // Case 6: Mixed positive and negative
200 else {
201 // Output: min saturation, truncated, or max saturation
202 APInt InputZeroLower = Zero.trunc(width: BitWidth);
203 APInt InputOneLower = One.trunc(width: BitWidth);
204 APInt MinSat = APInt::getSignedMinValue(numBits: BitWidth);
205 APInt MaxSat = APInt::getSignedMaxValue(numBits: BitWidth);
206 KnownBits MinSatKB = KnownBits::makeConstant(C: MinSat);
207 KnownBits MaxSatKB = KnownBits::makeConstant(C: MaxSat);
208 if (InputMax.sle(RHS: MaxInRange)) {
209 // Positive values fit, only negatives saturate to min
210 Known.Zero = InputZeroLower & MinSatKB.Zero;
211 Known.One = InputOneLower & MinSatKB.One;
212 } else if (InputMin.sge(RHS: MinInRange)) {
213 // Negative values fit, only positives saturate to max
214 Known.Zero = InputZeroLower & MaxSatKB.Zero;
215 Known.One = InputOneLower & MaxSatKB.One;
216 } else {
217 // Both positive and negative values might saturate
218 Known.Zero = InputZeroLower & MinSatKB.Zero & MaxSatKB.Zero;
219 Known.One = InputOneLower & MinSatKB.One & MaxSatKB.One;
220 }
221 }
222 return Known;
223}
224
225KnownBits KnownBits::truncSSatU(unsigned BitWidth) const {
226 unsigned InputBits = getBitWidth();
227 APInt MaxInRange = APInt::getAllOnes(numBits: BitWidth).zext(width: InputBits);
228 APInt InputMin = getSignedMinValue();
229 APInt InputMax = getSignedMaxValue();
230 KnownBits Known(BitWidth);
231
232 if (isNegative()) {
233 Known.setAllZero();
234 } else if (InputMin.isNonNegative() && InputMax.ule(RHS: MaxInRange)) {
235 Known = trunc(BitWidth);
236 } else if (InputMin.isNonNegative() && InputMin.ugt(RHS: MaxInRange)) {
237 Known.setAllOnes();
238 } else if (InputMin.isNonNegative()) {
239 // All non-negative but mixed: some fit, some saturate to all-ones
240 // No common zero bits (saturation is all-ones)
241 Known.Zero = APInt(BitWidth, 0);
242 // Common one bits: bits that are 1 in input stay 1
243 Known.One = One.trunc(width: BitWidth);
244 } else {
245 // Mixed: sign bit unknown
246 if (InputMax.ule(RHS: MaxInRange)) {
247 // Positive values all fit (no saturation to all-ones)
248 // Output: all-zeros (negatives) OR truncated (positives)
249 Known.Zero = Zero.trunc(width: BitWidth);
250 Known.One = APInt(BitWidth, 0);
251 } else {
252 // Positive values might saturate to all-ones
253 // Output: all-zeros OR truncated OR all-ones
254 // No common bits
255 Known.Zero = APInt(BitWidth, 0);
256 Known.One = APInt(BitWidth, 0);
257 }
258 }
259 return Known;
260}
261
262KnownBits KnownBits::truncUSat(unsigned BitWidth) const {
263 unsigned InputBits = getBitWidth();
264 APInt MaxInRange = APInt::getLowBitsSet(numBits: InputBits, loBitsSet: BitWidth);
265 APInt InputMax = getMaxValue();
266 APInt InputMin = getMinValue();
267 KnownBits Known(BitWidth);
268
269 if (InputMax.ule(RHS: MaxInRange)) {
270 Known = trunc(BitWidth);
271 } else if (InputMin.ugt(RHS: MaxInRange)) {
272 Known.setAllOnes();
273 } else {
274 Known.resetAll();
275 Known.One = One.trunc(width: BitWidth);
276 }
277 return Known;
278}
279
280KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
281 unsigned BitWidth = getBitWidth();
282 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
283 "Illegal sext-in-register");
284
285 if (SrcBitWidth == BitWidth)
286 return *this;
287
288 unsigned ExtBits = BitWidth - SrcBitWidth;
289 KnownBits Result;
290 Result.One = One << ExtBits;
291 Result.Zero = Zero << ExtBits;
292 Result.One.ashrInPlace(ShiftAmt: ExtBits);
293 Result.Zero.ashrInPlace(ShiftAmt: ExtBits);
294 return Result;
295}
296
297KnownBits KnownBits::makeGE(const APInt &Val) const {
298 // Count the number of leading bit positions where our underlying value is
299 // known to be less than or equal to Val.
300 unsigned N = (Zero | Val).countl_one();
301
302 // For each of those bit positions, if Val has a 1 in that bit then our
303 // underlying value must also have a 1.
304 APInt MaskedVal(Val);
305 MaskedVal.clearLowBits(loBits: getBitWidth() - N);
306 return KnownBits(Zero, One | MaskedVal);
307}
308
309KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
310 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
311 // RHS. Ideally our caller would already have spotted these cases and
312 // optimized away the umax operation, but we handle them here for
313 // completeness.
314 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
315 return LHS;
316 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
317 return RHS;
318
319 // If the result of the umax is LHS then it must be greater than or equal to
320 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
321 // are common to these two values are also known in the result.
322 KnownBits L = LHS.makeGE(Val: RHS.getMinValue());
323 KnownBits R = RHS.makeGE(Val: LHS.getMinValue());
324 return L.intersectWith(RHS: R);
325}
326
327KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
328 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
329 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
330 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
331}
332
333KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
334 return flipSignBit(Val: umax(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
335}
336
337KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
338 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
339 auto Flip = [](const KnownBits &Val) {
340 unsigned SignBitPosition = Val.getBitWidth() - 1;
341 APInt Zero = Val.One;
342 APInt One = Val.Zero;
343 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
344 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
345 return KnownBits(Zero, One);
346 };
347 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
348}
349
350KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
351 // If we know which argument is larger, return (sub LHS, RHS) or
352 // (sub RHS, LHS) directly.
353 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
354 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
355 RHS);
356 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
357 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
358 RHS: LHS);
359
360 // By construction, the subtraction in abdu never has unsigned overflow.
361 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
362 KnownBits Diff0 =
363 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
364 KnownBits Diff1 =
365 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
366 return Diff0.intersectWith(RHS: Diff1);
367}
368
369KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
370 // If we know which argument is larger, return (sub LHS, RHS) or
371 // (sub RHS, LHS) directly.
372 if (LHS.getSignedMinValue().sge(RHS: RHS.getSignedMaxValue()))
373 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
374 RHS);
375 if (RHS.getSignedMinValue().sge(RHS: LHS.getSignedMaxValue()))
376 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
377 RHS: LHS);
378
379 // Shift both arguments from the signed range to the unsigned range, e.g. from
380 // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
381 // abdu does.
382 // Note that we can't just use "sub nsw" instead because abds has signed
383 // inputs but an unsigned result, which makes the overflow conditions
384 // different.
385 unsigned SignBitPosition = LHS.getBitWidth() - 1;
386 for (auto Arg : {&LHS, &RHS}) {
387 bool Tmp = Arg->Zero[SignBitPosition];
388 Arg->Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Arg->One[SignBitPosition]);
389 Arg->One.setBitVal(BitPosition: SignBitPosition, BitValue: Tmp);
390 }
391
392 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
393 KnownBits Diff0 =
394 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
395 KnownBits Diff1 =
396 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
397 return Diff0.intersectWith(RHS: Diff1);
398}
399
400static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
401 if (isPowerOf2_32(Value: BitWidth))
402 return MaxValue.extractBitsAsZExtValue(numBits: Log2_32(Value: BitWidth), bitPosition: 0);
403 // This is only an approximate upper bound.
404 return MaxValue.getLimitedValue(Limit: BitWidth - 1);
405}
406
407KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
408 bool NSW, bool ShAmtNonZero) {
409 unsigned BitWidth = LHS.getBitWidth();
410 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
411 KnownBits Known;
412 bool ShiftedOutZero, ShiftedOutOne;
413 Known.Zero = LHS.Zero.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutZero);
414 Known.Zero.setLowBits(ShiftAmt);
415 Known.One = LHS.One.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutOne);
416
417 // All cases returning poison have been handled by MaxShiftAmount already.
418 if (NSW) {
419 if (NUW && ShiftAmt != 0)
420 // NUW means we can assume anything shifted out was a zero.
421 ShiftedOutZero = true;
422
423 if (ShiftedOutZero)
424 Known.makeNonNegative();
425 else if (ShiftedOutOne)
426 Known.makeNegative();
427 }
428 return Known;
429 };
430
431 // Fast path for a common case when LHS is completely unknown.
432 KnownBits Known(BitWidth);
433 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
434 if (MinShiftAmount == 0 && ShAmtNonZero)
435 MinShiftAmount = 1;
436 if (LHS.isUnknown()) {
437 Known.Zero.setLowBits(MinShiftAmount);
438 if (NUW && NSW && MinShiftAmount != 0)
439 Known.makeNonNegative();
440 return Known;
441 }
442
443 // Determine maximum shift amount, taking NUW/NSW flags into account.
444 APInt MaxValue = RHS.getMaxValue();
445 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
446 if (NUW && NSW)
447 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros() - 1);
448 if (NUW)
449 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros());
450 if (NSW)
451 MaxShiftAmount = std::min(
452 a: MaxShiftAmount,
453 b: std::max(a: LHS.countMaxLeadingZeros(), b: LHS.countMaxLeadingOnes()) - 1);
454
455 // Fast path for common case where the shift amount is unknown.
456 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
457 isPowerOf2_32(Value: BitWidth)) {
458 Known.Zero.setLowBits(LHS.countMinTrailingZeros());
459 if (LHS.isAllOnes())
460 Known.One.setSignBit();
461 if (NSW) {
462 if (LHS.isNonNegative())
463 Known.makeNonNegative();
464 if (LHS.isNegative())
465 Known.makeNegative();
466 }
467 return Known;
468 }
469
470 // Find the common bits from all possible shifts.
471 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
472 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
473 Known.setAllConflict();
474 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
475 ++ShiftAmt) {
476 // Skip if the shift amount is impossible.
477 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
478 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
479 continue;
480 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
481 if (Known.isUnknown())
482 break;
483 }
484
485 // All shift amounts may result in poison.
486 if (Known.hasConflict())
487 Known.setAllZero();
488 return Known;
489}
490
491KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
492 bool ShAmtNonZero, bool Exact) {
493 unsigned BitWidth = LHS.getBitWidth();
494 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
495 KnownBits Known = LHS;
496 Known >>= ShiftAmt;
497 // High bits are known zero.
498 Known.Zero.setHighBits(ShiftAmt);
499 return Known;
500 };
501
502 // Fast path for a common case when LHS is completely unknown.
503 KnownBits Known(BitWidth);
504 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
505 if (MinShiftAmount == 0 && ShAmtNonZero)
506 MinShiftAmount = 1;
507 if (LHS.isUnknown()) {
508 Known.Zero.setHighBits(MinShiftAmount);
509 return Known;
510 }
511
512 // Find the common bits from all possible shifts.
513 APInt MaxValue = RHS.getMaxValue();
514 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
515
516 // If exact, bound MaxShiftAmount to first known 1 in LHS.
517 if (Exact) {
518 unsigned FirstOne = LHS.countMaxTrailingZeros();
519 if (FirstOne < MinShiftAmount) {
520 // Always poison. Return zero because we don't like returning conflict.
521 Known.setAllZero();
522 return Known;
523 }
524 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
525 }
526
527 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
528 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
529 Known.setAllConflict();
530 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
531 ++ShiftAmt) {
532 // Skip if the shift amount is impossible.
533 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
534 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
535 continue;
536 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
537 if (Known.isUnknown())
538 break;
539 }
540
541 // All shift amounts may result in poison.
542 if (Known.hasConflict())
543 Known.setAllZero();
544 return Known;
545}
546
547KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
548 bool ShAmtNonZero, bool Exact) {
549 unsigned BitWidth = LHS.getBitWidth();
550 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
551 KnownBits Known = LHS;
552 Known.Zero.ashrInPlace(ShiftAmt);
553 Known.One.ashrInPlace(ShiftAmt);
554 return Known;
555 };
556
557 // Fast path for a common case when LHS is completely unknown.
558 KnownBits Known(BitWidth);
559 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
560 if (MinShiftAmount == 0 && ShAmtNonZero)
561 MinShiftAmount = 1;
562 if (LHS.isUnknown()) {
563 if (MinShiftAmount == BitWidth) {
564 // Always poison. Return zero because we don't like returning conflict.
565 Known.setAllZero();
566 return Known;
567 }
568 return Known;
569 }
570
571 // Find the common bits from all possible shifts.
572 APInt MaxValue = RHS.getMaxValue();
573 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
574
575 // If exact, bound MaxShiftAmount to first known 1 in LHS.
576 if (Exact) {
577 unsigned FirstOne = LHS.countMaxTrailingZeros();
578 if (FirstOne < MinShiftAmount) {
579 // Always poison. Return zero because we don't like returning conflict.
580 Known.setAllZero();
581 return Known;
582 }
583 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
584 }
585
586 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
587 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
588 Known.setAllConflict();
589 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
590 ++ShiftAmt) {
591 // Skip if the shift amount is impossible.
592 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
593 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
594 continue;
595 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
596 if (Known.isUnknown())
597 break;
598 }
599
600 // All shift amounts may result in poison.
601 if (Known.hasConflict())
602 Known.setAllZero();
603 return Known;
604}
605
606KnownBits KnownBits::clmul(const KnownBits &LHS, const KnownBits &RHS) {
607 KnownBits Res =
608 makeConstant(C: APIntOps::clmul(LHS: LHS.getMinValue(), RHS: RHS.getMinValue()));
609
610 // This is the same operation as clmul except it accumulates the result with
611 // an OR instead of an XOR.
612 auto ClMulOr = [](const APInt &LHS, const APInt &RHS) {
613 unsigned BW = LHS.getBitWidth();
614 assert(BW == RHS.getBitWidth() && "Operand mismatch");
615 APInt Result(BW, 0);
616 for (unsigned I :
617 seq(Size: std::min(a: RHS.getActiveBits(), b: BW - LHS.countr_zero())))
618 if (RHS[I])
619 Result |= LHS << I;
620 return Result;
621 };
622
623 // Bits in the result are known if, for every corresponding pair of input
624 // bits, both input bits are known or either input bit is known to be zero.
625 APInt Known = ~(ClMulOr(~LHS.Zero & ~LHS.One, ~RHS.Zero) |
626 ClMulOr(~LHS.Zero, ~RHS.Zero & ~RHS.One));
627 Res.Zero &= Known;
628 Res.One &= Known;
629
630 return Res;
631}
632
633std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
634 if (LHS.isConstant() && RHS.isConstant())
635 return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
636 if (LHS.One.intersects(RHS: RHS.Zero) || RHS.One.intersects(RHS: LHS.Zero))
637 return std::optional<bool>(false);
638 return std::nullopt;
639}
640
641std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
642 if (std::optional<bool> KnownEQ = eq(LHS, RHS))
643 return std::optional<bool>(!*KnownEQ);
644 return std::nullopt;
645}
646
647std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
648 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
649 if (LHS.getMaxValue().ule(RHS: RHS.getMinValue()))
650 return std::optional<bool>(false);
651 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
652 if (LHS.getMinValue().ugt(RHS: RHS.getMaxValue()))
653 return std::optional<bool>(true);
654 return std::nullopt;
655}
656
657std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
658 if (std::optional<bool> IsUGT = ugt(LHS: RHS, RHS: LHS))
659 return std::optional<bool>(!*IsUGT);
660 return std::nullopt;
661}
662
663std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
664 return ugt(LHS: RHS, RHS: LHS);
665}
666
667std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
668 return uge(LHS: RHS, RHS: LHS);
669}
670
671std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
672 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
673 if (LHS.getSignedMaxValue().sle(RHS: RHS.getSignedMinValue()))
674 return std::optional<bool>(false);
675 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
676 if (LHS.getSignedMinValue().sgt(RHS: RHS.getSignedMaxValue()))
677 return std::optional<bool>(true);
678 return std::nullopt;
679}
680
681std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
682 if (std::optional<bool> KnownSGT = sgt(LHS: RHS, RHS: LHS))
683 return std::optional<bool>(!*KnownSGT);
684 return std::nullopt;
685}
686
687std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
688 return sgt(LHS: RHS, RHS: LHS);
689}
690
691std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
692 return sge(LHS: RHS, RHS: LHS);
693}
694
695KnownBits KnownBits::abs(bool IntMinIsPoison) const {
696 // If the source's MSB is zero then we know the rest of the bits already.
697 if (isNonNegative())
698 return *this;
699
700 // Absolute value preserves trailing zero count.
701 KnownBits KnownAbs(getBitWidth());
702
703 // If the input is negative, then abs(x) == -x.
704 if (isNegative()) {
705 KnownBits Tmp = *this;
706 // Special case for IntMinIsPoison. We know the sign bit is set and we know
707 // all the rest of the bits except one to be zero. Since we have
708 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
709 // INT_MIN.
710 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
711 Tmp.One.setBit(countMinTrailingZeros());
712
713 KnownAbs = computeForAddSub(
714 /*Add*/ false, NSW: IntMinIsPoison, /*NUW=*/false,
715 LHS: KnownBits::makeConstant(C: APInt(getBitWidth(), 0)), RHS: Tmp);
716
717 // One more special case for IntMinIsPoison. If we don't know any ones other
718 // than the signbit, we know for certain that all the unknowns can't be
719 // zero. So if we know high zero bits, but have unknown low bits, we know
720 // for certain those high-zero bits will end up as one. This is because,
721 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
722 // to the high bits. If we know a known INT_MIN input skip this. The result
723 // is poison anyways.
724 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
725 Tmp.countMaxPopulation() != 1) {
726 Tmp.One.clearSignBit();
727 Tmp.Zero.setSignBit();
728 KnownAbs.One.setBits(loBit: getBitWidth() - Tmp.countMinLeadingZeros(),
729 hiBit: getBitWidth() - 1);
730 }
731
732 } else {
733 unsigned MaxTZ = countMaxTrailingZeros();
734 unsigned MinTZ = countMinTrailingZeros();
735
736 KnownAbs.Zero.setLowBits(MinTZ);
737 // If we know the lowest set 1, then preserve it.
738 if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
739 KnownAbs.One.setBit(MaxTZ);
740
741 // We only know that the absolute values's MSB will be zero if INT_MIN is
742 // poison, or there is a set bit that isn't the sign bit (otherwise it could
743 // be INT_MIN).
744 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
745 KnownAbs.One.clearSignBit();
746 KnownAbs.Zero.setSignBit();
747 }
748 }
749
750 return KnownAbs;
751}
752
753KnownBits KnownBits::reduceAdd(unsigned NumElts) const {
754 if (NumElts == 0)
755 return KnownBits(getBitWidth());
756
757 unsigned BitWidth = getBitWidth();
758 KnownBits Result(BitWidth);
759
760 if (isConstant())
761 // If all elements are the same constant, we can simply compute it
762 return KnownBits::makeConstant(C: NumElts * getConstant());
763
764 // The main idea is as follows.
765 //
766 // If KnownBits for each element has L leading zeros then
767 // X_i < 2^(W - L) for every i from [1, N].
768 //
769 // ADD X_i <= ADD max(X_i) = N * max(X_i)
770 // < N * 2^(W - L)
771 // < 2^(W - L + ceil(log2(N)))
772 //
773 // As the result, we can conclude that
774 //
775 // L' = L - ceil(log2(N))
776 //
777 // Similar logic can be applied to leading ones.
778 unsigned LostBits = Log2_32_Ceil(Value: NumElts);
779
780 if (isNonNegative()) {
781 unsigned LeadingZeros = countMinLeadingZeros();
782 LeadingZeros = LeadingZeros > LostBits ? LeadingZeros - LostBits : 0;
783 Result.Zero.setHighBits(LeadingZeros);
784 } else if (isNegative()) {
785 unsigned LeadingOnes = countMinLeadingOnes();
786 LeadingOnes = LeadingOnes > LostBits ? LeadingOnes - LostBits : 0;
787 Result.One.setHighBits(LeadingOnes);
788 }
789
790 return Result;
791}
792
793static KnownBits computeForSatAddSub(bool Add, bool Signed,
794 const KnownBits &LHS,
795 const KnownBits &RHS) {
796 // We don't see NSW even for sadd/ssub as we want to check if the result has
797 // signed overflow.
798 unsigned BitWidth = LHS.getBitWidth();
799
800 std::optional<bool> Overflow;
801 // Even if we can't entirely rule out overflow, we may be able to rule out
802 // overflow in one direction. This allows us to potentially keep some of the
803 // add/sub bits. I.e if we can't overflow in the positive direction we won't
804 // clamp to INT_MAX so we can keep low 0s from the add/sub result.
805 bool MayNegClamp = true;
806 bool MayPosClamp = true;
807 if (Signed) {
808 // Easy cases we can rule out any overflow.
809 if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
810 (LHS.isNonNegative() && RHS.isNegative())))
811 Overflow = false;
812 else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
813 (LHS.isNonNegative() && RHS.isNonNegative()))))
814 Overflow = false;
815 else {
816 // Check if we may overflow. If we can't rule out overflow then check if
817 // we can rule out a direction at least.
818 KnownBits UnsignedLHS = LHS;
819 KnownBits UnsignedRHS = RHS;
820 // Get version of LHS/RHS with clearer signbit. This allows us to detect
821 // how the addition/subtraction might overflow into the signbit. Then
822 // using the actual known signbits of LHS/RHS, we can figure out which
823 // overflows are/aren't possible.
824 UnsignedLHS.One.clearSignBit();
825 UnsignedLHS.Zero.setSignBit();
826 UnsignedRHS.One.clearSignBit();
827 UnsignedRHS.Zero.setSignBit();
828 KnownBits Res =
829 KnownBits::computeForAddSub(Add, /*NSW=*/false,
830 /*NUW=*/false, LHS: UnsignedLHS, RHS: UnsignedRHS);
831 if (Add) {
832 if (Res.isNegative()) {
833 // Only overflow scenario is Pos + Pos.
834 MayNegClamp = false;
835 // Pos + Pos will overflow with extra signbit.
836 if (LHS.isNonNegative() && RHS.isNonNegative())
837 Overflow = true;
838 } else if (Res.isNonNegative()) {
839 // Only overflow scenario is Neg + Neg
840 MayPosClamp = false;
841 // Neg + Neg will overflow without extra signbit.
842 if (LHS.isNegative() && RHS.isNegative())
843 Overflow = true;
844 }
845 // We will never clamp to the opposite sign of N-bit result.
846 if (LHS.isNegative() || RHS.isNegative())
847 MayPosClamp = false;
848 if (LHS.isNonNegative() || RHS.isNonNegative())
849 MayNegClamp = false;
850 } else {
851 if (Res.isNegative()) {
852 // Only overflow scenario is Neg - Pos.
853 MayPosClamp = false;
854 // Neg - Pos will overflow with extra signbit.
855 if (LHS.isNegative() && RHS.isNonNegative())
856 Overflow = true;
857 } else if (Res.isNonNegative()) {
858 // Only overflow scenario is Pos - Neg.
859 MayNegClamp = false;
860 // Pos - Neg will overflow without extra signbit.
861 if (LHS.isNonNegative() && RHS.isNegative())
862 Overflow = true;
863 }
864 // We will never clamp to the opposite sign of N-bit result.
865 if (LHS.isNegative() || RHS.isNonNegative())
866 MayPosClamp = false;
867 if (LHS.isNonNegative() || RHS.isNegative())
868 MayNegClamp = false;
869 }
870 }
871 // If we have ruled out all clamping, we will never overflow.
872 if (!MayNegClamp && !MayPosClamp)
873 Overflow = false;
874 } else if (Add) {
875 // uadd.sat
876 bool Of;
877 (void)LHS.getMaxValue().uadd_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
878 if (!Of) {
879 Overflow = false;
880 } else {
881 (void)LHS.getMinValue().uadd_ov(RHS: RHS.getMinValue(), Overflow&: Of);
882 if (Of)
883 Overflow = true;
884 }
885 } else {
886 // usub.sat
887 bool Of;
888 (void)LHS.getMinValue().usub_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
889 if (!Of) {
890 Overflow = false;
891 } else {
892 (void)LHS.getMaxValue().usub_ov(RHS: RHS.getMinValue(), Overflow&: Of);
893 if (Of)
894 Overflow = true;
895 }
896 }
897
898 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
899 /*NUW=*/!Signed, LHS, RHS);
900
901 if (Overflow) {
902 // We know whether or not we overflowed.
903 if (!(*Overflow)) {
904 // No overflow.
905 return Res;
906 }
907
908 // We overflowed
909 APInt C;
910 if (Signed) {
911 // sadd.sat / ssub.sat
912 assert(!LHS.isSignUnknown() &&
913 "We somehow know overflow without knowing input sign");
914 C = LHS.isNegative() ? APInt::getSignedMinValue(numBits: BitWidth)
915 : APInt::getSignedMaxValue(numBits: BitWidth);
916 } else if (Add) {
917 // uadd.sat
918 C = APInt::getMaxValue(numBits: BitWidth);
919 } else {
920 // uadd.sat
921 C = APInt::getMinValue(numBits: BitWidth);
922 }
923
924 Res.One = C;
925 Res.Zero = ~C;
926 return Res;
927 }
928
929 // We don't know if we overflowed.
930 if (Signed) {
931 // sadd.sat/ssub.sat
932 // We can keep our information about the sign bits.
933 if (MayPosClamp)
934 Res.Zero.clearLowBits(loBits: BitWidth - 1);
935 if (MayNegClamp)
936 Res.One.clearLowBits(loBits: BitWidth - 1);
937 } else if (Add) {
938 // uadd.sat
939 // We need to clear all the known zeros as we can only use the leading ones.
940 Res.Zero.clearAllBits();
941 } else {
942 // usub.sat
943 // We need to clear all the known ones as we can only use the leading zero.
944 Res.One.clearAllBits();
945 }
946
947 return Res;
948}
949
950KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
951 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
952}
953KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
954 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
955}
956KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
957 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
958}
959KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
960 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
961}
962
963static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
964 unsigned BitWidth = LHS.getBitWidth();
965 LHS = LHS.zext(BitWidth: BitWidth + 1);
966 RHS = RHS.zext(BitWidth: BitWidth + 1);
967 LHS =
968 computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
969 LHS = LHS.extractBits(NumBits: BitWidth, BitPosition: 1);
970 return LHS;
971}
972
973KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
974 return flipSignBit(Val: avgFloorU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
975}
976
977KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
978 return avgComputeU(LHS, RHS, /*IsCeil=*/false);
979}
980
981KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
982 return flipSignBit(Val: avgCeilU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
983}
984
985KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
986 return avgComputeU(LHS, RHS, /*IsCeil=*/true);
987}
988
989KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
990 bool NoUndefSelfMultiply) {
991 unsigned BitWidth = LHS.getBitWidth();
992 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
993 assert((!NoUndefSelfMultiply || LHS == RHS) &&
994 "Self multiplication knownbits mismatch");
995
996 // Compute the high known-0 bits by multiplying the unsigned max of each side.
997 // Conservatively, M active bits * N active bits results in M + N bits in the
998 // result. But if we know a value is a power-of-2 for example, then this
999 // computes one more leading zero.
1000 // TODO: This could be generalized to number of sign bits (negative numbers).
1001 APInt UMaxLHS = LHS.getMaxValue();
1002 APInt UMaxRHS = RHS.getMaxValue();
1003
1004 // For leading zeros in the result to be valid, the unsigned max product must
1005 // fit in the bitwidth (it must not overflow).
1006 bool HasOverflow;
1007 APInt UMaxResult = UMaxLHS.umul_ov(RHS: UMaxRHS, Overflow&: HasOverflow);
1008 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
1009
1010 // The result of the bottom bits of an integer multiply can be
1011 // inferred by looking at the bottom bits of both operands and
1012 // multiplying them together.
1013 // We can infer at least the minimum number of known trailing bits
1014 // of both operands. Depending on number of trailing zeros, we can
1015 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
1016 // a and b are divisible by m and n respectively.
1017 // We then calculate how many of those bits are inferrable and set
1018 // the output. For example, the i8 mul:
1019 // a = XXXX1100 (12)
1020 // b = XXXX1110 (14)
1021 // We know the bottom 3 bits are zero since the first can be divided by
1022 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
1023 // Applying the multiplication to the trimmed arguments gets:
1024 // XX11 (3)
1025 // X111 (7)
1026 // -------
1027 // XX11
1028 // XX11
1029 // XX11
1030 // XX11
1031 // -------
1032 // XXXXX01
1033 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
1034 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
1035 // The proof for this can be described as:
1036 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
1037 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
1038 // umin(countTrailingZeros(C2), C6) +
1039 // umin(C5 - umin(countTrailingZeros(C1), C5),
1040 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
1041 // %aa = shl i8 %a, C5
1042 // %bb = shl i8 %b, C6
1043 // %aaa = or i8 %aa, C1
1044 // %bbb = or i8 %bb, C2
1045 // %mul = mul i8 %aaa, %bbb
1046 // %mask = and i8 %mul, C7
1047 // =>
1048 // %mask = i8 ((C1*C2)&C7)
1049 // Where C5, C6 describe the known bits of %a, %b
1050 // C1, C2 describe the known bottom bits of %a, %b.
1051 // C7 describes the mask of the known bits of the result.
1052
1053 // How many times we'd be able to divide each argument by 2 (shr by 1).
1054 // This gives us the number of trailing zeros on the multiplication result.
1055 unsigned TrailBitsKnownLHS = (LHS.Zero | LHS.One).countr_one();
1056 unsigned TrailBitsKnownRHS = (RHS.Zero | RHS.One).countr_one();
1057 unsigned TrailZeroLHS = LHS.countMinTrailingZeros();
1058 unsigned TrailZeroRHS = RHS.countMinTrailingZeros();
1059 unsigned TrailZ = TrailZeroLHS + TrailZeroRHS;
1060
1061 // Figure out the fewest known-bits operand.
1062 unsigned SmallestOperand = std::min(a: TrailBitsKnownLHS - TrailZeroLHS,
1063 b: TrailBitsKnownRHS - TrailZeroRHS);
1064 unsigned ResultBitsKnown = std::min(a: SmallestOperand + TrailZ, b: BitWidth);
1065
1066 // The lower ResultBitsKnown bits of this are known.
1067 APInt BottomKnown = LHS.One * RHS.One;
1068
1069 KnownBits Res(BitWidth);
1070 Res.Zero.setHighBits(LeadZ);
1071 Res.Zero |= (~BottomKnown).getLoBits(numBits: ResultBitsKnown);
1072 Res.One = BottomKnown.getLoBits(numBits: ResultBitsKnown);
1073
1074 if (NoUndefSelfMultiply) {
1075 // If X has at least TZ trailing zeroes, then bit (2 * TZ + 1) must be zero.
1076 unsigned TwoTZP1 = 2 * TrailZeroLHS + 1;
1077 if (TwoTZP1 < BitWidth)
1078 Res.Zero.setBit(TwoTZP1);
1079
1080 // If X has exactly TZ trailing zeros, then bit (2 * TZ + 2) must also be
1081 // zero.
1082 if (TrailZeroLHS < BitWidth && LHS.One[TrailZeroLHS]) {
1083 unsigned TwoTZP2 = TwoTZP1 + 1;
1084 if (TwoTZP2 < BitWidth)
1085 Res.Zero.setBit(TwoTZP2);
1086 }
1087 }
1088
1089 return Res;
1090}
1091
1092KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
1093 unsigned BitWidth = LHS.getBitWidth();
1094 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
1095 KnownBits WideLHS = LHS.sext(BitWidth: 2 * BitWidth);
1096 KnownBits WideRHS = RHS.sext(BitWidth: 2 * BitWidth);
1097 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
1098}
1099
1100KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
1101 unsigned BitWidth = LHS.getBitWidth();
1102 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
1103 KnownBits WideLHS = LHS.zext(BitWidth: 2 * BitWidth);
1104 KnownBits WideRHS = RHS.zext(BitWidth: 2 * BitWidth);
1105 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
1106}
1107
1108static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
1109 const KnownBits &RHS, bool Exact) {
1110
1111 if (!Exact)
1112 return Known;
1113
1114 // If LHS is Odd, the result is Odd no matter what.
1115 // Odd / Odd -> Odd
1116 // Odd / Even -> Impossible (because its exact division)
1117 if (LHS.One[0])
1118 Known.One.setBit(0);
1119
1120 int MinTZ =
1121 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
1122 int MaxTZ =
1123 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
1124 if (MinTZ >= 0) {
1125 // Result has at least MinTZ trailing zeros.
1126 Known.Zero.setLowBits(MinTZ);
1127 if (MinTZ == MaxTZ) {
1128 // Result has exactly MinTZ trailing zeros.
1129 Known.One.setBit(MinTZ);
1130 }
1131 } else if (MaxTZ < 0) {
1132 // Poison Result
1133 Known.setAllZero();
1134 }
1135
1136 // In the KnownBits exhaustive tests, we have poison inputs for exact values
1137 // a LOT. If we have a conflict, just return all zeros.
1138 if (Known.hasConflict())
1139 Known.setAllZero();
1140
1141 return Known;
1142}
1143
1144KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
1145 bool Exact) {
1146 // Equivalent of `udiv`. We must have caught this before it was folded.
1147 if (LHS.isNonNegative() && RHS.isNonNegative())
1148 return udiv(LHS, RHS, Exact);
1149
1150 unsigned BitWidth = LHS.getBitWidth();
1151 KnownBits Known(BitWidth);
1152
1153 if (LHS.isZero() || RHS.isZero()) {
1154 // Result is either known Zero or UB. Return Zero either way.
1155 // Checking this earlier saves us a lot of special cases later on.
1156 Known.setAllZero();
1157 return Known;
1158 }
1159
1160 std::optional<APInt> Res;
1161 if (LHS.isNegative() && RHS.isNegative()) {
1162 // Result non-negative.
1163 APInt Denom = RHS.getSignedMaxValue();
1164 APInt Num = LHS.getSignedMinValue();
1165 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
1166 // as signed max (we will only set sign bit in the result).
1167 Res = (Num.isMinSignedValue() && Denom.isAllOnes())
1168 ? APInt::getSignedMaxValue(numBits: BitWidth)
1169 : Num.sdiv(RHS: Denom);
1170 } else if (LHS.isNegative() && RHS.isNonNegative()) {
1171 // Result is negative if Exact OR -LHS u>= RHS.
1172 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS: RHS.getSignedMaxValue())) {
1173 APInt Denom = RHS.getSignedMinValue();
1174 APInt Num = LHS.getSignedMinValue();
1175 Res = Denom.isZero() ? Num : Num.sdiv(RHS: Denom);
1176 }
1177 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
1178 // Result is negative if Exact OR LHS u>= -RHS.
1179 if (Exact || LHS.getSignedMinValue().uge(RHS: -RHS.getSignedMinValue())) {
1180 APInt Denom = RHS.getSignedMaxValue();
1181 APInt Num = LHS.getSignedMaxValue();
1182 Res = Num.sdiv(RHS: Denom);
1183 }
1184 }
1185
1186 if (Res) {
1187 if (Res->isNonNegative()) {
1188 unsigned LeadZ = Res->countLeadingZeros();
1189 Known.Zero.setHighBits(LeadZ);
1190 } else {
1191 unsigned LeadO = Res->countLeadingOnes();
1192 Known.One.setHighBits(LeadO);
1193 }
1194 }
1195
1196 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1197 return Known;
1198}
1199
1200KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1201 bool Exact) {
1202 unsigned BitWidth = LHS.getBitWidth();
1203 KnownBits Known(BitWidth);
1204
1205 if (LHS.isZero() || RHS.isZero()) {
1206 // Result is either known Zero or UB. Return Zero either way.
1207 // Checking this earlier saves us a lot of special cases later on.
1208 Known.setAllZero();
1209 return Known;
1210 }
1211
1212 // We can figure out the minimum number of upper zero bits by doing
1213 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1214 // gets larger, the number of upper zero bits increases.
1215 APInt MinDenom = RHS.getMinValue();
1216 APInt MaxNum = LHS.getMaxValue();
1217 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(RHS: MinDenom);
1218
1219 unsigned LeadZ = MaxRes.countLeadingZeros();
1220
1221 Known.Zero.setHighBits(LeadZ);
1222 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1223
1224 return Known;
1225}
1226
1227KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1228 unsigned BitWidth = LHS.getBitWidth();
1229 if (!RHS.isZero() && RHS.Zero[0]) {
1230 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1231 unsigned RHSZeros = RHS.countMinTrailingZeros();
1232 APInt Mask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: RHSZeros);
1233 APInt OnesMask = LHS.One & Mask;
1234 APInt ZerosMask = LHS.Zero & Mask;
1235 return KnownBits(ZerosMask, OnesMask);
1236 }
1237 return KnownBits(BitWidth);
1238}
1239
1240KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1241 KnownBits Known = remGetLowBits(LHS, RHS);
1242 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1243 // NB: Low bits set in `remGetLowBits`.
1244 APInt HighBits = ~(RHS.getConstant() - 1);
1245 Known.Zero |= std::move(HighBits);
1246 return Known;
1247 }
1248
1249 // Since the result is less than or equal to either operand, any leading
1250 // zero bits in either operand must also exist in the result.
1251 uint32_t Leaders =
1252 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingZeros());
1253 Known.Zero.setHighBits(Leaders);
1254 return Known;
1255}
1256
1257KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1258 KnownBits Known = remGetLowBits(LHS, RHS);
1259 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1260 // NB: Low bits are set in `remGetLowBits`.
1261 APInt LowBits = RHS.getConstant() - 1;
1262 // If the first operand is non-negative or has all low bits zero, then
1263 // the upper bits are all zero.
1264 if (LHS.isNonNegative() || LowBits.isSubsetOf(RHS: LHS.Zero))
1265 Known.Zero |= ~LowBits;
1266
1267 // If the first operand is negative and not all low bits are zero, then
1268 // the upper bits are all one.
1269 if (LHS.isNegative() && LowBits.intersects(RHS: LHS.One))
1270 Known.One |= ~LowBits;
1271 return Known;
1272 }
1273
1274 // The sign bit is the LHS's sign bit, except when the result of the
1275 // remainder is zero. The magnitude of the result should be less than or
1276 // equal to the magnitude of either operand.
1277 if (LHS.isNegative() && Known.isNonZero())
1278 Known.One.setHighBits(
1279 std::max(a: LHS.countMinLeadingOnes(), b: RHS.countMinSignBits()));
1280 else if (LHS.isNonNegative())
1281 Known.Zero.setHighBits(
1282 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinSignBits()));
1283 return Known;
1284}
1285
1286KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1287 // Result bit is 0 if either operand bit is 0.
1288 Zero |= RHS.Zero;
1289 // Result bit is 1 if both operand bits are 1.
1290 One &= RHS.One;
1291 return *this;
1292}
1293
1294KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1295 // Result bit is 0 if both operand bits are 0.
1296 Zero &= RHS.Zero;
1297 // Result bit is 1 if either operand bit is 1.
1298 One |= RHS.One;
1299 return *this;
1300}
1301
1302KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1303 // Result bit is 0 if both operand bits are 0 or both are 1.
1304 APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1305 // Result bit is 1 if one operand bit is 0 and the other is 1.
1306 One = (Zero & RHS.One) | (One & RHS.Zero);
1307 Zero = std::move(Z);
1308 return *this;
1309}
1310
1311KnownBits KnownBits::blsi() const {
1312 unsigned BitWidth = getBitWidth();
1313 KnownBits Known(Zero, APInt(BitWidth, 0));
1314 unsigned Max = countMaxTrailingZeros();
1315 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1316 unsigned Min = countMinTrailingZeros();
1317 if (Max == Min && Max < BitWidth)
1318 Known.One.setBit(Max);
1319 return Known;
1320}
1321
1322KnownBits KnownBits::blsmsk() const {
1323 unsigned BitWidth = getBitWidth();
1324 KnownBits Known(BitWidth);
1325 unsigned Max = countMaxTrailingZeros();
1326 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1327 unsigned Min = countMinTrailingZeros();
1328 Known.One.setLowBits(std::min(a: Min + 1, b: BitWidth));
1329 return Known;
1330}
1331
1332void KnownBits::print(raw_ostream &OS) const {
1333 unsigned BitWidth = getBitWidth();
1334 for (unsigned I = 0; I < BitWidth; ++I) {
1335 unsigned N = BitWidth - I - 1;
1336 if (Zero[N] && One[N])
1337 OS << "!";
1338 else if (Zero[N])
1339 OS << "0";
1340 else if (One[N])
1341 OS << "1";
1342 else
1343 OS << "?";
1344 }
1345}
1346
1347#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1348LLVM_DUMP_METHOD void KnownBits::dump() const {
1349 print(dbgs());
1350 dbgs() << "\n";
1351}
1352#endif
1353