1//===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains a class for representing known zeros and ones used by
10// computeKnownBits.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Support/KnownBits.h"
15#include "llvm/Support/Debug.h"
16#include "llvm/Support/raw_ostream.h"
17#include <cassert>
18
19using namespace llvm;
20
21KnownBits KnownBits::flipSignBit(const KnownBits &Val) {
22 unsigned SignBitPosition = Val.getBitWidth() - 1;
23 APInt Zero = Val.Zero;
24 APInt One = Val.One;
25 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
26 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
27 return KnownBits(Zero, One);
28}
29
30static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS,
31 bool CarryZero, bool CarryOne) {
32
33 APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero;
34 APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne;
35
36 // Compute known bits of the carry.
37 APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero);
38 APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One;
39
40 // Compute set of known bits (where all three relevant bits are known).
41 APInt LHSKnownUnion = LHS.Zero | LHS.One;
42 APInt RHSKnownUnion = RHS.Zero | RHS.One;
43 APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne;
44 APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion;
45
46 // Compute known bits of the result.
47 KnownBits KnownOut;
48 KnownOut.Zero = ~std::move(PossibleSumZero) & Known;
49 KnownOut.One = std::move(PossibleSumOne) & Known;
50 return KnownOut;
51}
52
53KnownBits KnownBits::computeForAddCarry(
54 const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) {
55 assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit");
56 return ::computeForAddCarry(
57 LHS, RHS, CarryZero: Carry.Zero.getBoolValue(), CarryOne: Carry.One.getBoolValue());
58}
59
60KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW,
61 const KnownBits &LHS,
62 const KnownBits &RHS) {
63 unsigned BitWidth = LHS.getBitWidth();
64 KnownBits KnownOut(BitWidth);
65 // This can be a relatively expensive helper, so optimistically save some
66 // work.
67 if (LHS.isUnknown() && RHS.isUnknown())
68 return KnownOut;
69
70 if (!LHS.isUnknown() && !RHS.isUnknown()) {
71 if (Add) {
72 // Sum = LHS + RHS + 0
73 KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true,
74 /*CarryOne=*/false);
75 } else {
76 // Sum = LHS + ~RHS + 1
77 KnownBits NotRHS = RHS;
78 std::swap(a&: NotRHS.Zero, b&: NotRHS.One);
79 KnownOut = ::computeForAddCarry(LHS, RHS: NotRHS, /*CarryZero=*/false,
80 /*CarryOne=*/true);
81 }
82 }
83
84 // Handle add/sub given nsw and/or nuw.
85 if (NUW) {
86 if (Add) {
87 // (add nuw X, Y)
88 APInt MinVal = LHS.getMinValue().uadd_sat(RHS: RHS.getMinValue());
89 // None of the adds can end up overflowing, so min consecutive highbits
90 // in minimum possible of X + Y must all remain set.
91 if (NSW) {
92 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
93 // If we have NSW as well, we also know we can't overflow the signbit so
94 // can start counting from 1 bit back.
95 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
96 }
97 KnownOut.One.setHighBits(MinVal.countl_one());
98 } else {
99 // (sub nuw X, Y)
100 APInt MaxVal = LHS.getMaxValue().usub_sat(RHS: RHS.getMinValue());
101 // None of the subs can overflow at any point, so any common high bits
102 // will subtract away and result in zeros.
103 if (NSW) {
104 // If we have NSW as well, we also know we can't overflow the signbit so
105 // can start counting from 1 bit back.
106 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
107 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
108 }
109 KnownOut.Zero.setHighBits(MaxVal.countl_zero());
110 }
111 }
112
113 if (NSW) {
114 APInt MinVal;
115 APInt MaxVal;
116 if (Add) {
117 // (add nsw X, Y)
118 MinVal = LHS.getSignedMinValue().sadd_sat(RHS: RHS.getSignedMinValue());
119 MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS: RHS.getSignedMaxValue());
120 } else {
121 // (sub nsw X, Y)
122 MinVal = LHS.getSignedMinValue().ssub_sat(RHS: RHS.getSignedMaxValue());
123 MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS: RHS.getSignedMinValue());
124 }
125 if (MinVal.isNonNegative()) {
126 // If min is non-negative, result will always be non-neg (can't overflow
127 // around).
128 unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one();
129 KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
130 KnownOut.Zero.setSignBit();
131 }
132 if (MaxVal.isNegative()) {
133 // If max is negative, result will always be neg (can't overflow around).
134 unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero();
135 KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1);
136 KnownOut.One.setSignBit();
137 }
138 }
139
140 // Just return 0 if the nsw/nuw is violated and we have poison.
141 if (KnownOut.hasConflict())
142 KnownOut.setAllZero();
143 return KnownOut;
144}
145
146KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS,
147 const KnownBits &Borrow) {
148 assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit");
149
150 // LHS - RHS = LHS + ~RHS + 1
151 // Carry 1 - Borrow in ::computeForAddCarry
152 std::swap(a&: RHS.Zero, b&: RHS.One);
153 return ::computeForAddCarry(LHS, RHS,
154 /*CarryZero=*/Borrow.One.getBoolValue(),
155 /*CarryOne=*/Borrow.Zero.getBoolValue());
156}
157
158KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const {
159 unsigned BitWidth = getBitWidth();
160 assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth &&
161 "Illegal sext-in-register");
162
163 if (SrcBitWidth == BitWidth)
164 return *this;
165
166 unsigned ExtBits = BitWidth - SrcBitWidth;
167 KnownBits Result;
168 Result.One = One << ExtBits;
169 Result.Zero = Zero << ExtBits;
170 Result.One.ashrInPlace(ShiftAmt: ExtBits);
171 Result.Zero.ashrInPlace(ShiftAmt: ExtBits);
172 return Result;
173}
174
175KnownBits KnownBits::makeGE(const APInt &Val) const {
176 // Count the number of leading bit positions where our underlying value is
177 // known to be less than or equal to Val.
178 unsigned N = (Zero | Val).countl_one();
179
180 // For each of those bit positions, if Val has a 1 in that bit then our
181 // underlying value must also have a 1.
182 APInt MaskedVal(Val);
183 MaskedVal.clearLowBits(loBits: getBitWidth() - N);
184 return KnownBits(Zero, One | MaskedVal);
185}
186
187KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) {
188 // If we can prove that LHS >= RHS then use LHS as the result. Likewise for
189 // RHS. Ideally our caller would already have spotted these cases and
190 // optimized away the umax operation, but we handle them here for
191 // completeness.
192 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
193 return LHS;
194 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
195 return RHS;
196
197 // If the result of the umax is LHS then it must be greater than or equal to
198 // the minimum possible value of RHS. Likewise for RHS. Any known bits that
199 // are common to these two values are also known in the result.
200 KnownBits L = LHS.makeGE(Val: RHS.getMinValue());
201 KnownBits R = RHS.makeGE(Val: LHS.getMinValue());
202 return L.intersectWith(RHS: R);
203}
204
205KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) {
206 // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0]
207 auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); };
208 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
209}
210
211KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) {
212 return flipSignBit(Val: umax(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
213}
214
215KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
216 // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0]
217 auto Flip = [](const KnownBits &Val) {
218 unsigned SignBitPosition = Val.getBitWidth() - 1;
219 APInt Zero = Val.One;
220 APInt One = Val.Zero;
221 Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]);
222 One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]);
223 return KnownBits(Zero, One);
224 };
225 return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS)));
226}
227
228KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) {
229 // If we know which argument is larger, return (sub LHS, RHS) or
230 // (sub RHS, LHS) directly.
231 if (LHS.getMinValue().uge(RHS: RHS.getMaxValue()))
232 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
233 RHS);
234 if (RHS.getMinValue().uge(RHS: LHS.getMaxValue()))
235 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
236 RHS: LHS);
237
238 // By construction, the subtraction in abdu never has unsigned overflow.
239 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
240 KnownBits Diff0 =
241 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
242 KnownBits Diff1 =
243 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
244 return Diff0.intersectWith(RHS: Diff1);
245}
246
247KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) {
248 // If we know which argument is larger, return (sub LHS, RHS) or
249 // (sub RHS, LHS) directly.
250 if (LHS.getSignedMinValue().sge(RHS: RHS.getSignedMaxValue()))
251 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS,
252 RHS);
253 if (RHS.getSignedMinValue().sge(RHS: LHS.getSignedMaxValue()))
254 return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS,
255 RHS: LHS);
256
257 // Shift both arguments from the signed range to the unsigned range, e.g. from
258 // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like
259 // abdu does.
260 // Note that we can't just use "sub nsw" instead because abds has signed
261 // inputs but an unsigned result, which makes the overflow conditions
262 // different.
263 unsigned SignBitPosition = LHS.getBitWidth() - 1;
264 for (auto Arg : {&LHS, &RHS}) {
265 bool Tmp = Arg->Zero[SignBitPosition];
266 Arg->Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Arg->One[SignBitPosition]);
267 Arg->One.setBitVal(BitPosition: SignBitPosition, BitValue: Tmp);
268 }
269
270 // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS).
271 KnownBits Diff0 =
272 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS);
273 KnownBits Diff1 =
274 computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS);
275 return Diff0.intersectWith(RHS: Diff1);
276}
277
278static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) {
279 if (isPowerOf2_32(Value: BitWidth))
280 return MaxValue.extractBitsAsZExtValue(numBits: Log2_32(Value: BitWidth), bitPosition: 0);
281 // This is only an approximate upper bound.
282 return MaxValue.getLimitedValue(Limit: BitWidth - 1);
283}
284
285KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
286 bool NSW, bool ShAmtNonZero) {
287 unsigned BitWidth = LHS.getBitWidth();
288 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
289 KnownBits Known;
290 bool ShiftedOutZero, ShiftedOutOne;
291 Known.Zero = LHS.Zero.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutZero);
292 Known.Zero.setLowBits(ShiftAmt);
293 Known.One = LHS.One.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutOne);
294
295 // All cases returning poison have been handled by MaxShiftAmount already.
296 if (NSW) {
297 if (NUW && ShiftAmt != 0)
298 // NUW means we can assume anything shifted out was a zero.
299 ShiftedOutZero = true;
300
301 if (ShiftedOutZero)
302 Known.makeNonNegative();
303 else if (ShiftedOutOne)
304 Known.makeNegative();
305 }
306 return Known;
307 };
308
309 // Fast path for a common case when LHS is completely unknown.
310 KnownBits Known(BitWidth);
311 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
312 if (MinShiftAmount == 0 && ShAmtNonZero)
313 MinShiftAmount = 1;
314 if (LHS.isUnknown()) {
315 Known.Zero.setLowBits(MinShiftAmount);
316 if (NUW && NSW && MinShiftAmount != 0)
317 Known.makeNonNegative();
318 return Known;
319 }
320
321 // Determine maximum shift amount, taking NUW/NSW flags into account.
322 APInt MaxValue = RHS.getMaxValue();
323 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
324 if (NUW && NSW)
325 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros() - 1);
326 if (NUW)
327 MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros());
328 if (NSW)
329 MaxShiftAmount = std::min(
330 a: MaxShiftAmount,
331 b: std::max(a: LHS.countMaxLeadingZeros(), b: LHS.countMaxLeadingOnes()) - 1);
332
333 // Fast path for common case where the shift amount is unknown.
334 if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 &&
335 isPowerOf2_32(Value: BitWidth)) {
336 Known.Zero.setLowBits(LHS.countMinTrailingZeros());
337 if (LHS.isAllOnes())
338 Known.One.setSignBit();
339 if (NSW) {
340 if (LHS.isNonNegative())
341 Known.makeNonNegative();
342 if (LHS.isNegative())
343 Known.makeNegative();
344 }
345 return Known;
346 }
347
348 // Find the common bits from all possible shifts.
349 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
350 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
351 Known.Zero.setAllBits();
352 Known.One.setAllBits();
353 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
354 ++ShiftAmt) {
355 // Skip if the shift amount is impossible.
356 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
357 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
358 continue;
359 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
360 if (Known.isUnknown())
361 break;
362 }
363
364 // All shift amounts may result in poison.
365 if (Known.hasConflict())
366 Known.setAllZero();
367 return Known;
368}
369
370KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS,
371 bool ShAmtNonZero, bool Exact) {
372 unsigned BitWidth = LHS.getBitWidth();
373 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
374 KnownBits Known = LHS;
375 Known.Zero.lshrInPlace(ShiftAmt);
376 Known.One.lshrInPlace(ShiftAmt);
377 // High bits are known zero.
378 Known.Zero.setHighBits(ShiftAmt);
379 return Known;
380 };
381
382 // Fast path for a common case when LHS is completely unknown.
383 KnownBits Known(BitWidth);
384 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
385 if (MinShiftAmount == 0 && ShAmtNonZero)
386 MinShiftAmount = 1;
387 if (LHS.isUnknown()) {
388 Known.Zero.setHighBits(MinShiftAmount);
389 return Known;
390 }
391
392 // Find the common bits from all possible shifts.
393 APInt MaxValue = RHS.getMaxValue();
394 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
395
396 // If exact, bound MaxShiftAmount to first known 1 in LHS.
397 if (Exact) {
398 unsigned FirstOne = LHS.countMaxTrailingZeros();
399 if (FirstOne < MinShiftAmount) {
400 // Always poison. Return zero because we don't like returning conflict.
401 Known.setAllZero();
402 return Known;
403 }
404 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
405 }
406
407 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
408 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
409 Known.Zero.setAllBits();
410 Known.One.setAllBits();
411 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
412 ++ShiftAmt) {
413 // Skip if the shift amount is impossible.
414 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
415 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
416 continue;
417 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
418 if (Known.isUnknown())
419 break;
420 }
421
422 // All shift amounts may result in poison.
423 if (Known.hasConflict())
424 Known.setAllZero();
425 return Known;
426}
427
428KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS,
429 bool ShAmtNonZero, bool Exact) {
430 unsigned BitWidth = LHS.getBitWidth();
431 auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) {
432 KnownBits Known = LHS;
433 Known.Zero.ashrInPlace(ShiftAmt);
434 Known.One.ashrInPlace(ShiftAmt);
435 return Known;
436 };
437
438 // Fast path for a common case when LHS is completely unknown.
439 KnownBits Known(BitWidth);
440 unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth);
441 if (MinShiftAmount == 0 && ShAmtNonZero)
442 MinShiftAmount = 1;
443 if (LHS.isUnknown()) {
444 if (MinShiftAmount == BitWidth) {
445 // Always poison. Return zero because we don't like returning conflict.
446 Known.setAllZero();
447 return Known;
448 }
449 return Known;
450 }
451
452 // Find the common bits from all possible shifts.
453 APInt MaxValue = RHS.getMaxValue();
454 unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth);
455
456 // If exact, bound MaxShiftAmount to first known 1 in LHS.
457 if (Exact) {
458 unsigned FirstOne = LHS.countMaxTrailingZeros();
459 if (FirstOne < MinShiftAmount) {
460 // Always poison. Return zero because we don't like returning conflict.
461 Known.setAllZero();
462 return Known;
463 }
464 MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne);
465 }
466
467 unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue();
468 unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue();
469 Known.Zero.setAllBits();
470 Known.One.setAllBits();
471 for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount;
472 ++ShiftAmt) {
473 // Skip if the shift amount is impossible.
474 if ((ShiftAmtZeroMask & ShiftAmt) != 0 ||
475 (ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
476 continue;
477 Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt));
478 if (Known.isUnknown())
479 break;
480 }
481
482 // All shift amounts may result in poison.
483 if (Known.hasConflict())
484 Known.setAllZero();
485 return Known;
486}
487
488std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) {
489 if (LHS.isConstant() && RHS.isConstant())
490 return std::optional<bool>(LHS.getConstant() == RHS.getConstant());
491 if (LHS.One.intersects(RHS: RHS.Zero) || RHS.One.intersects(RHS: LHS.Zero))
492 return std::optional<bool>(false);
493 return std::nullopt;
494}
495
496std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) {
497 if (std::optional<bool> KnownEQ = eq(LHS, RHS))
498 return std::optional<bool>(!*KnownEQ);
499 return std::nullopt;
500}
501
502std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) {
503 // LHS >u RHS -> false if umax(LHS) <= umax(RHS)
504 if (LHS.getMaxValue().ule(RHS: RHS.getMinValue()))
505 return std::optional<bool>(false);
506 // LHS >u RHS -> true if umin(LHS) > umax(RHS)
507 if (LHS.getMinValue().ugt(RHS: RHS.getMaxValue()))
508 return std::optional<bool>(true);
509 return std::nullopt;
510}
511
512std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) {
513 if (std::optional<bool> IsUGT = ugt(LHS: RHS, RHS: LHS))
514 return std::optional<bool>(!*IsUGT);
515 return std::nullopt;
516}
517
518std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) {
519 return ugt(LHS: RHS, RHS: LHS);
520}
521
522std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) {
523 return uge(LHS: RHS, RHS: LHS);
524}
525
526std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) {
527 // LHS >s RHS -> false if smax(LHS) <= smax(RHS)
528 if (LHS.getSignedMaxValue().sle(RHS: RHS.getSignedMinValue()))
529 return std::optional<bool>(false);
530 // LHS >s RHS -> true if smin(LHS) > smax(RHS)
531 if (LHS.getSignedMinValue().sgt(RHS: RHS.getSignedMaxValue()))
532 return std::optional<bool>(true);
533 return std::nullopt;
534}
535
536std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) {
537 if (std::optional<bool> KnownSGT = sgt(LHS: RHS, RHS: LHS))
538 return std::optional<bool>(!*KnownSGT);
539 return std::nullopt;
540}
541
542std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) {
543 return sgt(LHS: RHS, RHS: LHS);
544}
545
546std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) {
547 return sge(LHS: RHS, RHS: LHS);
548}
549
550KnownBits KnownBits::abs(bool IntMinIsPoison) const {
551 // If the source's MSB is zero then we know the rest of the bits already.
552 if (isNonNegative())
553 return *this;
554
555 // Absolute value preserves trailing zero count.
556 KnownBits KnownAbs(getBitWidth());
557
558 // If the input is negative, then abs(x) == -x.
559 if (isNegative()) {
560 KnownBits Tmp = *this;
561 // Special case for IntMinIsPoison. We know the sign bit is set and we know
562 // all the rest of the bits except one to be zero. Since we have
563 // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is
564 // INT_MIN.
565 if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth())
566 Tmp.One.setBit(countMinTrailingZeros());
567
568 KnownAbs = computeForAddSub(
569 /*Add*/ false, NSW: IntMinIsPoison, /*NUW=*/false,
570 LHS: KnownBits::makeConstant(C: APInt(getBitWidth(), 0)), RHS: Tmp);
571
572 // One more special case for IntMinIsPoison. If we don't know any ones other
573 // than the signbit, we know for certain that all the unknowns can't be
574 // zero. So if we know high zero bits, but have unknown low bits, we know
575 // for certain those high-zero bits will end up as one. This is because,
576 // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up
577 // to the high bits. If we know a known INT_MIN input skip this. The result
578 // is poison anyways.
579 if (IntMinIsPoison && Tmp.countMinPopulation() == 1 &&
580 Tmp.countMaxPopulation() != 1) {
581 Tmp.One.clearSignBit();
582 Tmp.Zero.setSignBit();
583 KnownAbs.One.setBits(loBit: getBitWidth() - Tmp.countMinLeadingZeros(),
584 hiBit: getBitWidth() - 1);
585 }
586
587 } else {
588 unsigned MaxTZ = countMaxTrailingZeros();
589 unsigned MinTZ = countMinTrailingZeros();
590
591 KnownAbs.Zero.setLowBits(MinTZ);
592 // If we know the lowest set 1, then preserve it.
593 if (MaxTZ == MinTZ && MaxTZ < getBitWidth())
594 KnownAbs.One.setBit(MaxTZ);
595
596 // We only know that the absolute values's MSB will be zero if INT_MIN is
597 // poison, or there is a set bit that isn't the sign bit (otherwise it could
598 // be INT_MIN).
599 if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) {
600 KnownAbs.One.clearSignBit();
601 KnownAbs.Zero.setSignBit();
602 }
603 }
604
605 return KnownAbs;
606}
607
608static KnownBits computeForSatAddSub(bool Add, bool Signed,
609 const KnownBits &LHS,
610 const KnownBits &RHS) {
611 // We don't see NSW even for sadd/ssub as we want to check if the result has
612 // signed overflow.
613 unsigned BitWidth = LHS.getBitWidth();
614
615 std::optional<bool> Overflow;
616 // Even if we can't entirely rule out overflow, we may be able to rule out
617 // overflow in one direction. This allows us to potentially keep some of the
618 // add/sub bits. I.e if we can't overflow in the positive direction we won't
619 // clamp to INT_MAX so we can keep low 0s from the add/sub result.
620 bool MayNegClamp = true;
621 bool MayPosClamp = true;
622 if (Signed) {
623 // Easy cases we can rule out any overflow.
624 if (Add && ((LHS.isNegative() && RHS.isNonNegative()) ||
625 (LHS.isNonNegative() && RHS.isNegative())))
626 Overflow = false;
627 else if (!Add && (((LHS.isNegative() && RHS.isNegative()) ||
628 (LHS.isNonNegative() && RHS.isNonNegative()))))
629 Overflow = false;
630 else {
631 // Check if we may overflow. If we can't rule out overflow then check if
632 // we can rule out a direction at least.
633 KnownBits UnsignedLHS = LHS;
634 KnownBits UnsignedRHS = RHS;
635 // Get version of LHS/RHS with clearer signbit. This allows us to detect
636 // how the addition/subtraction might overflow into the signbit. Then
637 // using the actual known signbits of LHS/RHS, we can figure out which
638 // overflows are/aren't possible.
639 UnsignedLHS.One.clearSignBit();
640 UnsignedLHS.Zero.setSignBit();
641 UnsignedRHS.One.clearSignBit();
642 UnsignedRHS.Zero.setSignBit();
643 KnownBits Res =
644 KnownBits::computeForAddSub(Add, /*NSW=*/false,
645 /*NUW=*/false, LHS: UnsignedLHS, RHS: UnsignedRHS);
646 if (Add) {
647 if (Res.isNegative()) {
648 // Only overflow scenario is Pos + Pos.
649 MayNegClamp = false;
650 // Pos + Pos will overflow with extra signbit.
651 if (LHS.isNonNegative() && RHS.isNonNegative())
652 Overflow = true;
653 } else if (Res.isNonNegative()) {
654 // Only overflow scenario is Neg + Neg
655 MayPosClamp = false;
656 // Neg + Neg will overflow without extra signbit.
657 if (LHS.isNegative() && RHS.isNegative())
658 Overflow = true;
659 }
660 // We will never clamp to the opposite sign of N-bit result.
661 if (LHS.isNegative() || RHS.isNegative())
662 MayPosClamp = false;
663 if (LHS.isNonNegative() || RHS.isNonNegative())
664 MayNegClamp = false;
665 } else {
666 if (Res.isNegative()) {
667 // Only overflow scenario is Neg - Pos.
668 MayPosClamp = false;
669 // Neg - Pos will overflow with extra signbit.
670 if (LHS.isNegative() && RHS.isNonNegative())
671 Overflow = true;
672 } else if (Res.isNonNegative()) {
673 // Only overflow scenario is Pos - Neg.
674 MayNegClamp = false;
675 // Pos - Neg will overflow without extra signbit.
676 if (LHS.isNonNegative() && RHS.isNegative())
677 Overflow = true;
678 }
679 // We will never clamp to the opposite sign of N-bit result.
680 if (LHS.isNegative() || RHS.isNonNegative())
681 MayPosClamp = false;
682 if (LHS.isNonNegative() || RHS.isNegative())
683 MayNegClamp = false;
684 }
685 }
686 // If we have ruled out all clamping, we will never overflow.
687 if (!MayNegClamp && !MayPosClamp)
688 Overflow = false;
689 } else if (Add) {
690 // uadd.sat
691 bool Of;
692 (void)LHS.getMaxValue().uadd_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
693 if (!Of) {
694 Overflow = false;
695 } else {
696 (void)LHS.getMinValue().uadd_ov(RHS: RHS.getMinValue(), Overflow&: Of);
697 if (Of)
698 Overflow = true;
699 }
700 } else {
701 // usub.sat
702 bool Of;
703 (void)LHS.getMinValue().usub_ov(RHS: RHS.getMaxValue(), Overflow&: Of);
704 if (!Of) {
705 Overflow = false;
706 } else {
707 (void)LHS.getMaxValue().usub_ov(RHS: RHS.getMinValue(), Overflow&: Of);
708 if (Of)
709 Overflow = true;
710 }
711 }
712
713 KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed,
714 /*NUW=*/!Signed, LHS, RHS);
715
716 if (Overflow) {
717 // We know whether or not we overflowed.
718 if (!(*Overflow)) {
719 // No overflow.
720 return Res;
721 }
722
723 // We overflowed
724 APInt C;
725 if (Signed) {
726 // sadd.sat / ssub.sat
727 assert(!LHS.isSignUnknown() &&
728 "We somehow know overflow without knowing input sign");
729 C = LHS.isNegative() ? APInt::getSignedMinValue(numBits: BitWidth)
730 : APInt::getSignedMaxValue(numBits: BitWidth);
731 } else if (Add) {
732 // uadd.sat
733 C = APInt::getMaxValue(numBits: BitWidth);
734 } else {
735 // uadd.sat
736 C = APInt::getMinValue(numBits: BitWidth);
737 }
738
739 Res.One = C;
740 Res.Zero = ~C;
741 return Res;
742 }
743
744 // We don't know if we overflowed.
745 if (Signed) {
746 // sadd.sat/ssub.sat
747 // We can keep our information about the sign bits.
748 if (MayPosClamp)
749 Res.Zero.clearLowBits(loBits: BitWidth - 1);
750 if (MayNegClamp)
751 Res.One.clearLowBits(loBits: BitWidth - 1);
752 } else if (Add) {
753 // uadd.sat
754 // We need to clear all the known zeros as we can only use the leading ones.
755 Res.Zero.clearAllBits();
756 } else {
757 // usub.sat
758 // We need to clear all the known ones as we can only use the leading zero.
759 Res.One.clearAllBits();
760 }
761
762 return Res;
763}
764
765KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
766 return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS);
767}
768KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) {
769 return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS);
770}
771KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) {
772 return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS);
773}
774KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) {
775 return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS);
776}
777
778static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) {
779 unsigned BitWidth = LHS.getBitWidth();
780 LHS = LHS.zext(BitWidth: BitWidth + 1);
781 RHS = RHS.zext(BitWidth: BitWidth + 1);
782 LHS =
783 computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil);
784 LHS = LHS.extractBits(NumBits: BitWidth, BitPosition: 1);
785 return LHS;
786}
787
788KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) {
789 return flipSignBit(Val: avgFloorU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
790}
791
792KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) {
793 return avgComputeU(LHS, RHS, /*IsCeil=*/false);
794}
795
796KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) {
797 return flipSignBit(Val: avgCeilU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS)));
798}
799
800KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) {
801 return avgComputeU(LHS, RHS, /*IsCeil=*/true);
802}
803
804KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS,
805 bool NoUndefSelfMultiply) {
806 unsigned BitWidth = LHS.getBitWidth();
807 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
808 assert((!NoUndefSelfMultiply || LHS == RHS) &&
809 "Self multiplication knownbits mismatch");
810
811 // Compute the high known-0 bits by multiplying the unsigned max of each side.
812 // Conservatively, M active bits * N active bits results in M + N bits in the
813 // result. But if we know a value is a power-of-2 for example, then this
814 // computes one more leading zero.
815 // TODO: This could be generalized to number of sign bits (negative numbers).
816 APInt UMaxLHS = LHS.getMaxValue();
817 APInt UMaxRHS = RHS.getMaxValue();
818
819 // For leading zeros in the result to be valid, the unsigned max product must
820 // fit in the bitwidth (it must not overflow).
821 bool HasOverflow;
822 APInt UMaxResult = UMaxLHS.umul_ov(RHS: UMaxRHS, Overflow&: HasOverflow);
823 unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero();
824
825 // The result of the bottom bits of an integer multiply can be
826 // inferred by looking at the bottom bits of both operands and
827 // multiplying them together.
828 // We can infer at least the minimum number of known trailing bits
829 // of both operands. Depending on number of trailing zeros, we can
830 // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming
831 // a and b are divisible by m and n respectively.
832 // We then calculate how many of those bits are inferrable and set
833 // the output. For example, the i8 mul:
834 // a = XXXX1100 (12)
835 // b = XXXX1110 (14)
836 // We know the bottom 3 bits are zero since the first can be divided by
837 // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4).
838 // Applying the multiplication to the trimmed arguments gets:
839 // XX11 (3)
840 // X111 (7)
841 // -------
842 // XX11
843 // XX11
844 // XX11
845 // XX11
846 // -------
847 // XXXXX01
848 // Which allows us to infer the 2 LSBs. Since we're multiplying the result
849 // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits.
850 // The proof for this can be described as:
851 // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) &&
852 // (C7 == (1 << (umin(countTrailingZeros(C1), C5) +
853 // umin(countTrailingZeros(C2), C6) +
854 // umin(C5 - umin(countTrailingZeros(C1), C5),
855 // C6 - umin(countTrailingZeros(C2), C6)))) - 1)
856 // %aa = shl i8 %a, C5
857 // %bb = shl i8 %b, C6
858 // %aaa = or i8 %aa, C1
859 // %bbb = or i8 %bb, C2
860 // %mul = mul i8 %aaa, %bbb
861 // %mask = and i8 %mul, C7
862 // =>
863 // %mask = i8 ((C1*C2)&C7)
864 // Where C5, C6 describe the known bits of %a, %b
865 // C1, C2 describe the known bottom bits of %a, %b.
866 // C7 describes the mask of the known bits of the result.
867 const APInt &Bottom0 = LHS.One;
868 const APInt &Bottom1 = RHS.One;
869
870 // How many times we'd be able to divide each argument by 2 (shr by 1).
871 // This gives us the number of trailing zeros on the multiplication result.
872 unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one();
873 unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one();
874 unsigned TrailZero0 = LHS.countMinTrailingZeros();
875 unsigned TrailZero1 = RHS.countMinTrailingZeros();
876 unsigned TrailZ = TrailZero0 + TrailZero1;
877
878 // Figure out the fewest known-bits operand.
879 unsigned SmallestOperand =
880 std::min(a: TrailBitsKnown0 - TrailZero0, b: TrailBitsKnown1 - TrailZero1);
881 unsigned ResultBitsKnown = std::min(a: SmallestOperand + TrailZ, b: BitWidth);
882
883 APInt BottomKnown =
884 Bottom0.getLoBits(numBits: TrailBitsKnown0) * Bottom1.getLoBits(numBits: TrailBitsKnown1);
885
886 KnownBits Res(BitWidth);
887 Res.Zero.setHighBits(LeadZ);
888 Res.Zero |= (~BottomKnown).getLoBits(numBits: ResultBitsKnown);
889 Res.One = BottomKnown.getLoBits(numBits: ResultBitsKnown);
890
891 // If we're self-multiplying then bit[1] is guaranteed to be zero.
892 if (NoUndefSelfMultiply && BitWidth > 1) {
893 assert(Res.One[1] == 0 &&
894 "Self-multiplication failed Quadratic Reciprocity!");
895 Res.Zero.setBit(1);
896 }
897
898 return Res;
899}
900
901KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) {
902 unsigned BitWidth = LHS.getBitWidth();
903 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
904 KnownBits WideLHS = LHS.sext(BitWidth: 2 * BitWidth);
905 KnownBits WideRHS = RHS.sext(BitWidth: 2 * BitWidth);
906 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
907}
908
909KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) {
910 unsigned BitWidth = LHS.getBitWidth();
911 assert(BitWidth == RHS.getBitWidth() && "Operand mismatch");
912 KnownBits WideLHS = LHS.zext(BitWidth: 2 * BitWidth);
913 KnownBits WideRHS = RHS.zext(BitWidth: 2 * BitWidth);
914 return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth);
915}
916
917static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS,
918 const KnownBits &RHS, bool Exact) {
919
920 if (!Exact)
921 return Known;
922
923 // If LHS is Odd, the result is Odd no matter what.
924 // Odd / Odd -> Odd
925 // Odd / Even -> Impossible (because its exact division)
926 if (LHS.One[0])
927 Known.One.setBit(0);
928
929 int MinTZ =
930 (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros();
931 int MaxTZ =
932 (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros();
933 if (MinTZ >= 0) {
934 // Result has at least MinTZ trailing zeros.
935 Known.Zero.setLowBits(MinTZ);
936 if (MinTZ == MaxTZ) {
937 // Result has exactly MinTZ trailing zeros.
938 Known.One.setBit(MinTZ);
939 }
940 } else if (MaxTZ < 0) {
941 // Poison Result
942 Known.setAllZero();
943 }
944
945 // In the KnownBits exhaustive tests, we have poison inputs for exact values
946 // a LOT. If we have a conflict, just return all zeros.
947 if (Known.hasConflict())
948 Known.setAllZero();
949
950 return Known;
951}
952
953KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS,
954 bool Exact) {
955 // Equivalent of `udiv`. We must have caught this before it was folded.
956 if (LHS.isNonNegative() && RHS.isNonNegative())
957 return udiv(LHS, RHS, Exact);
958
959 unsigned BitWidth = LHS.getBitWidth();
960 KnownBits Known(BitWidth);
961
962 if (LHS.isZero() || RHS.isZero()) {
963 // Result is either known Zero or UB. Return Zero either way.
964 // Checking this earlier saves us a lot of special cases later on.
965 Known.setAllZero();
966 return Known;
967 }
968
969 std::optional<APInt> Res;
970 if (LHS.isNegative() && RHS.isNegative()) {
971 // Result non-negative.
972 APInt Denom = RHS.getSignedMaxValue();
973 APInt Num = LHS.getSignedMinValue();
974 // INT_MIN/-1 would be a poison result (impossible). Estimate the division
975 // as signed max (we will only set sign bit in the result).
976 Res = (Num.isMinSignedValue() && Denom.isAllOnes())
977 ? APInt::getSignedMaxValue(numBits: BitWidth)
978 : Num.sdiv(RHS: Denom);
979 } else if (LHS.isNegative() && RHS.isNonNegative()) {
980 // Result is negative if Exact OR -LHS u>= RHS.
981 if (Exact || (-LHS.getSignedMaxValue()).uge(RHS: RHS.getSignedMaxValue())) {
982 APInt Denom = RHS.getSignedMinValue();
983 APInt Num = LHS.getSignedMinValue();
984 Res = Denom.isZero() ? Num : Num.sdiv(RHS: Denom);
985 }
986 } else if (LHS.isStrictlyPositive() && RHS.isNegative()) {
987 // Result is negative if Exact OR LHS u>= -RHS.
988 if (Exact || LHS.getSignedMinValue().uge(RHS: -RHS.getSignedMinValue())) {
989 APInt Denom = RHS.getSignedMaxValue();
990 APInt Num = LHS.getSignedMaxValue();
991 Res = Num.sdiv(RHS: Denom);
992 }
993 }
994
995 if (Res) {
996 if (Res->isNonNegative()) {
997 unsigned LeadZ = Res->countLeadingZeros();
998 Known.Zero.setHighBits(LeadZ);
999 } else {
1000 unsigned LeadO = Res->countLeadingOnes();
1001 Known.One.setHighBits(LeadO);
1002 }
1003 }
1004
1005 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1006 return Known;
1007}
1008
1009KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS,
1010 bool Exact) {
1011 unsigned BitWidth = LHS.getBitWidth();
1012 KnownBits Known(BitWidth);
1013
1014 if (LHS.isZero() || RHS.isZero()) {
1015 // Result is either known Zero or UB. Return Zero either way.
1016 // Checking this earlier saves us a lot of special cases later on.
1017 Known.setAllZero();
1018 return Known;
1019 }
1020
1021 // We can figure out the minimum number of upper zero bits by doing
1022 // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator
1023 // gets larger, the number of upper zero bits increases.
1024 APInt MinDenom = RHS.getMinValue();
1025 APInt MaxNum = LHS.getMaxValue();
1026 APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(RHS: MinDenom);
1027
1028 unsigned LeadZ = MaxRes.countLeadingZeros();
1029
1030 Known.Zero.setHighBits(LeadZ);
1031 Known = divComputeLowBit(Known, LHS, RHS, Exact);
1032
1033 return Known;
1034}
1035
1036KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) {
1037 unsigned BitWidth = LHS.getBitWidth();
1038 if (!RHS.isZero() && RHS.Zero[0]) {
1039 // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result.
1040 unsigned RHSZeros = RHS.countMinTrailingZeros();
1041 APInt Mask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: RHSZeros);
1042 APInt OnesMask = LHS.One & Mask;
1043 APInt ZerosMask = LHS.Zero & Mask;
1044 return KnownBits(ZerosMask, OnesMask);
1045 }
1046 return KnownBits(BitWidth);
1047}
1048
1049KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) {
1050 KnownBits Known = remGetLowBits(LHS, RHS);
1051 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1052 // NB: Low bits set in `remGetLowBits`.
1053 APInt HighBits = ~(RHS.getConstant() - 1);
1054 Known.Zero |= HighBits;
1055 return Known;
1056 }
1057
1058 // Since the result is less than or equal to either operand, any leading
1059 // zero bits in either operand must also exist in the result.
1060 uint32_t Leaders =
1061 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingZeros());
1062 Known.Zero.setHighBits(Leaders);
1063 return Known;
1064}
1065
1066KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) {
1067 KnownBits Known = remGetLowBits(LHS, RHS);
1068 if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) {
1069 // NB: Low bits are set in `remGetLowBits`.
1070 APInt LowBits = RHS.getConstant() - 1;
1071 // If the first operand is non-negative or has all low bits zero, then
1072 // the upper bits are all zero.
1073 if (LHS.isNonNegative() || LowBits.isSubsetOf(RHS: LHS.Zero))
1074 Known.Zero |= ~LowBits;
1075
1076 // If the first operand is negative and not all low bits are zero, then
1077 // the upper bits are all one.
1078 if (LHS.isNegative() && LowBits.intersects(RHS: LHS.One))
1079 Known.One |= ~LowBits;
1080 return Known;
1081 }
1082
1083 // The sign bit is the LHS's sign bit, except when the result of the
1084 // remainder is zero. The magnitude of the result should be less than or
1085 // equal to the magnitude of either operand.
1086 if (LHS.isNegative() && Known.isNonZero())
1087 Known.One.setHighBits(
1088 std::max(a: LHS.countMinLeadingOnes(), b: RHS.countMinSignBits()));
1089 else if (LHS.isNonNegative())
1090 Known.Zero.setHighBits(
1091 std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinSignBits()));
1092 return Known;
1093}
1094
1095KnownBits &KnownBits::operator&=(const KnownBits &RHS) {
1096 // Result bit is 0 if either operand bit is 0.
1097 Zero |= RHS.Zero;
1098 // Result bit is 1 if both operand bits are 1.
1099 One &= RHS.One;
1100 return *this;
1101}
1102
1103KnownBits &KnownBits::operator|=(const KnownBits &RHS) {
1104 // Result bit is 0 if both operand bits are 0.
1105 Zero &= RHS.Zero;
1106 // Result bit is 1 if either operand bit is 1.
1107 One |= RHS.One;
1108 return *this;
1109}
1110
1111KnownBits &KnownBits::operator^=(const KnownBits &RHS) {
1112 // Result bit is 0 if both operand bits are 0 or both are 1.
1113 APInt Z = (Zero & RHS.Zero) | (One & RHS.One);
1114 // Result bit is 1 if one operand bit is 0 and the other is 1.
1115 One = (Zero & RHS.One) | (One & RHS.Zero);
1116 Zero = std::move(Z);
1117 return *this;
1118}
1119
1120KnownBits KnownBits::blsi() const {
1121 unsigned BitWidth = getBitWidth();
1122 KnownBits Known(Zero, APInt(BitWidth, 0));
1123 unsigned Max = countMaxTrailingZeros();
1124 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1125 unsigned Min = countMinTrailingZeros();
1126 if (Max == Min && Max < BitWidth)
1127 Known.One.setBit(Max);
1128 return Known;
1129}
1130
1131KnownBits KnownBits::blsmsk() const {
1132 unsigned BitWidth = getBitWidth();
1133 KnownBits Known(BitWidth);
1134 unsigned Max = countMaxTrailingZeros();
1135 Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth));
1136 unsigned Min = countMinTrailingZeros();
1137 Known.One.setLowBits(std::min(a: Min + 1, b: BitWidth));
1138 return Known;
1139}
1140
1141void KnownBits::print(raw_ostream &OS) const {
1142 unsigned BitWidth = getBitWidth();
1143 for (unsigned I = 0; I < BitWidth; ++I) {
1144 unsigned N = BitWidth - I - 1;
1145 if (Zero[N] && One[N])
1146 OS << "!";
1147 else if (Zero[N])
1148 OS << "0";
1149 else if (One[N])
1150 OS << "1";
1151 else
1152 OS << "?";
1153 }
1154}
1155
1156#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1157LLVM_DUMP_METHOD void KnownBits::dump() const {
1158 print(dbgs());
1159 dbgs() << "\n";
1160}
1161#endif
1162