1 | //===-- KnownBits.cpp - Stores known zeros/ones ---------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file contains a class for representing known zeros and ones used by |
10 | // computeKnownBits. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "llvm/Support/KnownBits.h" |
15 | #include "llvm/Support/Debug.h" |
16 | #include "llvm/Support/raw_ostream.h" |
17 | #include <cassert> |
18 | |
19 | using namespace llvm; |
20 | |
21 | KnownBits KnownBits::flipSignBit(const KnownBits &Val) { |
22 | unsigned SignBitPosition = Val.getBitWidth() - 1; |
23 | APInt Zero = Val.Zero; |
24 | APInt One = Val.One; |
25 | Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]); |
26 | One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]); |
27 | return KnownBits(Zero, One); |
28 | } |
29 | |
30 | static KnownBits computeForAddCarry(const KnownBits &LHS, const KnownBits &RHS, |
31 | bool CarryZero, bool CarryOne) { |
32 | |
33 | APInt PossibleSumZero = LHS.getMaxValue() + RHS.getMaxValue() + !CarryZero; |
34 | APInt PossibleSumOne = LHS.getMinValue() + RHS.getMinValue() + CarryOne; |
35 | |
36 | // Compute known bits of the carry. |
37 | APInt CarryKnownZero = ~(PossibleSumZero ^ LHS.Zero ^ RHS.Zero); |
38 | APInt CarryKnownOne = PossibleSumOne ^ LHS.One ^ RHS.One; |
39 | |
40 | // Compute set of known bits (where all three relevant bits are known). |
41 | APInt LHSKnownUnion = LHS.Zero | LHS.One; |
42 | APInt RHSKnownUnion = RHS.Zero | RHS.One; |
43 | APInt CarryKnownUnion = std::move(CarryKnownZero) | CarryKnownOne; |
44 | APInt Known = std::move(LHSKnownUnion) & RHSKnownUnion & CarryKnownUnion; |
45 | |
46 | // Compute known bits of the result. |
47 | KnownBits KnownOut; |
48 | KnownOut.Zero = ~std::move(PossibleSumZero) & Known; |
49 | KnownOut.One = std::move(PossibleSumOne) & Known; |
50 | return KnownOut; |
51 | } |
52 | |
53 | KnownBits KnownBits::computeForAddCarry( |
54 | const KnownBits &LHS, const KnownBits &RHS, const KnownBits &Carry) { |
55 | assert(Carry.getBitWidth() == 1 && "Carry must be 1-bit" ); |
56 | return ::computeForAddCarry( |
57 | LHS, RHS, CarryZero: Carry.Zero.getBoolValue(), CarryOne: Carry.One.getBoolValue()); |
58 | } |
59 | |
60 | KnownBits KnownBits::computeForAddSub(bool Add, bool NSW, bool NUW, |
61 | const KnownBits &LHS, |
62 | const KnownBits &RHS) { |
63 | unsigned BitWidth = LHS.getBitWidth(); |
64 | KnownBits KnownOut(BitWidth); |
65 | // This can be a relatively expensive helper, so optimistically save some |
66 | // work. |
67 | if (LHS.isUnknown() && RHS.isUnknown()) |
68 | return KnownOut; |
69 | |
70 | if (!LHS.isUnknown() && !RHS.isUnknown()) { |
71 | if (Add) { |
72 | // Sum = LHS + RHS + 0 |
73 | KnownOut = ::computeForAddCarry(LHS, RHS, /*CarryZero=*/true, |
74 | /*CarryOne=*/false); |
75 | } else { |
76 | // Sum = LHS + ~RHS + 1 |
77 | KnownBits NotRHS = RHS; |
78 | std::swap(a&: NotRHS.Zero, b&: NotRHS.One); |
79 | KnownOut = ::computeForAddCarry(LHS, RHS: NotRHS, /*CarryZero=*/false, |
80 | /*CarryOne=*/true); |
81 | } |
82 | } |
83 | |
84 | // Handle add/sub given nsw and/or nuw. |
85 | if (NUW) { |
86 | if (Add) { |
87 | // (add nuw X, Y) |
88 | APInt MinVal = LHS.getMinValue().uadd_sat(RHS: RHS.getMinValue()); |
89 | // None of the adds can end up overflowing, so min consecutive highbits |
90 | // in minimum possible of X + Y must all remain set. |
91 | if (NSW) { |
92 | unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one(); |
93 | // If we have NSW as well, we also know we can't overflow the signbit so |
94 | // can start counting from 1 bit back. |
95 | KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
96 | } |
97 | KnownOut.One.setHighBits(MinVal.countl_one()); |
98 | } else { |
99 | // (sub nuw X, Y) |
100 | APInt MaxVal = LHS.getMaxValue().usub_sat(RHS: RHS.getMinValue()); |
101 | // None of the subs can overflow at any point, so any common high bits |
102 | // will subtract away and result in zeros. |
103 | if (NSW) { |
104 | // If we have NSW as well, we also know we can't overflow the signbit so |
105 | // can start counting from 1 bit back. |
106 | unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero(); |
107 | KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
108 | } |
109 | KnownOut.Zero.setHighBits(MaxVal.countl_zero()); |
110 | } |
111 | } |
112 | |
113 | if (NSW) { |
114 | APInt MinVal; |
115 | APInt MaxVal; |
116 | if (Add) { |
117 | // (add nsw X, Y) |
118 | MinVal = LHS.getSignedMinValue().sadd_sat(RHS: RHS.getSignedMinValue()); |
119 | MaxVal = LHS.getSignedMaxValue().sadd_sat(RHS: RHS.getSignedMaxValue()); |
120 | } else { |
121 | // (sub nsw X, Y) |
122 | MinVal = LHS.getSignedMinValue().ssub_sat(RHS: RHS.getSignedMaxValue()); |
123 | MaxVal = LHS.getSignedMaxValue().ssub_sat(RHS: RHS.getSignedMinValue()); |
124 | } |
125 | if (MinVal.isNonNegative()) { |
126 | // If min is non-negative, result will always be non-neg (can't overflow |
127 | // around). |
128 | unsigned NumBits = MinVal.trunc(width: BitWidth - 1).countl_one(); |
129 | KnownOut.One.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
130 | KnownOut.Zero.setSignBit(); |
131 | } |
132 | if (MaxVal.isNegative()) { |
133 | // If max is negative, result will always be neg (can't overflow around). |
134 | unsigned NumBits = MaxVal.trunc(width: BitWidth - 1).countl_zero(); |
135 | KnownOut.Zero.setBits(loBit: BitWidth - 1 - NumBits, hiBit: BitWidth - 1); |
136 | KnownOut.One.setSignBit(); |
137 | } |
138 | } |
139 | |
140 | // Just return 0 if the nsw/nuw is violated and we have poison. |
141 | if (KnownOut.hasConflict()) |
142 | KnownOut.setAllZero(); |
143 | return KnownOut; |
144 | } |
145 | |
146 | KnownBits KnownBits::computeForSubBorrow(const KnownBits &LHS, KnownBits RHS, |
147 | const KnownBits &Borrow) { |
148 | assert(Borrow.getBitWidth() == 1 && "Borrow must be 1-bit" ); |
149 | |
150 | // LHS - RHS = LHS + ~RHS + 1 |
151 | // Carry 1 - Borrow in ::computeForAddCarry |
152 | std::swap(a&: RHS.Zero, b&: RHS.One); |
153 | return ::computeForAddCarry(LHS, RHS, |
154 | /*CarryZero=*/Borrow.One.getBoolValue(), |
155 | /*CarryOne=*/Borrow.Zero.getBoolValue()); |
156 | } |
157 | |
158 | KnownBits KnownBits::sextInReg(unsigned SrcBitWidth) const { |
159 | unsigned BitWidth = getBitWidth(); |
160 | assert(0 < SrcBitWidth && SrcBitWidth <= BitWidth && |
161 | "Illegal sext-in-register" ); |
162 | |
163 | if (SrcBitWidth == BitWidth) |
164 | return *this; |
165 | |
166 | unsigned ExtBits = BitWidth - SrcBitWidth; |
167 | KnownBits Result; |
168 | Result.One = One << ExtBits; |
169 | Result.Zero = Zero << ExtBits; |
170 | Result.One.ashrInPlace(ShiftAmt: ExtBits); |
171 | Result.Zero.ashrInPlace(ShiftAmt: ExtBits); |
172 | return Result; |
173 | } |
174 | |
175 | KnownBits KnownBits::makeGE(const APInt &Val) const { |
176 | // Count the number of leading bit positions where our underlying value is |
177 | // known to be less than or equal to Val. |
178 | unsigned N = (Zero | Val).countl_one(); |
179 | |
180 | // For each of those bit positions, if Val has a 1 in that bit then our |
181 | // underlying value must also have a 1. |
182 | APInt MaskedVal(Val); |
183 | MaskedVal.clearLowBits(loBits: getBitWidth() - N); |
184 | return KnownBits(Zero, One | MaskedVal); |
185 | } |
186 | |
187 | KnownBits KnownBits::umax(const KnownBits &LHS, const KnownBits &RHS) { |
188 | // If we can prove that LHS >= RHS then use LHS as the result. Likewise for |
189 | // RHS. Ideally our caller would already have spotted these cases and |
190 | // optimized away the umax operation, but we handle them here for |
191 | // completeness. |
192 | if (LHS.getMinValue().uge(RHS: RHS.getMaxValue())) |
193 | return LHS; |
194 | if (RHS.getMinValue().uge(RHS: LHS.getMaxValue())) |
195 | return RHS; |
196 | |
197 | // If the result of the umax is LHS then it must be greater than or equal to |
198 | // the minimum possible value of RHS. Likewise for RHS. Any known bits that |
199 | // are common to these two values are also known in the result. |
200 | KnownBits L = LHS.makeGE(Val: RHS.getMinValue()); |
201 | KnownBits R = RHS.makeGE(Val: LHS.getMinValue()); |
202 | return L.intersectWith(RHS: R); |
203 | } |
204 | |
205 | KnownBits KnownBits::umin(const KnownBits &LHS, const KnownBits &RHS) { |
206 | // Flip the range of values: [0, 0xFFFFFFFF] <-> [0xFFFFFFFF, 0] |
207 | auto Flip = [](const KnownBits &Val) { return KnownBits(Val.One, Val.Zero); }; |
208 | return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS))); |
209 | } |
210 | |
211 | KnownBits KnownBits::smax(const KnownBits &LHS, const KnownBits &RHS) { |
212 | return flipSignBit(Val: umax(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS))); |
213 | } |
214 | |
215 | KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) { |
216 | // Flip the range of values: [-0x80000000, 0x7FFFFFFF] <-> [0xFFFFFFFF, 0] |
217 | auto Flip = [](const KnownBits &Val) { |
218 | unsigned SignBitPosition = Val.getBitWidth() - 1; |
219 | APInt Zero = Val.One; |
220 | APInt One = Val.Zero; |
221 | Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Val.Zero[SignBitPosition]); |
222 | One.setBitVal(BitPosition: SignBitPosition, BitValue: Val.One[SignBitPosition]); |
223 | return KnownBits(Zero, One); |
224 | }; |
225 | return Flip(umax(LHS: Flip(LHS), RHS: Flip(RHS))); |
226 | } |
227 | |
228 | KnownBits KnownBits::abdu(const KnownBits &LHS, const KnownBits &RHS) { |
229 | // If we know which argument is larger, return (sub LHS, RHS) or |
230 | // (sub RHS, LHS) directly. |
231 | if (LHS.getMinValue().uge(RHS: RHS.getMaxValue())) |
232 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, |
233 | RHS); |
234 | if (RHS.getMinValue().uge(RHS: LHS.getMaxValue())) |
235 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS, |
236 | RHS: LHS); |
237 | |
238 | // By construction, the subtraction in abdu never has unsigned overflow. |
239 | // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS). |
240 | KnownBits Diff0 = |
241 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS); |
242 | KnownBits Diff1 = |
243 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS); |
244 | return Diff0.intersectWith(RHS: Diff1); |
245 | } |
246 | |
247 | KnownBits KnownBits::abds(KnownBits LHS, KnownBits RHS) { |
248 | // If we know which argument is larger, return (sub LHS, RHS) or |
249 | // (sub RHS, LHS) directly. |
250 | if (LHS.getSignedMinValue().sge(RHS: RHS.getSignedMaxValue())) |
251 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS, |
252 | RHS); |
253 | if (RHS.getSignedMinValue().sge(RHS: LHS.getSignedMaxValue())) |
254 | return computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/false, LHS: RHS, |
255 | RHS: LHS); |
256 | |
257 | // Shift both arguments from the signed range to the unsigned range, e.g. from |
258 | // [-0x80, 0x7F] to [0, 0xFF]. This allows us to use "sub nuw" below just like |
259 | // abdu does. |
260 | // Note that we can't just use "sub nsw" instead because abds has signed |
261 | // inputs but an unsigned result, which makes the overflow conditions |
262 | // different. |
263 | unsigned SignBitPosition = LHS.getBitWidth() - 1; |
264 | for (auto Arg : {&LHS, &RHS}) { |
265 | bool Tmp = Arg->Zero[SignBitPosition]; |
266 | Arg->Zero.setBitVal(BitPosition: SignBitPosition, BitValue: Arg->One[SignBitPosition]); |
267 | Arg->One.setBitVal(BitPosition: SignBitPosition, BitValue: Tmp); |
268 | } |
269 | |
270 | // Find the common bits between (sub nuw LHS, RHS) and (sub nuw RHS, LHS). |
271 | KnownBits Diff0 = |
272 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS, RHS); |
273 | KnownBits Diff1 = |
274 | computeForAddSub(/*Add=*/false, /*NSW=*/false, /*NUW=*/true, LHS: RHS, RHS: LHS); |
275 | return Diff0.intersectWith(RHS: Diff1); |
276 | } |
277 | |
278 | static unsigned getMaxShiftAmount(const APInt &MaxValue, unsigned BitWidth) { |
279 | if (isPowerOf2_32(Value: BitWidth)) |
280 | return MaxValue.extractBitsAsZExtValue(numBits: Log2_32(Value: BitWidth), bitPosition: 0); |
281 | // This is only an approximate upper bound. |
282 | return MaxValue.getLimitedValue(Limit: BitWidth - 1); |
283 | } |
284 | |
285 | KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW, |
286 | bool NSW, bool ShAmtNonZero) { |
287 | unsigned BitWidth = LHS.getBitWidth(); |
288 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
289 | KnownBits Known; |
290 | bool ShiftedOutZero, ShiftedOutOne; |
291 | Known.Zero = LHS.Zero.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutZero); |
292 | Known.Zero.setLowBits(ShiftAmt); |
293 | Known.One = LHS.One.ushl_ov(Amt: ShiftAmt, Overflow&: ShiftedOutOne); |
294 | |
295 | // All cases returning poison have been handled by MaxShiftAmount already. |
296 | if (NSW) { |
297 | if (NUW && ShiftAmt != 0) |
298 | // NUW means we can assume anything shifted out was a zero. |
299 | ShiftedOutZero = true; |
300 | |
301 | if (ShiftedOutZero) |
302 | Known.makeNonNegative(); |
303 | else if (ShiftedOutOne) |
304 | Known.makeNegative(); |
305 | } |
306 | return Known; |
307 | }; |
308 | |
309 | // Fast path for a common case when LHS is completely unknown. |
310 | KnownBits Known(BitWidth); |
311 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
312 | if (MinShiftAmount == 0 && ShAmtNonZero) |
313 | MinShiftAmount = 1; |
314 | if (LHS.isUnknown()) { |
315 | Known.Zero.setLowBits(MinShiftAmount); |
316 | if (NUW && NSW && MinShiftAmount != 0) |
317 | Known.makeNonNegative(); |
318 | return Known; |
319 | } |
320 | |
321 | // Determine maximum shift amount, taking NUW/NSW flags into account. |
322 | APInt MaxValue = RHS.getMaxValue(); |
323 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
324 | if (NUW && NSW) |
325 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros() - 1); |
326 | if (NUW) |
327 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: LHS.countMaxLeadingZeros()); |
328 | if (NSW) |
329 | MaxShiftAmount = std::min( |
330 | a: MaxShiftAmount, |
331 | b: std::max(a: LHS.countMaxLeadingZeros(), b: LHS.countMaxLeadingOnes()) - 1); |
332 | |
333 | // Fast path for common case where the shift amount is unknown. |
334 | if (MinShiftAmount == 0 && MaxShiftAmount == BitWidth - 1 && |
335 | isPowerOf2_32(Value: BitWidth)) { |
336 | Known.Zero.setLowBits(LHS.countMinTrailingZeros()); |
337 | if (LHS.isAllOnes()) |
338 | Known.One.setSignBit(); |
339 | if (NSW) { |
340 | if (LHS.isNonNegative()) |
341 | Known.makeNonNegative(); |
342 | if (LHS.isNegative()) |
343 | Known.makeNegative(); |
344 | } |
345 | return Known; |
346 | } |
347 | |
348 | // Find the common bits from all possible shifts. |
349 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
350 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
351 | Known.Zero.setAllBits(); |
352 | Known.One.setAllBits(); |
353 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
354 | ++ShiftAmt) { |
355 | // Skip if the shift amount is impossible. |
356 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
357 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
358 | continue; |
359 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
360 | if (Known.isUnknown()) |
361 | break; |
362 | } |
363 | |
364 | // All shift amounts may result in poison. |
365 | if (Known.hasConflict()) |
366 | Known.setAllZero(); |
367 | return Known; |
368 | } |
369 | |
370 | KnownBits KnownBits::lshr(const KnownBits &LHS, const KnownBits &RHS, |
371 | bool ShAmtNonZero, bool Exact) { |
372 | unsigned BitWidth = LHS.getBitWidth(); |
373 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
374 | KnownBits Known = LHS; |
375 | Known.Zero.lshrInPlace(ShiftAmt); |
376 | Known.One.lshrInPlace(ShiftAmt); |
377 | // High bits are known zero. |
378 | Known.Zero.setHighBits(ShiftAmt); |
379 | return Known; |
380 | }; |
381 | |
382 | // Fast path for a common case when LHS is completely unknown. |
383 | KnownBits Known(BitWidth); |
384 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
385 | if (MinShiftAmount == 0 && ShAmtNonZero) |
386 | MinShiftAmount = 1; |
387 | if (LHS.isUnknown()) { |
388 | Known.Zero.setHighBits(MinShiftAmount); |
389 | return Known; |
390 | } |
391 | |
392 | // Find the common bits from all possible shifts. |
393 | APInt MaxValue = RHS.getMaxValue(); |
394 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
395 | |
396 | // If exact, bound MaxShiftAmount to first known 1 in LHS. |
397 | if (Exact) { |
398 | unsigned FirstOne = LHS.countMaxTrailingZeros(); |
399 | if (FirstOne < MinShiftAmount) { |
400 | // Always poison. Return zero because we don't like returning conflict. |
401 | Known.setAllZero(); |
402 | return Known; |
403 | } |
404 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne); |
405 | } |
406 | |
407 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
408 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
409 | Known.Zero.setAllBits(); |
410 | Known.One.setAllBits(); |
411 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
412 | ++ShiftAmt) { |
413 | // Skip if the shift amount is impossible. |
414 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
415 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
416 | continue; |
417 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
418 | if (Known.isUnknown()) |
419 | break; |
420 | } |
421 | |
422 | // All shift amounts may result in poison. |
423 | if (Known.hasConflict()) |
424 | Known.setAllZero(); |
425 | return Known; |
426 | } |
427 | |
428 | KnownBits KnownBits::ashr(const KnownBits &LHS, const KnownBits &RHS, |
429 | bool ShAmtNonZero, bool Exact) { |
430 | unsigned BitWidth = LHS.getBitWidth(); |
431 | auto ShiftByConst = [&](const KnownBits &LHS, unsigned ShiftAmt) { |
432 | KnownBits Known = LHS; |
433 | Known.Zero.ashrInPlace(ShiftAmt); |
434 | Known.One.ashrInPlace(ShiftAmt); |
435 | return Known; |
436 | }; |
437 | |
438 | // Fast path for a common case when LHS is completely unknown. |
439 | KnownBits Known(BitWidth); |
440 | unsigned MinShiftAmount = RHS.getMinValue().getLimitedValue(Limit: BitWidth); |
441 | if (MinShiftAmount == 0 && ShAmtNonZero) |
442 | MinShiftAmount = 1; |
443 | if (LHS.isUnknown()) { |
444 | if (MinShiftAmount == BitWidth) { |
445 | // Always poison. Return zero because we don't like returning conflict. |
446 | Known.setAllZero(); |
447 | return Known; |
448 | } |
449 | return Known; |
450 | } |
451 | |
452 | // Find the common bits from all possible shifts. |
453 | APInt MaxValue = RHS.getMaxValue(); |
454 | unsigned MaxShiftAmount = getMaxShiftAmount(MaxValue, BitWidth); |
455 | |
456 | // If exact, bound MaxShiftAmount to first known 1 in LHS. |
457 | if (Exact) { |
458 | unsigned FirstOne = LHS.countMaxTrailingZeros(); |
459 | if (FirstOne < MinShiftAmount) { |
460 | // Always poison. Return zero because we don't like returning conflict. |
461 | Known.setAllZero(); |
462 | return Known; |
463 | } |
464 | MaxShiftAmount = std::min(a: MaxShiftAmount, b: FirstOne); |
465 | } |
466 | |
467 | unsigned ShiftAmtZeroMask = RHS.Zero.zextOrTrunc(width: 32).getZExtValue(); |
468 | unsigned ShiftAmtOneMask = RHS.One.zextOrTrunc(width: 32).getZExtValue(); |
469 | Known.Zero.setAllBits(); |
470 | Known.One.setAllBits(); |
471 | for (unsigned ShiftAmt = MinShiftAmount; ShiftAmt <= MaxShiftAmount; |
472 | ++ShiftAmt) { |
473 | // Skip if the shift amount is impossible. |
474 | if ((ShiftAmtZeroMask & ShiftAmt) != 0 || |
475 | (ShiftAmtOneMask | ShiftAmt) != ShiftAmt) |
476 | continue; |
477 | Known = Known.intersectWith(RHS: ShiftByConst(LHS, ShiftAmt)); |
478 | if (Known.isUnknown()) |
479 | break; |
480 | } |
481 | |
482 | // All shift amounts may result in poison. |
483 | if (Known.hasConflict()) |
484 | Known.setAllZero(); |
485 | return Known; |
486 | } |
487 | |
488 | std::optional<bool> KnownBits::eq(const KnownBits &LHS, const KnownBits &RHS) { |
489 | if (LHS.isConstant() && RHS.isConstant()) |
490 | return std::optional<bool>(LHS.getConstant() == RHS.getConstant()); |
491 | if (LHS.One.intersects(RHS: RHS.Zero) || RHS.One.intersects(RHS: LHS.Zero)) |
492 | return std::optional<bool>(false); |
493 | return std::nullopt; |
494 | } |
495 | |
496 | std::optional<bool> KnownBits::ne(const KnownBits &LHS, const KnownBits &RHS) { |
497 | if (std::optional<bool> KnownEQ = eq(LHS, RHS)) |
498 | return std::optional<bool>(!*KnownEQ); |
499 | return std::nullopt; |
500 | } |
501 | |
502 | std::optional<bool> KnownBits::ugt(const KnownBits &LHS, const KnownBits &RHS) { |
503 | // LHS >u RHS -> false if umax(LHS) <= umax(RHS) |
504 | if (LHS.getMaxValue().ule(RHS: RHS.getMinValue())) |
505 | return std::optional<bool>(false); |
506 | // LHS >u RHS -> true if umin(LHS) > umax(RHS) |
507 | if (LHS.getMinValue().ugt(RHS: RHS.getMaxValue())) |
508 | return std::optional<bool>(true); |
509 | return std::nullopt; |
510 | } |
511 | |
512 | std::optional<bool> KnownBits::uge(const KnownBits &LHS, const KnownBits &RHS) { |
513 | if (std::optional<bool> IsUGT = ugt(LHS: RHS, RHS: LHS)) |
514 | return std::optional<bool>(!*IsUGT); |
515 | return std::nullopt; |
516 | } |
517 | |
518 | std::optional<bool> KnownBits::ult(const KnownBits &LHS, const KnownBits &RHS) { |
519 | return ugt(LHS: RHS, RHS: LHS); |
520 | } |
521 | |
522 | std::optional<bool> KnownBits::ule(const KnownBits &LHS, const KnownBits &RHS) { |
523 | return uge(LHS: RHS, RHS: LHS); |
524 | } |
525 | |
526 | std::optional<bool> KnownBits::sgt(const KnownBits &LHS, const KnownBits &RHS) { |
527 | // LHS >s RHS -> false if smax(LHS) <= smax(RHS) |
528 | if (LHS.getSignedMaxValue().sle(RHS: RHS.getSignedMinValue())) |
529 | return std::optional<bool>(false); |
530 | // LHS >s RHS -> true if smin(LHS) > smax(RHS) |
531 | if (LHS.getSignedMinValue().sgt(RHS: RHS.getSignedMaxValue())) |
532 | return std::optional<bool>(true); |
533 | return std::nullopt; |
534 | } |
535 | |
536 | std::optional<bool> KnownBits::sge(const KnownBits &LHS, const KnownBits &RHS) { |
537 | if (std::optional<bool> KnownSGT = sgt(LHS: RHS, RHS: LHS)) |
538 | return std::optional<bool>(!*KnownSGT); |
539 | return std::nullopt; |
540 | } |
541 | |
542 | std::optional<bool> KnownBits::slt(const KnownBits &LHS, const KnownBits &RHS) { |
543 | return sgt(LHS: RHS, RHS: LHS); |
544 | } |
545 | |
546 | std::optional<bool> KnownBits::sle(const KnownBits &LHS, const KnownBits &RHS) { |
547 | return sge(LHS: RHS, RHS: LHS); |
548 | } |
549 | |
550 | KnownBits KnownBits::abs(bool IntMinIsPoison) const { |
551 | // If the source's MSB is zero then we know the rest of the bits already. |
552 | if (isNonNegative()) |
553 | return *this; |
554 | |
555 | // Absolute value preserves trailing zero count. |
556 | KnownBits KnownAbs(getBitWidth()); |
557 | |
558 | // If the input is negative, then abs(x) == -x. |
559 | if (isNegative()) { |
560 | KnownBits Tmp = *this; |
561 | // Special case for IntMinIsPoison. We know the sign bit is set and we know |
562 | // all the rest of the bits except one to be zero. Since we have |
563 | // IntMinIsPoison, that final bit MUST be a one, as otherwise the input is |
564 | // INT_MIN. |
565 | if (IntMinIsPoison && (Zero.popcount() + 2) == getBitWidth()) |
566 | Tmp.One.setBit(countMinTrailingZeros()); |
567 | |
568 | KnownAbs = computeForAddSub( |
569 | /*Add*/ false, NSW: IntMinIsPoison, /*NUW=*/false, |
570 | LHS: KnownBits::makeConstant(C: APInt(getBitWidth(), 0)), RHS: Tmp); |
571 | |
572 | // One more special case for IntMinIsPoison. If we don't know any ones other |
573 | // than the signbit, we know for certain that all the unknowns can't be |
574 | // zero. So if we know high zero bits, but have unknown low bits, we know |
575 | // for certain those high-zero bits will end up as one. This is because, |
576 | // the low bits can't be all zeros, so the +1 in (~x + 1) cannot carry up |
577 | // to the high bits. If we know a known INT_MIN input skip this. The result |
578 | // is poison anyways. |
579 | if (IntMinIsPoison && Tmp.countMinPopulation() == 1 && |
580 | Tmp.countMaxPopulation() != 1) { |
581 | Tmp.One.clearSignBit(); |
582 | Tmp.Zero.setSignBit(); |
583 | KnownAbs.One.setBits(loBit: getBitWidth() - Tmp.countMinLeadingZeros(), |
584 | hiBit: getBitWidth() - 1); |
585 | } |
586 | |
587 | } else { |
588 | unsigned MaxTZ = countMaxTrailingZeros(); |
589 | unsigned MinTZ = countMinTrailingZeros(); |
590 | |
591 | KnownAbs.Zero.setLowBits(MinTZ); |
592 | // If we know the lowest set 1, then preserve it. |
593 | if (MaxTZ == MinTZ && MaxTZ < getBitWidth()) |
594 | KnownAbs.One.setBit(MaxTZ); |
595 | |
596 | // We only know that the absolute values's MSB will be zero if INT_MIN is |
597 | // poison, or there is a set bit that isn't the sign bit (otherwise it could |
598 | // be INT_MIN). |
599 | if (IntMinIsPoison || (!One.isZero() && !One.isMinSignedValue())) { |
600 | KnownAbs.One.clearSignBit(); |
601 | KnownAbs.Zero.setSignBit(); |
602 | } |
603 | } |
604 | |
605 | return KnownAbs; |
606 | } |
607 | |
608 | static KnownBits computeForSatAddSub(bool Add, bool Signed, |
609 | const KnownBits &LHS, |
610 | const KnownBits &RHS) { |
611 | // We don't see NSW even for sadd/ssub as we want to check if the result has |
612 | // signed overflow. |
613 | unsigned BitWidth = LHS.getBitWidth(); |
614 | |
615 | std::optional<bool> Overflow; |
616 | // Even if we can't entirely rule out overflow, we may be able to rule out |
617 | // overflow in one direction. This allows us to potentially keep some of the |
618 | // add/sub bits. I.e if we can't overflow in the positive direction we won't |
619 | // clamp to INT_MAX so we can keep low 0s from the add/sub result. |
620 | bool MayNegClamp = true; |
621 | bool MayPosClamp = true; |
622 | if (Signed) { |
623 | // Easy cases we can rule out any overflow. |
624 | if (Add && ((LHS.isNegative() && RHS.isNonNegative()) || |
625 | (LHS.isNonNegative() && RHS.isNegative()))) |
626 | Overflow = false; |
627 | else if (!Add && (((LHS.isNegative() && RHS.isNegative()) || |
628 | (LHS.isNonNegative() && RHS.isNonNegative())))) |
629 | Overflow = false; |
630 | else { |
631 | // Check if we may overflow. If we can't rule out overflow then check if |
632 | // we can rule out a direction at least. |
633 | KnownBits UnsignedLHS = LHS; |
634 | KnownBits UnsignedRHS = RHS; |
635 | // Get version of LHS/RHS with clearer signbit. This allows us to detect |
636 | // how the addition/subtraction might overflow into the signbit. Then |
637 | // using the actual known signbits of LHS/RHS, we can figure out which |
638 | // overflows are/aren't possible. |
639 | UnsignedLHS.One.clearSignBit(); |
640 | UnsignedLHS.Zero.setSignBit(); |
641 | UnsignedRHS.One.clearSignBit(); |
642 | UnsignedRHS.Zero.setSignBit(); |
643 | KnownBits Res = |
644 | KnownBits::computeForAddSub(Add, /*NSW=*/false, |
645 | /*NUW=*/false, LHS: UnsignedLHS, RHS: UnsignedRHS); |
646 | if (Add) { |
647 | if (Res.isNegative()) { |
648 | // Only overflow scenario is Pos + Pos. |
649 | MayNegClamp = false; |
650 | // Pos + Pos will overflow with extra signbit. |
651 | if (LHS.isNonNegative() && RHS.isNonNegative()) |
652 | Overflow = true; |
653 | } else if (Res.isNonNegative()) { |
654 | // Only overflow scenario is Neg + Neg |
655 | MayPosClamp = false; |
656 | // Neg + Neg will overflow without extra signbit. |
657 | if (LHS.isNegative() && RHS.isNegative()) |
658 | Overflow = true; |
659 | } |
660 | // We will never clamp to the opposite sign of N-bit result. |
661 | if (LHS.isNegative() || RHS.isNegative()) |
662 | MayPosClamp = false; |
663 | if (LHS.isNonNegative() || RHS.isNonNegative()) |
664 | MayNegClamp = false; |
665 | } else { |
666 | if (Res.isNegative()) { |
667 | // Only overflow scenario is Neg - Pos. |
668 | MayPosClamp = false; |
669 | // Neg - Pos will overflow with extra signbit. |
670 | if (LHS.isNegative() && RHS.isNonNegative()) |
671 | Overflow = true; |
672 | } else if (Res.isNonNegative()) { |
673 | // Only overflow scenario is Pos - Neg. |
674 | MayNegClamp = false; |
675 | // Pos - Neg will overflow without extra signbit. |
676 | if (LHS.isNonNegative() && RHS.isNegative()) |
677 | Overflow = true; |
678 | } |
679 | // We will never clamp to the opposite sign of N-bit result. |
680 | if (LHS.isNegative() || RHS.isNonNegative()) |
681 | MayPosClamp = false; |
682 | if (LHS.isNonNegative() || RHS.isNegative()) |
683 | MayNegClamp = false; |
684 | } |
685 | } |
686 | // If we have ruled out all clamping, we will never overflow. |
687 | if (!MayNegClamp && !MayPosClamp) |
688 | Overflow = false; |
689 | } else if (Add) { |
690 | // uadd.sat |
691 | bool Of; |
692 | (void)LHS.getMaxValue().uadd_ov(RHS: RHS.getMaxValue(), Overflow&: Of); |
693 | if (!Of) { |
694 | Overflow = false; |
695 | } else { |
696 | (void)LHS.getMinValue().uadd_ov(RHS: RHS.getMinValue(), Overflow&: Of); |
697 | if (Of) |
698 | Overflow = true; |
699 | } |
700 | } else { |
701 | // usub.sat |
702 | bool Of; |
703 | (void)LHS.getMinValue().usub_ov(RHS: RHS.getMaxValue(), Overflow&: Of); |
704 | if (!Of) { |
705 | Overflow = false; |
706 | } else { |
707 | (void)LHS.getMaxValue().usub_ov(RHS: RHS.getMinValue(), Overflow&: Of); |
708 | if (Of) |
709 | Overflow = true; |
710 | } |
711 | } |
712 | |
713 | KnownBits Res = KnownBits::computeForAddSub(Add, /*NSW=*/Signed, |
714 | /*NUW=*/!Signed, LHS, RHS); |
715 | |
716 | if (Overflow) { |
717 | // We know whether or not we overflowed. |
718 | if (!(*Overflow)) { |
719 | // No overflow. |
720 | return Res; |
721 | } |
722 | |
723 | // We overflowed |
724 | APInt C; |
725 | if (Signed) { |
726 | // sadd.sat / ssub.sat |
727 | assert(!LHS.isSignUnknown() && |
728 | "We somehow know overflow without knowing input sign" ); |
729 | C = LHS.isNegative() ? APInt::getSignedMinValue(numBits: BitWidth) |
730 | : APInt::getSignedMaxValue(numBits: BitWidth); |
731 | } else if (Add) { |
732 | // uadd.sat |
733 | C = APInt::getMaxValue(numBits: BitWidth); |
734 | } else { |
735 | // uadd.sat |
736 | C = APInt::getMinValue(numBits: BitWidth); |
737 | } |
738 | |
739 | Res.One = C; |
740 | Res.Zero = ~C; |
741 | return Res; |
742 | } |
743 | |
744 | // We don't know if we overflowed. |
745 | if (Signed) { |
746 | // sadd.sat/ssub.sat |
747 | // We can keep our information about the sign bits. |
748 | if (MayPosClamp) |
749 | Res.Zero.clearLowBits(loBits: BitWidth - 1); |
750 | if (MayNegClamp) |
751 | Res.One.clearLowBits(loBits: BitWidth - 1); |
752 | } else if (Add) { |
753 | // uadd.sat |
754 | // We need to clear all the known zeros as we can only use the leading ones. |
755 | Res.Zero.clearAllBits(); |
756 | } else { |
757 | // usub.sat |
758 | // We need to clear all the known ones as we can only use the leading zero. |
759 | Res.One.clearAllBits(); |
760 | } |
761 | |
762 | return Res; |
763 | } |
764 | |
765 | KnownBits KnownBits::sadd_sat(const KnownBits &LHS, const KnownBits &RHS) { |
766 | return computeForSatAddSub(/*Add*/ true, /*Signed*/ true, LHS, RHS); |
767 | } |
768 | KnownBits KnownBits::ssub_sat(const KnownBits &LHS, const KnownBits &RHS) { |
769 | return computeForSatAddSub(/*Add*/ false, /*Signed*/ true, LHS, RHS); |
770 | } |
771 | KnownBits KnownBits::uadd_sat(const KnownBits &LHS, const KnownBits &RHS) { |
772 | return computeForSatAddSub(/*Add*/ true, /*Signed*/ false, LHS, RHS); |
773 | } |
774 | KnownBits KnownBits::usub_sat(const KnownBits &LHS, const KnownBits &RHS) { |
775 | return computeForSatAddSub(/*Add*/ false, /*Signed*/ false, LHS, RHS); |
776 | } |
777 | |
778 | static KnownBits avgComputeU(KnownBits LHS, KnownBits RHS, bool IsCeil) { |
779 | unsigned BitWidth = LHS.getBitWidth(); |
780 | LHS = LHS.zext(BitWidth: BitWidth + 1); |
781 | RHS = RHS.zext(BitWidth: BitWidth + 1); |
782 | LHS = |
783 | computeForAddCarry(LHS, RHS, /*CarryZero*/ !IsCeil, /*CarryOne*/ IsCeil); |
784 | LHS = LHS.extractBits(NumBits: BitWidth, BitPosition: 1); |
785 | return LHS; |
786 | } |
787 | |
788 | KnownBits KnownBits::avgFloorS(const KnownBits &LHS, const KnownBits &RHS) { |
789 | return flipSignBit(Val: avgFloorU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS))); |
790 | } |
791 | |
792 | KnownBits KnownBits::avgFloorU(const KnownBits &LHS, const KnownBits &RHS) { |
793 | return avgComputeU(LHS, RHS, /*IsCeil=*/false); |
794 | } |
795 | |
796 | KnownBits KnownBits::avgCeilS(const KnownBits &LHS, const KnownBits &RHS) { |
797 | return flipSignBit(Val: avgCeilU(LHS: flipSignBit(Val: LHS), RHS: flipSignBit(Val: RHS))); |
798 | } |
799 | |
800 | KnownBits KnownBits::avgCeilU(const KnownBits &LHS, const KnownBits &RHS) { |
801 | return avgComputeU(LHS, RHS, /*IsCeil=*/true); |
802 | } |
803 | |
804 | KnownBits KnownBits::mul(const KnownBits &LHS, const KnownBits &RHS, |
805 | bool NoUndefSelfMultiply) { |
806 | unsigned BitWidth = LHS.getBitWidth(); |
807 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
808 | assert((!NoUndefSelfMultiply || LHS == RHS) && |
809 | "Self multiplication knownbits mismatch" ); |
810 | |
811 | // Compute the high known-0 bits by multiplying the unsigned max of each side. |
812 | // Conservatively, M active bits * N active bits results in M + N bits in the |
813 | // result. But if we know a value is a power-of-2 for example, then this |
814 | // computes one more leading zero. |
815 | // TODO: This could be generalized to number of sign bits (negative numbers). |
816 | APInt UMaxLHS = LHS.getMaxValue(); |
817 | APInt UMaxRHS = RHS.getMaxValue(); |
818 | |
819 | // For leading zeros in the result to be valid, the unsigned max product must |
820 | // fit in the bitwidth (it must not overflow). |
821 | bool HasOverflow; |
822 | APInt UMaxResult = UMaxLHS.umul_ov(RHS: UMaxRHS, Overflow&: HasOverflow); |
823 | unsigned LeadZ = HasOverflow ? 0 : UMaxResult.countl_zero(); |
824 | |
825 | // The result of the bottom bits of an integer multiply can be |
826 | // inferred by looking at the bottom bits of both operands and |
827 | // multiplying them together. |
828 | // We can infer at least the minimum number of known trailing bits |
829 | // of both operands. Depending on number of trailing zeros, we can |
830 | // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming |
831 | // a and b are divisible by m and n respectively. |
832 | // We then calculate how many of those bits are inferrable and set |
833 | // the output. For example, the i8 mul: |
834 | // a = XXXX1100 (12) |
835 | // b = XXXX1110 (14) |
836 | // We know the bottom 3 bits are zero since the first can be divided by |
837 | // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). |
838 | // Applying the multiplication to the trimmed arguments gets: |
839 | // XX11 (3) |
840 | // X111 (7) |
841 | // ------- |
842 | // XX11 |
843 | // XX11 |
844 | // XX11 |
845 | // XX11 |
846 | // ------- |
847 | // XXXXX01 |
848 | // Which allows us to infer the 2 LSBs. Since we're multiplying the result |
849 | // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. |
850 | // The proof for this can be described as: |
851 | // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && |
852 | // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + |
853 | // umin(countTrailingZeros(C2), C6) + |
854 | // umin(C5 - umin(countTrailingZeros(C1), C5), |
855 | // C6 - umin(countTrailingZeros(C2), C6)))) - 1) |
856 | // %aa = shl i8 %a, C5 |
857 | // %bb = shl i8 %b, C6 |
858 | // %aaa = or i8 %aa, C1 |
859 | // %bbb = or i8 %bb, C2 |
860 | // %mul = mul i8 %aaa, %bbb |
861 | // %mask = and i8 %mul, C7 |
862 | // => |
863 | // %mask = i8 ((C1*C2)&C7) |
864 | // Where C5, C6 describe the known bits of %a, %b |
865 | // C1, C2 describe the known bottom bits of %a, %b. |
866 | // C7 describes the mask of the known bits of the result. |
867 | const APInt &Bottom0 = LHS.One; |
868 | const APInt &Bottom1 = RHS.One; |
869 | |
870 | // How many times we'd be able to divide each argument by 2 (shr by 1). |
871 | // This gives us the number of trailing zeros on the multiplication result. |
872 | unsigned TrailBitsKnown0 = (LHS.Zero | LHS.One).countr_one(); |
873 | unsigned TrailBitsKnown1 = (RHS.Zero | RHS.One).countr_one(); |
874 | unsigned TrailZero0 = LHS.countMinTrailingZeros(); |
875 | unsigned TrailZero1 = RHS.countMinTrailingZeros(); |
876 | unsigned TrailZ = TrailZero0 + TrailZero1; |
877 | |
878 | // Figure out the fewest known-bits operand. |
879 | unsigned SmallestOperand = |
880 | std::min(a: TrailBitsKnown0 - TrailZero0, b: TrailBitsKnown1 - TrailZero1); |
881 | unsigned ResultBitsKnown = std::min(a: SmallestOperand + TrailZ, b: BitWidth); |
882 | |
883 | APInt BottomKnown = |
884 | Bottom0.getLoBits(numBits: TrailBitsKnown0) * Bottom1.getLoBits(numBits: TrailBitsKnown1); |
885 | |
886 | KnownBits Res(BitWidth); |
887 | Res.Zero.setHighBits(LeadZ); |
888 | Res.Zero |= (~BottomKnown).getLoBits(numBits: ResultBitsKnown); |
889 | Res.One = BottomKnown.getLoBits(numBits: ResultBitsKnown); |
890 | |
891 | // If we're self-multiplying then bit[1] is guaranteed to be zero. |
892 | if (NoUndefSelfMultiply && BitWidth > 1) { |
893 | assert(Res.One[1] == 0 && |
894 | "Self-multiplication failed Quadratic Reciprocity!" ); |
895 | Res.Zero.setBit(1); |
896 | } |
897 | |
898 | return Res; |
899 | } |
900 | |
901 | KnownBits KnownBits::mulhs(const KnownBits &LHS, const KnownBits &RHS) { |
902 | unsigned BitWidth = LHS.getBitWidth(); |
903 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
904 | KnownBits WideLHS = LHS.sext(BitWidth: 2 * BitWidth); |
905 | KnownBits WideRHS = RHS.sext(BitWidth: 2 * BitWidth); |
906 | return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth); |
907 | } |
908 | |
909 | KnownBits KnownBits::mulhu(const KnownBits &LHS, const KnownBits &RHS) { |
910 | unsigned BitWidth = LHS.getBitWidth(); |
911 | assert(BitWidth == RHS.getBitWidth() && "Operand mismatch" ); |
912 | KnownBits WideLHS = LHS.zext(BitWidth: 2 * BitWidth); |
913 | KnownBits WideRHS = RHS.zext(BitWidth: 2 * BitWidth); |
914 | return mul(LHS: WideLHS, RHS: WideRHS).extractBits(NumBits: BitWidth, BitPosition: BitWidth); |
915 | } |
916 | |
917 | static KnownBits divComputeLowBit(KnownBits Known, const KnownBits &LHS, |
918 | const KnownBits &RHS, bool Exact) { |
919 | |
920 | if (!Exact) |
921 | return Known; |
922 | |
923 | // If LHS is Odd, the result is Odd no matter what. |
924 | // Odd / Odd -> Odd |
925 | // Odd / Even -> Impossible (because its exact division) |
926 | if (LHS.One[0]) |
927 | Known.One.setBit(0); |
928 | |
929 | int MinTZ = |
930 | (int)LHS.countMinTrailingZeros() - (int)RHS.countMaxTrailingZeros(); |
931 | int MaxTZ = |
932 | (int)LHS.countMaxTrailingZeros() - (int)RHS.countMinTrailingZeros(); |
933 | if (MinTZ >= 0) { |
934 | // Result has at least MinTZ trailing zeros. |
935 | Known.Zero.setLowBits(MinTZ); |
936 | if (MinTZ == MaxTZ) { |
937 | // Result has exactly MinTZ trailing zeros. |
938 | Known.One.setBit(MinTZ); |
939 | } |
940 | } else if (MaxTZ < 0) { |
941 | // Poison Result |
942 | Known.setAllZero(); |
943 | } |
944 | |
945 | // In the KnownBits exhaustive tests, we have poison inputs for exact values |
946 | // a LOT. If we have a conflict, just return all zeros. |
947 | if (Known.hasConflict()) |
948 | Known.setAllZero(); |
949 | |
950 | return Known; |
951 | } |
952 | |
953 | KnownBits KnownBits::sdiv(const KnownBits &LHS, const KnownBits &RHS, |
954 | bool Exact) { |
955 | // Equivalent of `udiv`. We must have caught this before it was folded. |
956 | if (LHS.isNonNegative() && RHS.isNonNegative()) |
957 | return udiv(LHS, RHS, Exact); |
958 | |
959 | unsigned BitWidth = LHS.getBitWidth(); |
960 | KnownBits Known(BitWidth); |
961 | |
962 | if (LHS.isZero() || RHS.isZero()) { |
963 | // Result is either known Zero or UB. Return Zero either way. |
964 | // Checking this earlier saves us a lot of special cases later on. |
965 | Known.setAllZero(); |
966 | return Known; |
967 | } |
968 | |
969 | std::optional<APInt> Res; |
970 | if (LHS.isNegative() && RHS.isNegative()) { |
971 | // Result non-negative. |
972 | APInt Denom = RHS.getSignedMaxValue(); |
973 | APInt Num = LHS.getSignedMinValue(); |
974 | // INT_MIN/-1 would be a poison result (impossible). Estimate the division |
975 | // as signed max (we will only set sign bit in the result). |
976 | Res = (Num.isMinSignedValue() && Denom.isAllOnes()) |
977 | ? APInt::getSignedMaxValue(numBits: BitWidth) |
978 | : Num.sdiv(RHS: Denom); |
979 | } else if (LHS.isNegative() && RHS.isNonNegative()) { |
980 | // Result is negative if Exact OR -LHS u>= RHS. |
981 | if (Exact || (-LHS.getSignedMaxValue()).uge(RHS: RHS.getSignedMaxValue())) { |
982 | APInt Denom = RHS.getSignedMinValue(); |
983 | APInt Num = LHS.getSignedMinValue(); |
984 | Res = Denom.isZero() ? Num : Num.sdiv(RHS: Denom); |
985 | } |
986 | } else if (LHS.isStrictlyPositive() && RHS.isNegative()) { |
987 | // Result is negative if Exact OR LHS u>= -RHS. |
988 | if (Exact || LHS.getSignedMinValue().uge(RHS: -RHS.getSignedMinValue())) { |
989 | APInt Denom = RHS.getSignedMaxValue(); |
990 | APInt Num = LHS.getSignedMaxValue(); |
991 | Res = Num.sdiv(RHS: Denom); |
992 | } |
993 | } |
994 | |
995 | if (Res) { |
996 | if (Res->isNonNegative()) { |
997 | unsigned LeadZ = Res->countLeadingZeros(); |
998 | Known.Zero.setHighBits(LeadZ); |
999 | } else { |
1000 | unsigned LeadO = Res->countLeadingOnes(); |
1001 | Known.One.setHighBits(LeadO); |
1002 | } |
1003 | } |
1004 | |
1005 | Known = divComputeLowBit(Known, LHS, RHS, Exact); |
1006 | return Known; |
1007 | } |
1008 | |
1009 | KnownBits KnownBits::udiv(const KnownBits &LHS, const KnownBits &RHS, |
1010 | bool Exact) { |
1011 | unsigned BitWidth = LHS.getBitWidth(); |
1012 | KnownBits Known(BitWidth); |
1013 | |
1014 | if (LHS.isZero() || RHS.isZero()) { |
1015 | // Result is either known Zero or UB. Return Zero either way. |
1016 | // Checking this earlier saves us a lot of special cases later on. |
1017 | Known.setAllZero(); |
1018 | return Known; |
1019 | } |
1020 | |
1021 | // We can figure out the minimum number of upper zero bits by doing |
1022 | // MaxNumerator / MinDenominator. If the Numerator gets smaller or Denominator |
1023 | // gets larger, the number of upper zero bits increases. |
1024 | APInt MinDenom = RHS.getMinValue(); |
1025 | APInt MaxNum = LHS.getMaxValue(); |
1026 | APInt MaxRes = MinDenom.isZero() ? MaxNum : MaxNum.udiv(RHS: MinDenom); |
1027 | |
1028 | unsigned LeadZ = MaxRes.countLeadingZeros(); |
1029 | |
1030 | Known.Zero.setHighBits(LeadZ); |
1031 | Known = divComputeLowBit(Known, LHS, RHS, Exact); |
1032 | |
1033 | return Known; |
1034 | } |
1035 | |
1036 | KnownBits KnownBits::remGetLowBits(const KnownBits &LHS, const KnownBits &RHS) { |
1037 | unsigned BitWidth = LHS.getBitWidth(); |
1038 | if (!RHS.isZero() && RHS.Zero[0]) { |
1039 | // rem X, Y where Y[0:N] is zero will preserve X[0:N] in the result. |
1040 | unsigned RHSZeros = RHS.countMinTrailingZeros(); |
1041 | APInt Mask = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: RHSZeros); |
1042 | APInt OnesMask = LHS.One & Mask; |
1043 | APInt ZerosMask = LHS.Zero & Mask; |
1044 | return KnownBits(ZerosMask, OnesMask); |
1045 | } |
1046 | return KnownBits(BitWidth); |
1047 | } |
1048 | |
1049 | KnownBits KnownBits::urem(const KnownBits &LHS, const KnownBits &RHS) { |
1050 | KnownBits Known = remGetLowBits(LHS, RHS); |
1051 | if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { |
1052 | // NB: Low bits set in `remGetLowBits`. |
1053 | APInt HighBits = ~(RHS.getConstant() - 1); |
1054 | Known.Zero |= HighBits; |
1055 | return Known; |
1056 | } |
1057 | |
1058 | // Since the result is less than or equal to either operand, any leading |
1059 | // zero bits in either operand must also exist in the result. |
1060 | uint32_t Leaders = |
1061 | std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinLeadingZeros()); |
1062 | Known.Zero.setHighBits(Leaders); |
1063 | return Known; |
1064 | } |
1065 | |
1066 | KnownBits KnownBits::srem(const KnownBits &LHS, const KnownBits &RHS) { |
1067 | KnownBits Known = remGetLowBits(LHS, RHS); |
1068 | if (RHS.isConstant() && RHS.getConstant().isPowerOf2()) { |
1069 | // NB: Low bits are set in `remGetLowBits`. |
1070 | APInt LowBits = RHS.getConstant() - 1; |
1071 | // If the first operand is non-negative or has all low bits zero, then |
1072 | // the upper bits are all zero. |
1073 | if (LHS.isNonNegative() || LowBits.isSubsetOf(RHS: LHS.Zero)) |
1074 | Known.Zero |= ~LowBits; |
1075 | |
1076 | // If the first operand is negative and not all low bits are zero, then |
1077 | // the upper bits are all one. |
1078 | if (LHS.isNegative() && LowBits.intersects(RHS: LHS.One)) |
1079 | Known.One |= ~LowBits; |
1080 | return Known; |
1081 | } |
1082 | |
1083 | // The sign bit is the LHS's sign bit, except when the result of the |
1084 | // remainder is zero. The magnitude of the result should be less than or |
1085 | // equal to the magnitude of either operand. |
1086 | if (LHS.isNegative() && Known.isNonZero()) |
1087 | Known.One.setHighBits( |
1088 | std::max(a: LHS.countMinLeadingOnes(), b: RHS.countMinSignBits())); |
1089 | else if (LHS.isNonNegative()) |
1090 | Known.Zero.setHighBits( |
1091 | std::max(a: LHS.countMinLeadingZeros(), b: RHS.countMinSignBits())); |
1092 | return Known; |
1093 | } |
1094 | |
1095 | KnownBits &KnownBits::operator&=(const KnownBits &RHS) { |
1096 | // Result bit is 0 if either operand bit is 0. |
1097 | Zero |= RHS.Zero; |
1098 | // Result bit is 1 if both operand bits are 1. |
1099 | One &= RHS.One; |
1100 | return *this; |
1101 | } |
1102 | |
1103 | KnownBits &KnownBits::operator|=(const KnownBits &RHS) { |
1104 | // Result bit is 0 if both operand bits are 0. |
1105 | Zero &= RHS.Zero; |
1106 | // Result bit is 1 if either operand bit is 1. |
1107 | One |= RHS.One; |
1108 | return *this; |
1109 | } |
1110 | |
1111 | KnownBits &KnownBits::operator^=(const KnownBits &RHS) { |
1112 | // Result bit is 0 if both operand bits are 0 or both are 1. |
1113 | APInt Z = (Zero & RHS.Zero) | (One & RHS.One); |
1114 | // Result bit is 1 if one operand bit is 0 and the other is 1. |
1115 | One = (Zero & RHS.One) | (One & RHS.Zero); |
1116 | Zero = std::move(Z); |
1117 | return *this; |
1118 | } |
1119 | |
1120 | KnownBits KnownBits::blsi() const { |
1121 | unsigned BitWidth = getBitWidth(); |
1122 | KnownBits Known(Zero, APInt(BitWidth, 0)); |
1123 | unsigned Max = countMaxTrailingZeros(); |
1124 | Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth)); |
1125 | unsigned Min = countMinTrailingZeros(); |
1126 | if (Max == Min && Max < BitWidth) |
1127 | Known.One.setBit(Max); |
1128 | return Known; |
1129 | } |
1130 | |
1131 | KnownBits KnownBits::blsmsk() const { |
1132 | unsigned BitWidth = getBitWidth(); |
1133 | KnownBits Known(BitWidth); |
1134 | unsigned Max = countMaxTrailingZeros(); |
1135 | Known.Zero.setBitsFrom(std::min(a: Max + 1, b: BitWidth)); |
1136 | unsigned Min = countMinTrailingZeros(); |
1137 | Known.One.setLowBits(std::min(a: Min + 1, b: BitWidth)); |
1138 | return Known; |
1139 | } |
1140 | |
1141 | void KnownBits::print(raw_ostream &OS) const { |
1142 | unsigned BitWidth = getBitWidth(); |
1143 | for (unsigned I = 0; I < BitWidth; ++I) { |
1144 | unsigned N = BitWidth - I - 1; |
1145 | if (Zero[N] && One[N]) |
1146 | OS << "!" ; |
1147 | else if (Zero[N]) |
1148 | OS << "0" ; |
1149 | else if (One[N]) |
1150 | OS << "1" ; |
1151 | else |
1152 | OS << "?" ; |
1153 | } |
1154 | } |
1155 | |
1156 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
1157 | LLVM_DUMP_METHOD void KnownBits::dump() const { |
1158 | print(dbgs()); |
1159 | dbgs() << "\n" ; |
1160 | } |
1161 | #endif |
1162 | |