1//===- ValueTracking.cpp - Walk computations to compute properties --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains routines that help analyze properties that chains of
10// computations have.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/ValueTracking.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/FloatingPointMode.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/ADT/iterator_range.h"
25#include "llvm/Analysis/AliasAnalysis.h"
26#include "llvm/Analysis/AssumeBundleQueries.h"
27#include "llvm/Analysis/AssumptionCache.h"
28#include "llvm/Analysis/ConstantFolding.h"
29#include "llvm/Analysis/DomConditionCache.h"
30#include "llvm/Analysis/FloatingPointPredicateUtils.h"
31#include "llvm/Analysis/GuardUtils.h"
32#include "llvm/Analysis/InstructionSimplify.h"
33#include "llvm/Analysis/Loads.h"
34#include "llvm/Analysis/LoopInfo.h"
35#include "llvm/Analysis/TargetLibraryInfo.h"
36#include "llvm/Analysis/VectorUtils.h"
37#include "llvm/Analysis/WithCache.h"
38#include "llvm/IR/Argument.h"
39#include "llvm/IR/Attributes.h"
40#include "llvm/IR/BasicBlock.h"
41#include "llvm/IR/Constant.h"
42#include "llvm/IR/ConstantFPRange.h"
43#include "llvm/IR/ConstantRange.h"
44#include "llvm/IR/Constants.h"
45#include "llvm/IR/DerivedTypes.h"
46#include "llvm/IR/DiagnosticInfo.h"
47#include "llvm/IR/Dominators.h"
48#include "llvm/IR/EHPersonalities.h"
49#include "llvm/IR/Function.h"
50#include "llvm/IR/GetElementPtrTypeIterator.h"
51#include "llvm/IR/GlobalAlias.h"
52#include "llvm/IR/GlobalValue.h"
53#include "llvm/IR/GlobalVariable.h"
54#include "llvm/IR/InstrTypes.h"
55#include "llvm/IR/Instruction.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/IntrinsicInst.h"
58#include "llvm/IR/Intrinsics.h"
59#include "llvm/IR/IntrinsicsAArch64.h"
60#include "llvm/IR/IntrinsicsAMDGPU.h"
61#include "llvm/IR/IntrinsicsRISCV.h"
62#include "llvm/IR/IntrinsicsX86.h"
63#include "llvm/IR/LLVMContext.h"
64#include "llvm/IR/Metadata.h"
65#include "llvm/IR/Module.h"
66#include "llvm/IR/Operator.h"
67#include "llvm/IR/PatternMatch.h"
68#include "llvm/IR/Type.h"
69#include "llvm/IR/User.h"
70#include "llvm/IR/Value.h"
71#include "llvm/Support/Casting.h"
72#include "llvm/Support/CommandLine.h"
73#include "llvm/Support/Compiler.h"
74#include "llvm/Support/ErrorHandling.h"
75#include "llvm/Support/KnownBits.h"
76#include "llvm/Support/KnownFPClass.h"
77#include "llvm/Support/MathExtras.h"
78#include "llvm/TargetParser/RISCVTargetParser.h"
79#include <algorithm>
80#include <cassert>
81#include <cstdint>
82#include <optional>
83#include <utility>
84
85using namespace llvm;
86using namespace llvm::PatternMatch;
87
88// Controls the number of uses of the value searched for possible
89// dominating comparisons.
90static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
91 cl::Hidden, cl::init(Val: 20));
92
93/// Maximum number of instructions to check between assume and context
94/// instruction.
95static constexpr unsigned MaxInstrsToCheckForFree = 16;
96
97/// Returns the bitwidth of the given scalar or pointer type. For vector types,
98/// returns the element type's bitwidth.
99static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
100 if (unsigned BitWidth = Ty->getScalarSizeInBits())
101 return BitWidth;
102
103 return DL.getPointerTypeSizeInBits(Ty);
104}
105
106// Given the provided Value and, potentially, a context instruction, return
107// the preferred context instruction (if any).
108static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
109 // If we've been provided with a context instruction, then use that (provided
110 // it has been inserted).
111 if (CxtI && CxtI->getParent())
112 return CxtI;
113
114 // If the value is really an already-inserted instruction, then use that.
115 CxtI = dyn_cast<Instruction>(Val: V);
116 if (CxtI && CxtI->getParent())
117 return CxtI;
118
119 return nullptr;
120}
121
122static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
123 const APInt &DemandedElts,
124 APInt &DemandedLHS, APInt &DemandedRHS) {
125 if (isa<ScalableVectorType>(Val: Shuf->getType())) {
126 assert(DemandedElts == APInt(1,1));
127 DemandedLHS = DemandedRHS = DemandedElts;
128 return true;
129 }
130
131 int NumElts =
132 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
133 return llvm::getShuffleDemandedElts(SrcWidth: NumElts, Mask: Shuf->getShuffleMask(),
134 DemandedElts, DemandedLHS, DemandedRHS);
135}
136
137static void computeKnownBits(const Value *V, const APInt &DemandedElts,
138 KnownBits &Known, const SimplifyQuery &Q,
139 unsigned Depth);
140
141void llvm::computeKnownBits(const Value *V, KnownBits &Known,
142 const SimplifyQuery &Q, unsigned Depth) {
143 // Since the number of lanes in a scalable vector is unknown at compile time,
144 // we track one bit which is implicitly broadcast to all lanes. This means
145 // that all lanes in a scalable vector are considered demanded.
146 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
147 APInt DemandedElts =
148 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
149 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
150}
151
152void llvm::computeKnownBits(const Value *V, KnownBits &Known,
153 const DataLayout &DL, AssumptionCache *AC,
154 const Instruction *CxtI, const DominatorTree *DT,
155 bool UseInstrInfo, unsigned Depth) {
156 computeKnownBits(V, Known,
157 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
158 Depth);
159}
160
161KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
162 AssumptionCache *AC, const Instruction *CxtI,
163 const DominatorTree *DT, bool UseInstrInfo,
164 unsigned Depth) {
165 return computeKnownBits(
166 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
167}
168
169KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
170 const DataLayout &DL, AssumptionCache *AC,
171 const Instruction *CxtI,
172 const DominatorTree *DT, bool UseInstrInfo,
173 unsigned Depth) {
174 return computeKnownBits(
175 V, DemandedElts,
176 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
177}
178
179static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
180 const SimplifyQuery &SQ) {
181 // Look for an inverted mask: (X & ~M) op (Y & M).
182 {
183 Value *M;
184 if (match(V: LHS, P: m_c_And(L: m_Not(V: m_Value(V&: M)), R: m_Value())) &&
185 match(V: RHS, P: m_c_And(L: m_Specific(V: M), R: m_Value())) &&
186 isGuaranteedNotToBeUndef(V: M, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
187 return true;
188 }
189
190 // X op (Y & ~X)
191 if (match(V: RHS, P: m_c_And(L: m_Not(V: m_Specific(V: LHS)), R: m_Value())) &&
192 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
193 return true;
194
195 // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern
196 // for constant Y.
197 Value *Y;
198 if (match(V: RHS,
199 P: m_c_Xor(L: m_c_And(L: m_Specific(V: LHS), R: m_Value(V&: Y)), R: m_Deferred(V: Y))) &&
200 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
201 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
202 return true;
203
204 // Peek through extends to find a 'not' of the other side:
205 // (ext Y) op ext(~Y)
206 if (match(V: LHS, P: m_ZExtOrSExt(Op: m_Value(V&: Y))) &&
207 match(V: RHS, P: m_ZExtOrSExt(Op: m_Not(V: m_Specific(V: Y)))) &&
208 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
209 return true;
210
211 // Look for: (A & B) op ~(A | B)
212 {
213 Value *A, *B;
214 if (match(V: LHS, P: m_And(L: m_Value(V&: A), R: m_Value(V&: B))) &&
215 match(V: RHS, P: m_Not(V: m_c_Or(L: m_Specific(V: A), R: m_Specific(V: B)))) &&
216 isGuaranteedNotToBeUndef(V: A, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
217 isGuaranteedNotToBeUndef(V: B, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
218 return true;
219 }
220
221 // Look for: (X << V) op (Y >> (BitWidth - V))
222 // or (X >> V) op (Y << (BitWidth - V))
223 {
224 const Value *V;
225 const APInt *R;
226 if (((match(V: RHS, P: m_Shl(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
227 match(V: LHS, P: m_LShr(L: m_Value(), R: m_Specific(V)))) ||
228 (match(V: RHS, P: m_LShr(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
229 match(V: LHS, P: m_Shl(L: m_Value(), R: m_Specific(V))))) &&
230 R->uge(RHS: LHS->getType()->getScalarSizeInBits()))
231 return true;
232 }
233
234 return false;
235}
236
237bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
238 const WithCache<const Value *> &RHSCache,
239 const SimplifyQuery &SQ) {
240 const Value *LHS = LHSCache.getValue();
241 const Value *RHS = RHSCache.getValue();
242
243 assert(LHS->getType() == RHS->getType() &&
244 "LHS and RHS should have the same type");
245 assert(LHS->getType()->isIntOrIntVectorTy() &&
246 "LHS and RHS should be integers");
247
248 if (haveNoCommonBitsSetSpecialCases(LHS, RHS, SQ) ||
249 haveNoCommonBitsSetSpecialCases(LHS: RHS, RHS: LHS, SQ))
250 return true;
251
252 return KnownBits::haveNoCommonBitsSet(LHS: LHSCache.getKnownBits(Q: SQ),
253 RHS: RHSCache.getKnownBits(Q: SQ));
254}
255
256bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
257 return !I->user_empty() &&
258 all_of(Range: I->users(), P: match_fn(P: m_ICmp(L: m_Value(), R: m_Zero())));
259}
260
261bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
262 return !I->user_empty() && all_of(Range: I->users(), P: [](const User *U) {
263 CmpPredicate P;
264 return match(V: U, P: m_ICmp(Pred&: P, L: m_Value(), R: m_Zero())) && ICmpInst::isEquality(P);
265 });
266}
267
268bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
269 bool OrZero, AssumptionCache *AC,
270 const Instruction *CxtI,
271 const DominatorTree *DT, bool UseInstrInfo,
272 unsigned Depth) {
273 return ::isKnownToBeAPowerOfTwo(
274 V, OrZero, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
275 Depth);
276}
277
278static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
279 const SimplifyQuery &Q, unsigned Depth);
280
281bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
282 unsigned Depth) {
283 return computeKnownBits(V, Q: SQ, Depth).isNonNegative();
284}
285
286bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
287 unsigned Depth) {
288 if (auto *CI = dyn_cast<ConstantInt>(Val: V))
289 return CI->getValue().isStrictlyPositive();
290
291 // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
292 // this updated.
293 KnownBits Known = computeKnownBits(V, Q: SQ, Depth);
294 return Known.isNonNegative() &&
295 (Known.isNonZero() || isKnownNonZero(V, Q: SQ, Depth));
296}
297
298bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
299 unsigned Depth) {
300 return computeKnownBits(V, Q: SQ, Depth).isNegative();
301}
302
303static bool isKnownNonEqual(const Value *V1, const Value *V2,
304 const APInt &DemandedElts, const SimplifyQuery &Q,
305 unsigned Depth);
306
307bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
308 const SimplifyQuery &Q, unsigned Depth) {
309 // We don't support looking through casts.
310 if (V1 == V2 || V1->getType() != V2->getType())
311 return false;
312 auto *FVTy = dyn_cast<FixedVectorType>(Val: V1->getType());
313 APInt DemandedElts =
314 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
315 return ::isKnownNonEqual(V1, V2, DemandedElts, Q, Depth);
316}
317
318bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
319 const SimplifyQuery &SQ, unsigned Depth) {
320 KnownBits Known(Mask.getBitWidth());
321 computeKnownBits(V, Known, Q: SQ, Depth);
322 return Mask.isSubsetOf(RHS: Known.Zero);
323}
324
325static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
326 const SimplifyQuery &Q, unsigned Depth);
327
328static unsigned ComputeNumSignBits(const Value *V, const SimplifyQuery &Q,
329 unsigned Depth = 0) {
330 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
331 APInt DemandedElts =
332 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
333 return ComputeNumSignBits(V, DemandedElts, Q, Depth);
334}
335
336unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
337 AssumptionCache *AC, const Instruction *CxtI,
338 const DominatorTree *DT, bool UseInstrInfo,
339 unsigned Depth) {
340 return ::ComputeNumSignBits(
341 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
342}
343
344unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
345 AssumptionCache *AC,
346 const Instruction *CxtI,
347 const DominatorTree *DT,
348 unsigned Depth) {
349 unsigned SignBits = ComputeNumSignBits(V, DL, AC, CxtI, DT, UseInstrInfo: Depth);
350 return V->getType()->getScalarSizeInBits() - SignBits + 1;
351}
352
353/// Try to detect the lerp pattern: a * (b - c) + c * d
354/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355///
356/// In that particular case, we can use the following chain of reasoning:
357///
358/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359///
360/// Since that is true for arbitrary a, b, c and d within our constraints, we
361/// can conclude that:
362///
363/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364///
365/// Considering that any result of the lerp would be less or equal to U, it
366/// would have at least the number of leading 0s as in U.
367///
368/// While being quite a specific situation, it is fairly common in computer
369/// graphics in the shape of alpha blending.
370///
371/// Modifies given KnownOut in-place with the inferred information.
372static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
373 const APInt &DemandedElts,
374 KnownBits &KnownOut,
375 const SimplifyQuery &Q,
376 unsigned Depth) {
377
378 Type *Ty = Op0->getType();
379 const unsigned BitWidth = Ty->getScalarSizeInBits();
380
381 // Only handle scalar types for now
382 if (Ty->isVectorTy())
383 return;
384
385 // Try to match: a * (b - c) + c * d.
386 // When a == 1 => A == nullptr, the same applies to d/D as well.
387 const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
388 const Instruction *SubBC = nullptr;
389
390 const auto MatchSubBC = [&]() {
391 // (b - c) can have two forms that interest us:
392 //
393 // 1. sub nuw %b, %c
394 // 2. xor %c, %b
395 //
396 // For the first case, nuw flag guarantees our requirement b >= c.
397 //
398 // The second case might happen when the analysis can infer that b is a mask
399 // for c and we can transform sub operation into xor (that is usually true
400 // for constant b's). Even though xor is symmetrical, canonicalization
401 // ensures that the constant will be the RHS. We have additional checks
402 // later on to ensure that this xor operation is equivalent to subtraction.
403 return m_Instruction(I&: SubBC, Match: m_CombineOr(L: m_NUWSub(L: m_Value(V&: B), R: m_Value(V&: C)),
404 R: m_Xor(L: m_Value(V&: C), R: m_Value(V&: B))));
405 };
406
407 const auto MatchASubBC = [&]() {
408 // Cases:
409 // - a * (b - c)
410 // - (b - c) * a
411 // - (b - c) <- a implicitly equals 1
412 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: A), R: MatchSubBC()), R: MatchSubBC());
413 };
414
415 const auto MatchCD = [&]() {
416 // Cases:
417 // - d * c
418 // - c * d
419 // - c <- d implicitly equals 1
420 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: D), R: m_Specific(V: C)), R: m_Specific(V: C));
421 };
422
423 const auto Match = [&](const Value *LHS, const Value *RHS) {
424 // We do use m_Specific(C) in MatchCD, so we have to make sure that
425 // it's bound to anything and match(LHS, MatchASubBC()) absolutely
426 // has to evaluate first and return true.
427 //
428 // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
429 return match(V: LHS, P: MatchASubBC()) && match(V: RHS, P: MatchCD());
430 };
431
432 if (!Match(Op0, Op1) && !Match(Op1, Op0))
433 return;
434
435 const auto ComputeKnownBitsOrOne = [&](const Value *V) {
436 // For some of the values we use the convention of leaving
437 // it nullptr to signify an implicit constant 1.
438 return V ? computeKnownBits(V, DemandedElts, Q, Depth: Depth + 1)
439 : KnownBits::makeConstant(C: APInt(BitWidth, 1));
440 };
441
442 // Check that all operands are non-negative
443 const KnownBits KnownA = ComputeKnownBitsOrOne(A);
444 if (!KnownA.isNonNegative())
445 return;
446
447 const KnownBits KnownD = ComputeKnownBitsOrOne(D);
448 if (!KnownD.isNonNegative())
449 return;
450
451 const KnownBits KnownB = computeKnownBits(V: B, DemandedElts, Q, Depth: Depth + 1);
452 if (!KnownB.isNonNegative())
453 return;
454
455 const KnownBits KnownC = computeKnownBits(V: C, DemandedElts, Q, Depth: Depth + 1);
456 if (!KnownC.isNonNegative())
457 return;
458
459 // If we matched subtraction as xor, we need to actually check that xor
460 // is semantically equivalent to subtraction.
461 //
462 // For that to be true, b has to be a mask for c or that b's known
463 // ones cover all known and possible ones of c.
464 if (SubBC->getOpcode() == Instruction::Xor &&
465 !KnownC.getMaxValue().isSubsetOf(RHS: KnownB.getMinValue()))
466 return;
467
468 const APInt MaxA = KnownA.getMaxValue();
469 const APInt MaxD = KnownD.getMaxValue();
470 const APInt MaxAD = APIntOps::umax(A: MaxA, B: MaxD);
471 const APInt MaxB = KnownB.getMaxValue();
472
473 // We can't infer leading zeros info if the upper-bound estimate wraps.
474 bool Overflow;
475 const APInt UpperBound = MaxAD.umul_ov(RHS: MaxB, Overflow);
476
477 if (Overflow)
478 return;
479
480 // If we know that x <= y and both are positive than x has at least the same
481 // number of leading zeros as y.
482 const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
483 KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
484}
485
486static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
487 bool NSW, bool NUW,
488 const APInt &DemandedElts,
489 KnownBits &KnownOut, KnownBits &Known2,
490 const SimplifyQuery &Q, unsigned Depth) {
491 computeKnownBits(V: Op1, DemandedElts, Known&: KnownOut, Q, Depth: Depth + 1);
492
493 // If one operand is unknown and we have no nowrap information,
494 // the result will be unknown independently of the second operand.
495 if (KnownOut.isUnknown() && !NSW && !NUW)
496 return;
497
498 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
499 KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, LHS: Known2, RHS: KnownOut);
500
501 if (!Add && NSW && !KnownOut.isNonNegative() &&
502 (isImpliedByDomCondition(Pred: ICmpInst::ICMP_SLE, LHS: Op1, RHS: Op0, ContextI: Q.CxtI, DL: Q.DL)
503 .value_or(u: false) ||
504 match(V: Op1, P: m_c_SMin(L: m_Specific(V: Op0), R: m_Value()))))
505 KnownOut.makeNonNegative();
506
507 if (Add)
508 // Try to match lerp pattern and combine results
509 computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
510}
511
512static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
513 bool NUW, const APInt &DemandedElts,
514 KnownBits &Known, KnownBits &Known2,
515 const SimplifyQuery &Q, unsigned Depth) {
516 computeKnownBits(V: Op1, DemandedElts, Known, Q, Depth: Depth + 1);
517 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
518
519 bool isKnownNegative = false;
520 bool isKnownNonNegative = false;
521 // If the multiplication is known not to overflow, compute the sign bit.
522 if (NSW) {
523 if (Op0 == Op1) {
524 // The product of a number with itself is non-negative.
525 isKnownNonNegative = true;
526 } else {
527 bool isKnownNonNegativeOp1 = Known.isNonNegative();
528 bool isKnownNonNegativeOp0 = Known2.isNonNegative();
529 bool isKnownNegativeOp1 = Known.isNegative();
530 bool isKnownNegativeOp0 = Known2.isNegative();
531 // The product of two numbers with the same sign is non-negative.
532 isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
533 (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
534 if (!isKnownNonNegative && NUW) {
535 // mul nuw nsw with a factor > 1 is non-negative.
536 KnownBits One = KnownBits::makeConstant(C: APInt(Known.getBitWidth(), 1));
537 isKnownNonNegative = KnownBits::sgt(LHS: Known, RHS: One).value_or(u: false) ||
538 KnownBits::sgt(LHS: Known2, RHS: One).value_or(u: false);
539 }
540
541 // The product of a negative number and a non-negative number is either
542 // negative or zero.
543 if (!isKnownNonNegative)
544 isKnownNegative =
545 (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
546 Known2.isNonZero()) ||
547 (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
548 }
549 }
550
551 bool SelfMultiply = Op0 == Op1;
552 if (SelfMultiply)
553 SelfMultiply &=
554 isGuaranteedNotToBeUndef(V: Op0, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1);
555 Known = KnownBits::mul(LHS: Known, RHS: Known2, NoUndefSelfMultiply: SelfMultiply);
556
557 if (SelfMultiply) {
558 unsigned SignBits = ComputeNumSignBits(V: Op0, DemandedElts, Q, Depth: Depth + 1);
559 unsigned TyBits = Op0->getType()->getScalarSizeInBits();
560 unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
561
562 if (OutValidBits < TyBits) {
563 APInt KnownZeroMask =
564 APInt::getHighBitsSet(numBits: TyBits, hiBitsSet: TyBits - OutValidBits + 1);
565 Known.Zero |= KnownZeroMask;
566 }
567 }
568
569 // Only make use of no-wrap flags if we failed to compute the sign bit
570 // directly. This matters if the multiplication always overflows, in
571 // which case we prefer to follow the result of the direct computation,
572 // though as the program is invoking undefined behaviour we can choose
573 // whatever we like here.
574 if (isKnownNonNegative && !Known.isNegative())
575 Known.makeNonNegative();
576 else if (isKnownNegative && !Known.isNonNegative())
577 Known.makeNegative();
578}
579
580void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
581 KnownBits &Known) {
582 unsigned BitWidth = Known.getBitWidth();
583 unsigned NumRanges = Ranges.getNumOperands() / 2;
584 assert(NumRanges >= 1);
585
586 Known.setAllConflict();
587
588 for (unsigned i = 0; i < NumRanges; ++i) {
589 ConstantInt *Lower =
590 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 0));
591 ConstantInt *Upper =
592 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 1));
593 ConstantRange Range(Lower->getValue(), Upper->getValue());
594 // BitWidth must equal the Ranges BitWidth for the correct number of high
595 // bits to be set.
596 assert(BitWidth == Range.getBitWidth() &&
597 "Known bit width must match range bit width!");
598
599 // The first CommonPrefixBits of all values in Range are equal.
600 unsigned CommonPrefixBits =
601 (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
602 APInt Mask = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: CommonPrefixBits);
603 APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(width: BitWidth);
604 Known.One &= UnsignedMax & Mask;
605 Known.Zero &= ~UnsignedMax & Mask;
606 }
607}
608
609static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
610 SmallVector<const Instruction *, 16> WorkSet(1, I);
611 SmallPtrSet<const Instruction *, 32> Visited;
612 SmallPtrSet<const Instruction *, 16> EphValues;
613
614 // The instruction defining an assumption's condition itself is always
615 // considered ephemeral to that assumption (even if it has other
616 // non-ephemeral users). See r246696's test case for an example.
617 if (is_contained(Range: I->operands(), Element: E))
618 return true;
619
620 while (!WorkSet.empty()) {
621 const Instruction *V = WorkSet.pop_back_val();
622 if (!Visited.insert(Ptr: V).second)
623 continue;
624
625 // If all uses of this value are ephemeral, then so is this value.
626 if (all_of(Range: V->users(), P: [&](const User *U) {
627 return EphValues.count(Ptr: cast<Instruction>(Val: U));
628 })) {
629 if (V == E)
630 return true;
631
632 if (V == I || (!V->mayHaveSideEffects() && !V->isTerminator())) {
633 EphValues.insert(Ptr: V);
634
635 for (const Use &U : V->operands()) {
636 if (const auto *I = dyn_cast<Instruction>(Val: U.get()))
637 WorkSet.push_back(Elt: I);
638 }
639 }
640 }
641 }
642
643 return false;
644}
645
646// Is this an intrinsic that cannot be speculated but also cannot trap?
647bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
648 if (const IntrinsicInst *CI = dyn_cast<IntrinsicInst>(Val: I))
649 return CI->isAssumeLikeIntrinsic();
650
651 return false;
652}
653
654bool llvm::isValidAssumeForContext(const Instruction *Inv,
655 const Instruction *CxtI,
656 const DominatorTree *DT,
657 bool AllowEphemerals) {
658 // There are two restrictions on the use of an assume:
659 // 1. The assume must dominate the context (or the control flow must
660 // reach the assume whenever it reaches the context).
661 // 2. The context must not be in the assume's set of ephemeral values
662 // (otherwise we will use the assume to prove that the condition
663 // feeding the assume is trivially true, thus causing the removal of
664 // the assume).
665
666 if (Inv->getParent() == CxtI->getParent()) {
667 // If Inv and CtxI are in the same block, check if the assume (Inv) is first
668 // in the BB.
669 if (Inv->comesBefore(Other: CxtI))
670 return true;
671
672 // Don't let an assume affect itself - this would cause the problems
673 // `isEphemeralValueOf` is trying to prevent, and it would also make
674 // the loop below go out of bounds.
675 if (!AllowEphemerals && Inv == CxtI)
676 return false;
677
678 // The context comes first, but they're both in the same block.
679 // Make sure there is nothing in between that might interrupt
680 // the control flow, not even CxtI itself.
681 // We limit the scan distance between the assume and its context instruction
682 // to avoid a compile-time explosion. This limit is chosen arbitrarily, so
683 // it can be adjusted if needed (could be turned into a cl::opt).
684 auto Range = make_range(x: CxtI->getIterator(), y: Inv->getIterator());
685 if (!isGuaranteedToTransferExecutionToSuccessor(Range, ScanLimit: 15))
686 return false;
687
688 return AllowEphemerals || !isEphemeralValueOf(I: Inv, E: CxtI);
689 }
690
691 // Inv and CxtI are in different blocks.
692 if (DT) {
693 if (DT->dominates(Def: Inv, User: CxtI))
694 return true;
695 } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor() ||
696 Inv->getParent()->isEntryBlock()) {
697 // We don't have a DT, but this trivially dominates.
698 return true;
699 }
700
701 return false;
702}
703
704bool llvm::willNotFreeBetween(const Instruction *Assume,
705 const Instruction *CtxI) {
706 // Helper to check if there are any calls in the range that may free memory.
707 auto hasNoFreeCalls = [](auto Range) {
708 for (const auto &[Idx, I] : enumerate(Range)) {
709 if (Idx > MaxInstrsToCheckForFree)
710 return false;
711 if (const auto *CB = dyn_cast<CallBase>(&I))
712 if (!CB->hasFnAttr(Attribute::NoFree))
713 return false;
714 }
715 return true;
716 };
717
718 // Make sure the current function cannot arrange for another thread to free on
719 // its behalf.
720 if (!CtxI->getFunction()->hasNoSync())
721 return false;
722
723 // Handle cross-block case: CtxI in a successor of Assume's block.
724 const BasicBlock *CtxBB = CtxI->getParent();
725 const BasicBlock *AssumeBB = Assume->getParent();
726 BasicBlock::const_iterator CtxIter = CtxI->getIterator();
727 if (CtxBB != AssumeBB) {
728 if (CtxBB->getSinglePredecessor() != AssumeBB)
729 return false;
730
731 if (!hasNoFreeCalls(make_range(x: CtxBB->begin(), y: CtxIter)))
732 return false;
733
734 CtxIter = AssumeBB->end();
735 } else {
736 // Same block case: check that Assume comes before CtxI.
737 if (!Assume->comesBefore(Other: CtxI))
738 return false;
739 }
740
741 // Check if there are any calls between Assume and CtxIter that may free
742 // memory.
743 return hasNoFreeCalls(make_range(x: Assume->getIterator(), y: CtxIter));
744}
745
746// TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but
747// we still have enough information about `RHS` to conclude non-zero. For
748// example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops
749// so the extra compile time may not be worth it, but possibly a second API
750// should be created for use outside of loops.
751static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
752 // v u> y implies v != 0.
753 if (Pred == ICmpInst::ICMP_UGT)
754 return true;
755
756 // Special-case v != 0 to also handle v != null.
757 if (Pred == ICmpInst::ICMP_NE)
758 return match(V: RHS, P: m_Zero());
759
760 // All other predicates - rely on generic ConstantRange handling.
761 const APInt *C;
762 auto Zero = APInt::getZero(numBits: RHS->getType()->getScalarSizeInBits());
763 if (match(V: RHS, P: m_APInt(Res&: C))) {
764 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(Pred, Other: *C);
765 return !TrueValues.contains(Val: Zero);
766 }
767
768 auto *VC = dyn_cast<ConstantDataVector>(Val: RHS);
769 if (VC == nullptr)
770 return false;
771
772 for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
773 ++ElemIdx) {
774 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
775 Pred, Other: VC->getElementAsAPInt(i: ElemIdx));
776 if (TrueValues.contains(Val: Zero))
777 return false;
778 }
779 return true;
780}
781
782static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
783 Value *&ValOut, Instruction *&CtxIOut,
784 const PHINode **PhiOut = nullptr) {
785 ValOut = U->get();
786 if (ValOut == PHI)
787 return;
788 CtxIOut = PHI->getIncomingBlock(U: *U)->getTerminator();
789 if (PhiOut)
790 *PhiOut = PHI;
791 Value *V;
792 // If the Use is a select of this phi, compute analysis on other arm to break
793 // recursion.
794 // TODO: Min/Max
795 if (match(V: ValOut, P: m_Select(C: m_Value(), L: m_Specific(V: PHI), R: m_Value(V))) ||
796 match(V: ValOut, P: m_Select(C: m_Value(), L: m_Value(V), R: m_Specific(V: PHI))))
797 ValOut = V;
798
799 // Same for select, if this phi is 2-operand phi, compute analysis on other
800 // incoming value to break recursion.
801 // TODO: We could handle any number of incoming edges as long as we only have
802 // two unique values.
803 if (auto *IncPhi = dyn_cast<PHINode>(Val: ValOut);
804 IncPhi && IncPhi->getNumIncomingValues() == 2) {
805 for (int Idx = 0; Idx < 2; ++Idx) {
806 if (IncPhi->getIncomingValue(i: Idx) == PHI) {
807 ValOut = IncPhi->getIncomingValue(i: 1 - Idx);
808 if (PhiOut)
809 *PhiOut = IncPhi;
810 CtxIOut = IncPhi->getIncomingBlock(i: 1 - Idx)->getTerminator();
811 break;
812 }
813 }
814 }
815}
816
817static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
818 // Use of assumptions is context-sensitive. If we don't have a context, we
819 // cannot use them!
820 if (!Q.AC || !Q.CxtI)
821 return false;
822
823 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
824 if (!Elem.Assume)
825 continue;
826
827 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
828 assert(I->getFunction() == Q.CxtI->getFunction() &&
829 "Got assumption for the wrong function!");
830
831 if (Elem.Index != AssumptionCache::ExprResultIdx) {
832 if (!V->getType()->isPointerTy())
833 continue;
834 if (RetainedKnowledge RK = getKnowledgeFromBundle(
835 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
836 if (RK.WasOn != V)
837 continue;
838 bool AssumeImpliesNonNull = [&]() {
839 if (RK.AttrKind == Attribute::NonNull)
840 return true;
841
842 if (RK.AttrKind == Attribute::Dereferenceable) {
843 if (NullPointerIsDefined(F: Q.CxtI->getFunction(),
844 AS: V->getType()->getPointerAddressSpace()))
845 return false;
846 assert(RK.IRArgValue &&
847 "Dereferenceable attribute without IR argument?");
848
849 auto *CI = dyn_cast<ConstantInt>(Val: RK.IRArgValue);
850 return CI && !CI->isZero();
851 }
852
853 return false;
854 }();
855 if (AssumeImpliesNonNull && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
856 return true;
857 }
858 continue;
859 }
860
861 // Warning: This loop can end up being somewhat performance sensitive.
862 // We're running this loop for once for each value queried resulting in a
863 // runtime of ~O(#assumes * #values).
864
865 Value *RHS;
866 CmpPredicate Pred;
867 auto m_V = m_CombineOr(L: m_Specific(V), R: m_PtrToInt(Op: m_Specific(V)));
868 if (!match(V: I->getArgOperand(i: 0), P: m_c_ICmp(Pred, L: m_V, R: m_Value(V&: RHS))))
869 continue;
870
871 if (cmpExcludesZero(Pred, RHS) && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
872 return true;
873 }
874
875 return false;
876}
877
878static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
879 Value *LHS, Value *RHS, KnownBits &Known,
880 const SimplifyQuery &Q) {
881 if (RHS->getType()->isPointerTy()) {
882 // Handle comparison of pointer to null explicitly, as it will not be
883 // covered by the m_APInt() logic below.
884 if (LHS == V && match(V: RHS, P: m_Zero())) {
885 switch (Pred) {
886 case ICmpInst::ICMP_EQ:
887 Known.setAllZero();
888 break;
889 case ICmpInst::ICMP_SGE:
890 case ICmpInst::ICMP_SGT:
891 Known.makeNonNegative();
892 break;
893 case ICmpInst::ICMP_SLT:
894 Known.makeNegative();
895 break;
896 default:
897 break;
898 }
899 }
900 return;
901 }
902
903 unsigned BitWidth = Known.getBitWidth();
904 auto m_V =
905 m_CombineOr(L: m_Specific(V), R: m_PtrToIntSameSize(DL: Q.DL, Op: m_Specific(V)));
906
907 Value *Y;
908 const APInt *Mask, *C;
909 if (!match(V: RHS, P: m_APInt(Res&: C)))
910 return;
911
912 uint64_t ShAmt;
913 switch (Pred) {
914 case ICmpInst::ICMP_EQ:
915 // assume(V = C)
916 if (match(V: LHS, P: m_V)) {
917 Known = Known.unionWith(RHS: KnownBits::makeConstant(C: *C));
918 // assume(V & Mask = C)
919 } else if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value(V&: Y)))) {
920 // For one bits in Mask, we can propagate bits from C to V.
921 Known.One |= *C;
922 if (match(V: Y, P: m_APInt(Res&: Mask)))
923 Known.Zero |= ~*C & *Mask;
924 // assume(V | Mask = C)
925 } else if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value(V&: Y)))) {
926 // For zero bits in Mask, we can propagate bits from C to V.
927 Known.Zero |= ~*C;
928 if (match(V: Y, P: m_APInt(Res&: Mask)))
929 Known.One |= *C & ~*Mask;
930 // assume(V << ShAmt = C)
931 } else if (match(V: LHS, P: m_Shl(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
932 ShAmt < BitWidth) {
933 // For those bits in C that are known, we can propagate them to known
934 // bits in V shifted to the right by ShAmt.
935 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
936 RHSKnown >>= ShAmt;
937 Known = Known.unionWith(RHS: RHSKnown);
938 // assume(V >> ShAmt = C)
939 } else if (match(V: LHS, P: m_Shr(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
940 ShAmt < BitWidth) {
941 // For those bits in RHS that are known, we can propagate them to known
942 // bits in V shifted to the right by C.
943 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
944 RHSKnown <<= ShAmt;
945 Known = Known.unionWith(RHS: RHSKnown);
946 }
947 break;
948 case ICmpInst::ICMP_NE: {
949 // assume (V & B != 0) where B is a power of 2
950 const APInt *BPow2;
951 if (C->isZero() && match(V: LHS, P: m_And(L: m_V, R: m_Power2(V&: BPow2))))
952 Known.One |= *BPow2;
953 break;
954 }
955 default: {
956 const APInt *Offset = nullptr;
957 if (match(V: LHS, P: m_CombineOr(L: m_V, R: m_AddLike(L: m_V, R: m_APInt(Res&: Offset))))) {
958 ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, Other: *C);
959 if (Offset)
960 LHSRange = LHSRange.sub(Other: *Offset);
961 Known = Known.unionWith(RHS: LHSRange.toKnownBits());
962 }
963 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
964 // X & Y u> C -> X u> C && Y u> C
965 // X nuw- Y u> C -> X u> C
966 if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value())) ||
967 match(V: LHS, P: m_NUWSub(L: m_V, R: m_Value())))
968 Known.One.setHighBits(
969 (*C + (Pred == ICmpInst::ICMP_UGT)).countLeadingOnes());
970 }
971 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
972 // X | Y u< C -> X u< C && Y u< C
973 // X nuw+ Y u< C -> X u< C && Y u< C
974 if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value())) ||
975 match(V: LHS, P: m_c_NUWAdd(L: m_V, R: m_Value()))) {
976 Known.Zero.setHighBits(
977 (*C - (Pred == ICmpInst::ICMP_ULT)).countLeadingZeros());
978 }
979 }
980 } break;
981 }
982}
983
984static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
985 KnownBits &Known,
986 const SimplifyQuery &SQ, bool Invert) {
987 ICmpInst::Predicate Pred =
988 Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
989 Value *LHS = Cmp->getOperand(i_nocapture: 0);
990 Value *RHS = Cmp->getOperand(i_nocapture: 1);
991
992 // Handle icmp pred (trunc V), C
993 if (match(V: LHS, P: m_Trunc(Op: m_Specific(V)))) {
994 KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
995 computeKnownBitsFromCmp(V: LHS, Pred, LHS, RHS, Known&: DstKnown, Q: SQ);
996 if (cast<TruncInst>(Val: LHS)->hasNoUnsignedWrap())
997 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
998 else
999 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1000 return;
1001 }
1002
1003 computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, Q: SQ);
1004}
1005
1006static void computeKnownBitsFromCond(const Value *V, Value *Cond,
1007 KnownBits &Known, const SimplifyQuery &SQ,
1008 bool Invert, unsigned Depth) {
1009 Value *A, *B;
1010 if (Depth < MaxAnalysisRecursionDepth &&
1011 match(V: Cond, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
1012 KnownBits Known2(Known.getBitWidth());
1013 KnownBits Known3(Known.getBitWidth());
1014 computeKnownBitsFromCond(V, Cond: A, Known&: Known2, SQ, Invert, Depth: Depth + 1);
1015 computeKnownBitsFromCond(V, Cond: B, Known&: Known3, SQ, Invert, Depth: Depth + 1);
1016 if (Invert ? match(V: Cond, P: m_LogicalOr(L: m_Value(), R: m_Value()))
1017 : match(V: Cond, P: m_LogicalAnd(L: m_Value(), R: m_Value())))
1018 Known2 = Known2.unionWith(RHS: Known3);
1019 else
1020 Known2 = Known2.intersectWith(RHS: Known3);
1021 Known = Known.unionWith(RHS: Known2);
1022 return;
1023 }
1024
1025 if (auto *Cmp = dyn_cast<ICmpInst>(Val: Cond)) {
1026 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ, Invert);
1027 return;
1028 }
1029
1030 if (match(V: Cond, P: m_Trunc(Op: m_Specific(V)))) {
1031 KnownBits DstKnown(1);
1032 if (Invert) {
1033 DstKnown.setAllZero();
1034 } else {
1035 DstKnown.setAllOnes();
1036 }
1037 if (cast<TruncInst>(Val: Cond)->hasNoUnsignedWrap()) {
1038 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
1039 return;
1040 }
1041 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1042 return;
1043 }
1044
1045 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A))))
1046 computeKnownBitsFromCond(V, Cond: A, Known, SQ, Invert: !Invert, Depth: Depth + 1);
1047}
1048
1049void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
1050 const SimplifyQuery &Q, unsigned Depth) {
1051 // Handle injected condition.
1052 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
1053 computeKnownBitsFromCond(V, Cond: Q.CC->Cond, Known, SQ: Q, Invert: Q.CC->Invert, Depth);
1054
1055 if (!Q.CxtI)
1056 return;
1057
1058 if (Q.DC && Q.DT) {
1059 // Handle dominating conditions.
1060 for (CondBrInst *BI : Q.DC->conditionsFor(V)) {
1061 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
1062 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
1063 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1064 /*Invert*/ false, Depth);
1065
1066 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
1067 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
1068 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1069 /*Invert*/ true, Depth);
1070 }
1071
1072 if (Known.hasConflict())
1073 Known.resetAll();
1074 }
1075
1076 if (!Q.AC)
1077 return;
1078
1079 unsigned BitWidth = Known.getBitWidth();
1080
1081 // Note that the patterns below need to be kept in sync with the code
1082 // in AssumptionCache::updateAffectedValues.
1083
1084 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
1085 if (!Elem.Assume)
1086 continue;
1087
1088 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
1089 assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
1090 "Got assumption for the wrong function!");
1091
1092 if (Elem.Index != AssumptionCache::ExprResultIdx) {
1093 if (!V->getType()->isPointerTy())
1094 continue;
1095 if (RetainedKnowledge RK = getKnowledgeFromBundle(
1096 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
1097 // Allow AllowEphemerals in isValidAssumeForContext, as the CxtI might
1098 // be the producer of the pointer in the bundle. At the moment, align
1099 // assumptions aren't optimized away.
1100 if (RK.WasOn == V && RK.AttrKind == Attribute::Alignment &&
1101 isPowerOf2_64(Value: RK.ArgValue) &&
1102 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT, /*AllowEphemerals*/ true))
1103 Known.Zero.setLowBits(Log2_64(Value: RK.ArgValue));
1104 }
1105 continue;
1106 }
1107
1108 // Warning: This loop can end up being somewhat performance sensitive.
1109 // We're running this loop for once for each value queried resulting in a
1110 // runtime of ~O(#assumes * #values).
1111
1112 Value *Arg = I->getArgOperand(i: 0);
1113
1114 if (Arg == V && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1115 assert(BitWidth == 1 && "assume operand is not i1?");
1116 (void)BitWidth;
1117 Known.setAllOnes();
1118 return;
1119 }
1120 if (match(V: Arg, P: m_Not(V: m_Specific(V))) &&
1121 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1122 assert(BitWidth == 1 && "assume operand is not i1?");
1123 (void)BitWidth;
1124 Known.setAllZero();
1125 return;
1126 }
1127 auto *Trunc = dyn_cast<TruncInst>(Val: Arg);
1128 if (Trunc && Trunc->getOperand(i_nocapture: 0) == V &&
1129 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1130 if (Trunc->hasNoUnsignedWrap()) {
1131 Known = KnownBits::makeConstant(C: APInt(BitWidth, 1));
1132 return;
1133 }
1134 Known.One.setBit(0);
1135 return;
1136 }
1137
1138 // The remaining tests are all recursive, so bail out if we hit the limit.
1139 if (Depth == MaxAnalysisRecursionDepth)
1140 continue;
1141
1142 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
1143 if (!Cmp)
1144 continue;
1145
1146 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
1147 continue;
1148
1149 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ: Q, /*Invert=*/false);
1150 }
1151
1152 // Conflicting assumption: Undefined behavior will occur on this execution
1153 // path.
1154 if (Known.hasConflict())
1155 Known.resetAll();
1156}
1157
1158/// Compute known bits from a shift operator, including those with a
1159/// non-constant shift amount. Known is the output of this function. Known2 is a
1160/// pre-allocated temporary with the same bit width as Known and on return
1161/// contains the known bit of the shift value source. KF is an
1162/// operator-specific function that, given the known-bits and a shift amount,
1163/// compute the implied known-bits of the shift operator's result respectively
1164/// for that shift amount. The results from calling KF are conservatively
1165/// combined for all permitted shift amounts.
1166static void computeKnownBitsFromShiftOperator(
1167 const Operator *I, const APInt &DemandedElts, KnownBits &Known,
1168 KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
1169 function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
1170 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1171 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1172 // To limit compile-time impact, only query isKnownNonZero() if we know at
1173 // least something about the shift amount.
1174 bool ShAmtNonZero =
1175 Known.isNonZero() ||
1176 (Known.getMaxValue().ult(RHS: Known.getBitWidth()) &&
1177 isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1));
1178 Known = KF(Known2, Known, ShAmtNonZero);
1179}
1180
1181static KnownBits
1182getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
1183 const KnownBits &KnownLHS, const KnownBits &KnownRHS,
1184 const SimplifyQuery &Q, unsigned Depth) {
1185 unsigned BitWidth = KnownLHS.getBitWidth();
1186 KnownBits KnownOut(BitWidth);
1187 bool IsAnd = false;
1188 bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
1189 Value *X = nullptr, *Y = nullptr;
1190
1191 switch (I->getOpcode()) {
1192 case Instruction::And:
1193 KnownOut = KnownLHS & KnownRHS;
1194 IsAnd = true;
1195 // and(x, -x) is common idioms that will clear all but lowest set
1196 // bit. If we have a single known bit in x, we can clear all bits
1197 // above it.
1198 // TODO: instcombine often reassociates independent `and` which can hide
1199 // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
1200 if (HasKnownOne && match(V: I, P: m_c_And(L: m_Value(V&: X), R: m_Neg(V: m_Deferred(V: X))))) {
1201 // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
1202 if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
1203 KnownOut = KnownLHS.blsi();
1204 else
1205 KnownOut = KnownRHS.blsi();
1206 }
1207 break;
1208 case Instruction::Or:
1209 KnownOut = KnownLHS | KnownRHS;
1210 break;
1211 case Instruction::Xor:
1212 KnownOut = KnownLHS ^ KnownRHS;
1213 // xor(x, x-1) is common idioms that will clear all but lowest set
1214 // bit. If we have a single known bit in x, we can clear all bits
1215 // above it.
1216 // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
1217 // -1 but for the purpose of demanded bits (xor(x, x-C) &
1218 // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
1219 // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
1220 if (HasKnownOne &&
1221 match(V: I, P: m_c_Xor(L: m_Value(V&: X), R: m_Add(L: m_Deferred(V: X), R: m_AllOnes())))) {
1222 const KnownBits &XBits = I->getOperand(i: 0) == X ? KnownLHS : KnownRHS;
1223 KnownOut = XBits.blsmsk();
1224 }
1225 break;
1226 default:
1227 llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
1228 }
1229
1230 // and(x, add (x, -1)) is a common idiom that always clears the low bit;
1231 // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
1232 // here we handle the more general case of adding any odd number by
1233 // matching the form and/xor/or(x, add(x, y)) where y is odd.
1234 // TODO: This could be generalized to clearing any bit set in y where the
1235 // following bit is known to be unset in y.
1236 if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
1237 (match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_c_Add(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1238 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1239 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Value(V&: Y), R: m_Deferred(V: X)))))) {
1240 KnownBits KnownY(BitWidth);
1241 computeKnownBits(V: Y, DemandedElts, Known&: KnownY, Q, Depth: Depth + 1);
1242 if (KnownY.countMinTrailingOnes() > 0) {
1243 if (IsAnd)
1244 KnownOut.Zero.setBit(0);
1245 else
1246 KnownOut.One.setBit(0);
1247 }
1248 }
1249 return KnownOut;
1250}
1251
1252static KnownBits computeKnownBitsForHorizontalOperation(
1253 const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
1254 unsigned Depth,
1255 const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
1256 KnownBitsFunc) {
1257 APInt DemandedEltsLHS, DemandedEltsRHS;
1258 getHorizDemandedEltsForFirstOperand(VectorBitWidth: Q.DL.getTypeSizeInBits(Ty: I->getType()),
1259 DemandedElts, DemandedLHS&: DemandedEltsLHS,
1260 DemandedRHS&: DemandedEltsRHS);
1261
1262 const auto ComputeForSingleOpFunc =
1263 [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
1264 return KnownBitsFunc(
1265 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp, Q, Depth: Depth + 1),
1266 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp << 1, Q, Depth: Depth + 1));
1267 };
1268
1269 if (DemandedEltsRHS.isZero())
1270 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS);
1271 if (DemandedEltsLHS.isZero())
1272 return ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS);
1273
1274 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS)
1275 .intersectWith(RHS: ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS));
1276}
1277
1278// Public so this can be used in `SimplifyDemandedUseBits`.
1279KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
1280 const KnownBits &KnownLHS,
1281 const KnownBits &KnownRHS,
1282 const SimplifyQuery &SQ,
1283 unsigned Depth) {
1284 auto *FVTy = dyn_cast<FixedVectorType>(Val: I->getType());
1285 APInt DemandedElts =
1286 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
1287
1288 return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, Q: SQ,
1289 Depth);
1290}
1291
1292ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
1293 Attribute Attr = F->getFnAttribute(Kind: Attribute::VScaleRange);
1294 // Without vscale_range, we only know that vscale is non-zero.
1295 if (!Attr.isValid())
1296 return ConstantRange(APInt(BitWidth, 1), APInt::getZero(numBits: BitWidth));
1297
1298 unsigned AttrMin = Attr.getVScaleRangeMin();
1299 // Minimum is larger than vscale width, result is always poison.
1300 if ((unsigned)llvm::bit_width(Value: AttrMin) > BitWidth)
1301 return ConstantRange::getEmpty(BitWidth);
1302
1303 APInt Min(BitWidth, AttrMin);
1304 std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
1305 if (!AttrMax || (unsigned)llvm::bit_width(Value: *AttrMax) > BitWidth)
1306 return ConstantRange(Min, APInt::getZero(numBits: BitWidth));
1307
1308 return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
1309}
1310
1311void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
1312 Value *Arm, bool Invert,
1313 const SimplifyQuery &Q, unsigned Depth) {
1314 // If we have a constant arm, we are done.
1315 if (Known.isConstant())
1316 return;
1317
1318 // See what condition implies about the bits of the select arm.
1319 KnownBits CondRes(Known.getBitWidth());
1320 computeKnownBitsFromCond(V: Arm, Cond, Known&: CondRes, SQ: Q, Invert, Depth: Depth + 1);
1321 // If we don't get any information from the condition, no reason to
1322 // proceed.
1323 if (CondRes.isUnknown())
1324 return;
1325
1326 // We can have conflict if the condition is dead. I.e if we have
1327 // (x | 64) < 32 ? (x | 64) : y
1328 // we will have conflict at bit 6 from the condition/the `or`.
1329 // In that case just return. Its not particularly important
1330 // what we do, as this select is going to be simplified soon.
1331 CondRes = CondRes.unionWith(RHS: Known);
1332 if (CondRes.hasConflict())
1333 return;
1334
1335 // Finally make sure the information we found is valid. This is relatively
1336 // expensive so it's left for the very end.
1337 if (!isGuaranteedNotToBeUndef(V: Arm, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1))
1338 return;
1339
1340 // Finally, we know we get information from the condition and its valid,
1341 // so return it.
1342 Known = std::move(CondRes);
1343}
1344
1345// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1346// Returns the input and lower/upper bounds.
1347static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1348 const APInt *&CLow, const APInt *&CHigh) {
1349 assert(isa<Operator>(Select) &&
1350 cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1351 "Input should be a Select!");
1352
1353 const Value *LHS = nullptr, *RHS = nullptr;
1354 SelectPatternFlavor SPF = matchSelectPattern(V: Select, LHS, RHS).Flavor;
1355 if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1356 return false;
1357
1358 if (!match(V: RHS, P: m_APInt(Res&: CLow)))
1359 return false;
1360
1361 const Value *LHS2 = nullptr, *RHS2 = nullptr;
1362 SelectPatternFlavor SPF2 = matchSelectPattern(V: LHS, LHS&: LHS2, RHS&: RHS2).Flavor;
1363 if (getInverseMinMaxFlavor(SPF) != SPF2)
1364 return false;
1365
1366 if (!match(V: RHS2, P: m_APInt(Res&: CHigh)))
1367 return false;
1368
1369 if (SPF == SPF_SMIN)
1370 std::swap(a&: CLow, b&: CHigh);
1371
1372 In = LHS2;
1373 return CLow->sle(RHS: *CHigh);
1374}
1375
1376static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1377 const APInt *&CLow,
1378 const APInt *&CHigh) {
1379 assert((II->getIntrinsicID() == Intrinsic::smin ||
1380 II->getIntrinsicID() == Intrinsic::smax) &&
1381 "Must be smin/smax");
1382
1383 Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(MinMaxID: II->getIntrinsicID());
1384 auto *InnerII = dyn_cast<IntrinsicInst>(Val: II->getArgOperand(i: 0));
1385 if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1386 !match(V: II->getArgOperand(i: 1), P: m_APInt(Res&: CLow)) ||
1387 !match(V: InnerII->getArgOperand(i: 1), P: m_APInt(Res&: CHigh)))
1388 return false;
1389
1390 if (II->getIntrinsicID() == Intrinsic::smin)
1391 std::swap(a&: CLow, b&: CHigh);
1392 return CLow->sle(RHS: *CHigh);
1393}
1394
1395static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1396 KnownBits &Known) {
1397 const APInt *CLow, *CHigh;
1398 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1399 Known = Known.unionWith(
1400 RHS: ConstantRange::getNonEmpty(Lower: *CLow, Upper: *CHigh + 1).toKnownBits());
1401}
1402
1403static void computeKnownBitsFromOperator(const Operator *I,
1404 const APInt &DemandedElts,
1405 KnownBits &Known,
1406 const SimplifyQuery &Q,
1407 unsigned Depth) {
1408 unsigned BitWidth = Known.getBitWidth();
1409
1410 KnownBits Known2(BitWidth);
1411 switch (I->getOpcode()) {
1412 default: break;
1413 case Instruction::Load:
1414 if (MDNode *MD =
1415 Q.IIQ.getMetadata(I: cast<LoadInst>(Val: I), KindID: LLVMContext::MD_range))
1416 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
1417 break;
1418 case Instruction::And:
1419 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1420 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1421
1422 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1423 break;
1424 case Instruction::Or:
1425 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1426 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1427
1428 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1429 break;
1430 case Instruction::Xor:
1431 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1432 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1433
1434 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1435 break;
1436 case Instruction::Mul: {
1437 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1438 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1439 computeKnownBitsMul(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1440 DemandedElts, Known, Known2, Q, Depth);
1441 break;
1442 }
1443 case Instruction::UDiv: {
1444 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1445 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1446 Known =
1447 KnownBits::udiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1448 break;
1449 }
1450 case Instruction::SDiv: {
1451 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1452 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1453 Known =
1454 KnownBits::sdiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1455 break;
1456 }
1457 case Instruction::Select: {
1458 auto ComputeForArm = [&](Value *Arm, bool Invert) {
1459 KnownBits Res(Known.getBitWidth());
1460 computeKnownBits(V: Arm, DemandedElts, Known&: Res, Q, Depth: Depth + 1);
1461 adjustKnownBitsForSelectArm(Known&: Res, Cond: I->getOperand(i: 0), Arm, Invert, Q, Depth);
1462 return Res;
1463 };
1464 // Only known if known in both the LHS and RHS.
1465 Known =
1466 ComputeForArm(I->getOperand(i: 1), /*Invert=*/false)
1467 .intersectWith(RHS: ComputeForArm(I->getOperand(i: 2), /*Invert=*/true));
1468 break;
1469 }
1470 case Instruction::FPTrunc:
1471 case Instruction::FPExt:
1472 case Instruction::FPToUI:
1473 case Instruction::FPToSI:
1474 case Instruction::SIToFP:
1475 case Instruction::UIToFP:
1476 break; // Can't work with floating point.
1477 case Instruction::PtrToInt:
1478 case Instruction::PtrToAddr:
1479 case Instruction::IntToPtr:
1480 // Fall through and handle them the same as zext/trunc.
1481 [[fallthrough]];
1482 case Instruction::ZExt:
1483 case Instruction::Trunc: {
1484 Type *SrcTy = I->getOperand(i: 0)->getType();
1485
1486 unsigned SrcBitWidth;
1487 // Note that we handle pointer operands here because of inttoptr/ptrtoint
1488 // which fall through here.
1489 Type *ScalarTy = SrcTy->getScalarType();
1490 SrcBitWidth = ScalarTy->isPointerTy() ?
1491 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
1492 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
1493
1494 assert(SrcBitWidth && "SrcBitWidth can't be zero");
1495 Known = Known.anyextOrTrunc(BitWidth: SrcBitWidth);
1496 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1497 if (auto *Inst = dyn_cast<PossiblyNonNegInst>(Val: I);
1498 Inst && Inst->hasNonNeg() && !Known.isNegative())
1499 Known.makeNonNegative();
1500 Known = Known.zextOrTrunc(BitWidth);
1501 break;
1502 }
1503 case Instruction::BitCast: {
1504 Type *SrcTy = I->getOperand(i: 0)->getType();
1505 if (SrcTy->isIntOrPtrTy() &&
1506 // TODO: For now, not handling conversions like:
1507 // (bitcast i64 %x to <2 x i32>)
1508 !I->getType()->isVectorTy()) {
1509 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1510 break;
1511 }
1512
1513 const Value *V;
1514 // Handle bitcast from floating point to integer.
1515 if (match(V: I, P: m_ElementWiseBitCast(Op: m_Value(V))) &&
1516 V->getType()->isFPOrFPVectorTy()) {
1517 Type *FPType = V->getType()->getScalarType();
1518 KnownFPClass Result =
1519 computeKnownFPClass(V, DemandedElts, InterestedClasses: fcAllFlags, SQ: Q, Depth: Depth + 1);
1520 FPClassTest FPClasses = Result.KnownFPClasses;
1521
1522 // TODO: Treat it as zero/poison if the use of I is unreachable.
1523 if (FPClasses == fcNone)
1524 break;
1525
1526 if (Result.isKnownNever(Mask: fcNormal | fcSubnormal | fcNan)) {
1527 Known.setAllConflict();
1528
1529 if (FPClasses & fcInf)
1530 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1531 C: APFloat::getInf(Sem: FPType->getFltSemantics()).bitcastToAPInt()));
1532
1533 if (FPClasses & fcZero)
1534 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1535 C: APInt::getZero(numBits: FPType->getScalarSizeInBits())));
1536
1537 Known.Zero.clearSignBit();
1538 Known.One.clearSignBit();
1539 }
1540
1541 if (Result.SignBit) {
1542 if (*Result.SignBit)
1543 Known.makeNegative();
1544 else
1545 Known.makeNonNegative();
1546 }
1547
1548 break;
1549 }
1550
1551 // Handle cast from vector integer type to scalar or vector integer.
1552 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcTy);
1553 if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
1554 !I->getType()->isIntOrIntVectorTy() ||
1555 isa<ScalableVectorType>(Val: I->getType()))
1556 break;
1557
1558 unsigned NumElts = DemandedElts.getBitWidth();
1559 bool IsLE = Q.DL.isLittleEndian();
1560 // Look through a cast from narrow vector elements to wider type.
1561 // Examples: v4i32 -> v2i64, v3i8 -> v24
1562 unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
1563 if (BitWidth % SubBitWidth == 0) {
1564 // Known bits are automatically intersected across demanded elements of a
1565 // vector. So for example, if a bit is computed as known zero, it must be
1566 // zero across all demanded elements of the vector.
1567 //
1568 // For this bitcast, each demanded element of the output is sub-divided
1569 // across a set of smaller vector elements in the source vector. To get
1570 // the known bits for an entire element of the output, compute the known
1571 // bits for each sub-element sequentially. This is done by shifting the
1572 // one-set-bit demanded elements parameter across the sub-elements for
1573 // consecutive calls to computeKnownBits. We are using the demanded
1574 // elements parameter as a mask operator.
1575 //
1576 // The known bits of each sub-element are then inserted into place
1577 // (dependent on endian) to form the full result of known bits.
1578 unsigned SubScale = BitWidth / SubBitWidth;
1579 APInt SubDemandedElts = APInt::getZero(numBits: NumElts * SubScale);
1580 for (unsigned i = 0; i != NumElts; ++i) {
1581 if (DemandedElts[i])
1582 SubDemandedElts.setBit(i * SubScale);
1583 }
1584
1585 KnownBits KnownSrc(SubBitWidth);
1586 for (unsigned i = 0; i != SubScale; ++i) {
1587 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts.shl(shiftAmt: i), Known&: KnownSrc, Q,
1588 Depth: Depth + 1);
1589 unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
1590 Known.insertBits(SubBits: KnownSrc, BitPosition: ShiftElt * SubBitWidth);
1591 }
1592 }
1593 // Look through a cast from wider vector elements to narrow type.
1594 // Examples: v2i64 -> v4i32
1595 if (SubBitWidth % BitWidth == 0) {
1596 unsigned SubScale = SubBitWidth / BitWidth;
1597 KnownBits KnownSrc(SubBitWidth);
1598 APInt SubDemandedElts =
1599 APIntOps::ScaleBitMask(A: DemandedElts, NewBitWidth: NumElts / SubScale);
1600 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts, Known&: KnownSrc, Q,
1601 Depth: Depth + 1);
1602
1603 Known.setAllConflict();
1604 for (unsigned i = 0; i != NumElts; ++i) {
1605 if (DemandedElts[i]) {
1606 unsigned Shifts = IsLE ? i : NumElts - 1 - i;
1607 unsigned Offset = (Shifts % SubScale) * BitWidth;
1608 Known = Known.intersectWith(RHS: KnownSrc.extractBits(NumBits: BitWidth, BitPosition: Offset));
1609 if (Known.isUnknown())
1610 break;
1611 }
1612 }
1613 }
1614 break;
1615 }
1616 case Instruction::SExt: {
1617 // Compute the bits in the result that are not present in the input.
1618 unsigned SrcBitWidth = I->getOperand(i: 0)->getType()->getScalarSizeInBits();
1619
1620 Known = Known.trunc(BitWidth: SrcBitWidth);
1621 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1622 // If the sign bit of the input is known set or clear, then we know the
1623 // top bits of the result.
1624 Known = Known.sext(BitWidth);
1625 break;
1626 }
1627 case Instruction::Shl: {
1628 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1629 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1630 auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1631 bool ShAmtNonZero) {
1632 return KnownBits::shl(LHS: KnownVal, RHS: KnownAmt, NUW, NSW, ShAmtNonZero);
1633 };
1634 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1635 KF);
1636 // Trailing zeros of a right-shifted constant never decrease.
1637 const APInt *C;
1638 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1639 Known.Zero.setLowBits(C->countr_zero());
1640 break;
1641 }
1642 case Instruction::LShr: {
1643 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1644 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1645 bool ShAmtNonZero) {
1646 return KnownBits::lshr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1647 };
1648 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1649 KF);
1650 // Leading zeros of a left-shifted constant never decrease.
1651 const APInt *C;
1652 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1653 Known.Zero.setHighBits(C->countl_zero());
1654 break;
1655 }
1656 case Instruction::AShr: {
1657 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1658 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1659 bool ShAmtNonZero) {
1660 return KnownBits::ashr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1661 };
1662 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1663 KF);
1664 break;
1665 }
1666 case Instruction::Sub: {
1667 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1668 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1669 computeKnownBitsAddSub(Add: false, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1670 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1671 break;
1672 }
1673 case Instruction::Add: {
1674 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1675 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1676 computeKnownBitsAddSub(Add: true, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1677 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1678 break;
1679 }
1680 case Instruction::SRem:
1681 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1682 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1683 Known = KnownBits::srem(LHS: Known, RHS: Known2);
1684 break;
1685
1686 case Instruction::URem:
1687 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1688 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1689 Known = KnownBits::urem(LHS: Known, RHS: Known2);
1690 break;
1691 case Instruction::Alloca:
1692 Known.Zero.setLowBits(Log2(A: cast<AllocaInst>(Val: I)->getAlign()));
1693 break;
1694 case Instruction::GetElementPtr: {
1695 // Analyze all of the subscripts of this getelementptr instruction
1696 // to determine if we can prove known low zero bits.
1697 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1698 // Accumulate the constant indices in a separate variable
1699 // to minimize the number of calls to computeForAddSub.
1700 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: I->getType());
1701 APInt AccConstIndices(IndexWidth, 0);
1702
1703 auto AddIndexToKnown = [&](KnownBits IndexBits) {
1704 if (IndexWidth == BitWidth) {
1705 // Note that inbounds does *not* guarantee nsw for the addition, as only
1706 // the offset is signed, while the base address is unsigned.
1707 Known = KnownBits::add(LHS: Known, RHS: IndexBits);
1708 } else {
1709 // If the index width is smaller than the pointer width, only add the
1710 // value to the low bits.
1711 assert(IndexWidth < BitWidth &&
1712 "Index width can't be larger than pointer width");
1713 Known.insertBits(SubBits: KnownBits::add(LHS: Known.trunc(BitWidth: IndexWidth), RHS: IndexBits), BitPosition: 0);
1714 }
1715 };
1716
1717 gep_type_iterator GTI = gep_type_begin(GEP: I);
1718 for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
1719 // TrailZ can only become smaller, short-circuit if we hit zero.
1720 if (Known.isUnknown())
1721 break;
1722
1723 Value *Index = I->getOperand(i);
1724
1725 // Handle case when index is zero.
1726 Constant *CIndex = dyn_cast<Constant>(Val: Index);
1727 if (CIndex && CIndex->isNullValue())
1728 continue;
1729
1730 if (StructType *STy = GTI.getStructTypeOrNull()) {
1731 // Handle struct member offset arithmetic.
1732
1733 assert(CIndex &&
1734 "Access to structure field must be known at compile time");
1735
1736 if (CIndex->getType()->isVectorTy())
1737 Index = CIndex->getSplatValue();
1738
1739 unsigned Idx = cast<ConstantInt>(Val: Index)->getZExtValue();
1740 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
1741 uint64_t Offset = SL->getElementOffset(Idx);
1742 AccConstIndices += Offset;
1743 continue;
1744 }
1745
1746 // Handle array index arithmetic.
1747 Type *IndexedTy = GTI.getIndexedType();
1748 if (!IndexedTy->isSized()) {
1749 Known.resetAll();
1750 break;
1751 }
1752
1753 TypeSize Stride = GTI.getSequentialElementStride(DL: Q.DL);
1754 uint64_t StrideInBytes = Stride.getKnownMinValue();
1755 if (!Stride.isScalable()) {
1756 // Fast path for constant offset.
1757 if (auto *CI = dyn_cast<ConstantInt>(Val: Index)) {
1758 AccConstIndices +=
1759 CI->getValue().sextOrTrunc(width: IndexWidth) * StrideInBytes;
1760 continue;
1761 }
1762 }
1763
1764 KnownBits IndexBits =
1765 computeKnownBits(V: Index, Q, Depth: Depth + 1).sextOrTrunc(BitWidth: IndexWidth);
1766 KnownBits ScalingFactor(IndexWidth);
1767 // Multiply by current sizeof type.
1768 // &A[i] == A + i * sizeof(*A[i]).
1769 if (Stride.isScalable()) {
1770 // For scalable types the only thing we know about sizeof is
1771 // that this is a multiple of the minimum size.
1772 ScalingFactor.Zero.setLowBits(llvm::countr_zero(Val: StrideInBytes));
1773 } else {
1774 ScalingFactor =
1775 KnownBits::makeConstant(C: APInt(IndexWidth, StrideInBytes));
1776 }
1777 AddIndexToKnown(KnownBits::mul(LHS: IndexBits, RHS: ScalingFactor));
1778 }
1779 if (!Known.isUnknown() && !AccConstIndices.isZero())
1780 AddIndexToKnown(KnownBits::makeConstant(C: AccConstIndices));
1781 break;
1782 }
1783 case Instruction::PHI: {
1784 const PHINode *P = cast<PHINode>(Val: I);
1785 BinaryOperator *BO = nullptr;
1786 Value *R = nullptr, *L = nullptr;
1787 if (matchSimpleRecurrence(P, BO, Start&: R, Step&: L)) {
1788 // Handle the case of a simple two-predecessor recurrence PHI.
1789 // There's a lot more that could theoretically be done here, but
1790 // this is sufficient to catch some interesting cases.
1791 unsigned Opcode = BO->getOpcode();
1792
1793 switch (Opcode) {
1794 // If this is a shift recurrence, we know the bits being shifted in. We
1795 // can combine that with information about the start value of the
1796 // recurrence to conclude facts about the result. If this is a udiv
1797 // recurrence, we know that the result can never exceed either the
1798 // numerator or the start value, whichever is greater.
1799 case Instruction::LShr:
1800 case Instruction::AShr:
1801 case Instruction::Shl:
1802 case Instruction::UDiv:
1803 if (BO->getOperand(i_nocapture: 0) != I)
1804 break;
1805 [[fallthrough]];
1806
1807 // For a urem recurrence, the result can never exceed the start value. The
1808 // phi could either be the numerator or the denominator.
1809 case Instruction::URem: {
1810 // We have matched a recurrence of the form:
1811 // %iv = [R, %entry], [%iv.next, %backedge]
1812 // %iv.next = shift_op %iv, L
1813
1814 // Recurse with the phi context to avoid concern about whether facts
1815 // inferred hold at original context instruction. TODO: It may be
1816 // correct to use the original context. IF warranted, explore and
1817 // add sufficient tests to cover.
1818 SimplifyQuery RecQ = Q.getWithoutCondContext();
1819 RecQ.CxtI = P;
1820 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1821 switch (Opcode) {
1822 case Instruction::Shl:
1823 // A shl recurrence will only increase the tailing zeros
1824 Known.Zero.setLowBits(Known2.countMinTrailingZeros());
1825 break;
1826 case Instruction::LShr:
1827 case Instruction::UDiv:
1828 case Instruction::URem:
1829 // lshr, udiv, and urem recurrences will preserve the leading zeros of
1830 // the start value.
1831 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1832 break;
1833 case Instruction::AShr:
1834 // An ashr recurrence will extend the initial sign bit
1835 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1836 Known.One.setHighBits(Known2.countMinLeadingOnes());
1837 break;
1838 }
1839 break;
1840 }
1841
1842 // Check for operations that have the property that if
1843 // both their operands have low zero bits, the result
1844 // will have low zero bits.
1845 case Instruction::Add:
1846 case Instruction::Sub:
1847 case Instruction::And:
1848 case Instruction::Or:
1849 case Instruction::Mul: {
1850 // Change the context instruction to the "edge" that flows into the
1851 // phi. This is important because that is where the value is actually
1852 // "evaluated" even though it is used later somewhere else. (see also
1853 // D69571).
1854 SimplifyQuery RecQ = Q.getWithoutCondContext();
1855
1856 unsigned OpNum = P->getOperand(i_nocapture: 0) == R ? 0 : 1;
1857 Instruction *RInst = P->getIncomingBlock(i: OpNum)->getTerminator();
1858 Instruction *LInst = P->getIncomingBlock(i: 1 - OpNum)->getTerminator();
1859
1860 // Ok, we have a PHI of the form L op= R. Check for low
1861 // zero bits.
1862 RecQ.CxtI = RInst;
1863 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1864
1865 // We need to take the minimum number of known bits
1866 KnownBits Known3(BitWidth);
1867 RecQ.CxtI = LInst;
1868 computeKnownBits(V: L, DemandedElts, Known&: Known3, Q: RecQ, Depth: Depth + 1);
1869
1870 Known.Zero.setLowBits(std::min(a: Known2.countMinTrailingZeros(),
1871 b: Known3.countMinTrailingZeros()));
1872
1873 auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(Val: BO);
1874 if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(Op: OverflowOp))
1875 break;
1876
1877 switch (Opcode) {
1878 // If initial value of recurrence is nonnegative, and we are adding
1879 // a nonnegative number with nsw, the result can only be nonnegative
1880 // or poison value regardless of the number of times we execute the
1881 // add in phi recurrence. If initial value is negative and we are
1882 // adding a negative number with nsw, the result can only be
1883 // negative or poison value. Similar arguments apply to sub and mul.
1884 //
1885 // (add non-negative, non-negative) --> non-negative
1886 // (add negative, negative) --> negative
1887 case Instruction::Add: {
1888 if (Known2.isNonNegative() && Known3.isNonNegative())
1889 Known.makeNonNegative();
1890 else if (Known2.isNegative() && Known3.isNegative())
1891 Known.makeNegative();
1892 break;
1893 }
1894
1895 // (sub nsw non-negative, negative) --> non-negative
1896 // (sub nsw negative, non-negative) --> negative
1897 case Instruction::Sub: {
1898 if (BO->getOperand(i_nocapture: 0) != I)
1899 break;
1900 if (Known2.isNonNegative() && Known3.isNegative())
1901 Known.makeNonNegative();
1902 else if (Known2.isNegative() && Known3.isNonNegative())
1903 Known.makeNegative();
1904 break;
1905 }
1906
1907 // (mul nsw non-negative, non-negative) --> non-negative
1908 case Instruction::Mul:
1909 if (Known2.isNonNegative() && Known3.isNonNegative())
1910 Known.makeNonNegative();
1911 break;
1912
1913 default:
1914 break;
1915 }
1916 break;
1917 }
1918
1919 default:
1920 break;
1921 }
1922 }
1923
1924 // Unreachable blocks may have zero-operand PHI nodes.
1925 if (P->getNumIncomingValues() == 0)
1926 break;
1927
1928 // Otherwise take the unions of the known bit sets of the operands,
1929 // taking conservative care to avoid excessive recursion.
1930 if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
1931 // Skip if every incoming value references to ourself.
1932 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
1933 break;
1934
1935 Known.setAllConflict();
1936 for (const Use &U : P->operands()) {
1937 Value *IncValue;
1938 const PHINode *CxtPhi;
1939 Instruction *CxtI;
1940 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI, PhiOut: &CxtPhi);
1941 // Skip direct self references.
1942 if (IncValue == P)
1943 continue;
1944
1945 // Change the context instruction to the "edge" that flows into the
1946 // phi. This is important because that is where the value is actually
1947 // "evaluated" even though it is used later somewhere else. (see also
1948 // D69571).
1949 SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(I: CxtI);
1950
1951 Known2 = KnownBits(BitWidth);
1952
1953 // Recurse, but cap the recursion to one level, because we don't
1954 // want to waste time spinning around in loops.
1955 // TODO: See if we can base recursion limiter on number of incoming phi
1956 // edges so we don't overly clamp analysis.
1957 computeKnownBits(V: IncValue, DemandedElts, Known&: Known2, Q: RecQ,
1958 Depth: MaxAnalysisRecursionDepth - 1);
1959
1960 // See if we can further use a conditional branch into the phi
1961 // to help us determine the range of the value.
1962 if (!Known2.isConstant()) {
1963 CmpPredicate Pred;
1964 const APInt *RHSC;
1965 BasicBlock *TrueSucc, *FalseSucc;
1966 // TODO: Use RHS Value and compute range from its known bits.
1967 if (match(V: RecQ.CxtI,
1968 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: IncValue), R: m_APInt(Res&: RHSC)),
1969 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
1970 // Check for cases of duplicate successors.
1971 if ((TrueSucc == CxtPhi->getParent()) !=
1972 (FalseSucc == CxtPhi->getParent())) {
1973 // If we're using the false successor, invert the predicate.
1974 if (FalseSucc == CxtPhi->getParent())
1975 Pred = CmpInst::getInversePredicate(pred: Pred);
1976 // Get the knownbits implied by the incoming phi condition.
1977 auto CR = ConstantRange::makeExactICmpRegion(Pred, Other: *RHSC);
1978 KnownBits KnownUnion = Known2.unionWith(RHS: CR.toKnownBits());
1979 // We can have conflicts here if we are analyzing deadcode (its
1980 // impossible for us reach this BB based the icmp).
1981 if (KnownUnion.hasConflict()) {
1982 // No reason to continue analyzing in a known dead region, so
1983 // just resetAll and break. This will cause us to also exit the
1984 // outer loop.
1985 Known.resetAll();
1986 break;
1987 }
1988 Known2 = KnownUnion;
1989 }
1990 }
1991 }
1992
1993 Known = Known.intersectWith(RHS: Known2);
1994 // If all bits have been ruled out, there's no need to check
1995 // more operands.
1996 if (Known.isUnknown())
1997 break;
1998 }
1999 }
2000 break;
2001 }
2002 case Instruction::Call:
2003 case Instruction::Invoke: {
2004 // If range metadata is attached to this call, set known bits from that,
2005 // and then intersect with known bits based on other properties of the
2006 // function.
2007 if (MDNode *MD =
2008 Q.IIQ.getMetadata(I: cast<Instruction>(Val: I), KindID: LLVMContext::MD_range))
2009 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
2010
2011 const auto *CB = cast<CallBase>(Val: I);
2012
2013 if (std::optional<ConstantRange> Range = CB->getRange())
2014 Known = Known.unionWith(RHS: Range->toKnownBits());
2015
2016 if (const Value *RV = CB->getReturnedArgOperand()) {
2017 if (RV->getType() == I->getType()) {
2018 computeKnownBits(V: RV, Known&: Known2, Q, Depth: Depth + 1);
2019 Known = Known.unionWith(RHS: Known2);
2020 // If the function doesn't return properly for all input values
2021 // (e.g. unreachable exits) then there might be conflicts between the
2022 // argument value and the range metadata. Simply discard the known bits
2023 // in case of conflicts.
2024 if (Known.hasConflict())
2025 Known.resetAll();
2026 }
2027 }
2028 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) {
2029 switch (II->getIntrinsicID()) {
2030 default:
2031 break;
2032 case Intrinsic::abs: {
2033 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2034 bool IntMinIsPoison = match(V: II->getArgOperand(i: 1), P: m_One());
2035 Known = Known.unionWith(RHS: Known2.abs(IntMinIsPoison));
2036 break;
2037 }
2038 case Intrinsic::bitreverse:
2039 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2040 Known = Known.unionWith(RHS: Known2.reverseBits());
2041 break;
2042 case Intrinsic::bswap:
2043 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2044 Known = Known.unionWith(RHS: Known2.byteSwap());
2045 break;
2046 case Intrinsic::ctlz: {
2047 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2048 // If we have a known 1, its position is our upper bound.
2049 unsigned PossibleLZ = Known2.countMaxLeadingZeros();
2050 // If this call is poison for 0 input, the result will be less than 2^n.
2051 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2052 PossibleLZ = std::min(a: PossibleLZ, b: BitWidth - 1);
2053 unsigned LowBits = llvm::bit_width(Value: PossibleLZ);
2054 Known.Zero.setBitsFrom(LowBits);
2055 break;
2056 }
2057 case Intrinsic::cttz: {
2058 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2059 // If we have a known 1, its position is our upper bound.
2060 unsigned PossibleTZ = Known2.countMaxTrailingZeros();
2061 // If this call is poison for 0 input, the result will be less than 2^n.
2062 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2063 PossibleTZ = std::min(a: PossibleTZ, b: BitWidth - 1);
2064 unsigned LowBits = llvm::bit_width(Value: PossibleTZ);
2065 Known.Zero.setBitsFrom(LowBits);
2066 break;
2067 }
2068 case Intrinsic::ctpop: {
2069 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2070 // We can bound the space the count needs. Also, bits known to be zero
2071 // can't contribute to the population.
2072 unsigned BitsPossiblySet = Known2.countMaxPopulation();
2073 unsigned LowBits = llvm::bit_width(Value: BitsPossiblySet);
2074 Known.Zero.setBitsFrom(LowBits);
2075 // TODO: we could bound KnownOne using the lower bound on the number
2076 // of bits which might be set provided by popcnt KnownOne2.
2077 break;
2078 }
2079 case Intrinsic::fshr:
2080 case Intrinsic::fshl: {
2081 const APInt *SA;
2082 if (!match(V: I->getOperand(i: 2), P: m_APInt(Res&: SA)))
2083 break;
2084
2085 // Normalize to funnel shift left.
2086 uint64_t ShiftAmt = SA->urem(RHS: BitWidth);
2087 if (II->getIntrinsicID() == Intrinsic::fshr)
2088 ShiftAmt = BitWidth - ShiftAmt;
2089
2090 KnownBits Known3(BitWidth);
2091 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2092 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known3, Q, Depth: Depth + 1);
2093
2094 Known2 <<= ShiftAmt;
2095 Known3 >>= BitWidth - ShiftAmt;
2096 Known = Known2.unionWith(RHS: Known3);
2097 break;
2098 }
2099 case Intrinsic::clmul:
2100 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2101 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2102 Known = KnownBits::clmul(LHS: Known, RHS: Known2);
2103 break;
2104 case Intrinsic::uadd_sat:
2105 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2106 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2107 Known = KnownBits::uadd_sat(LHS: Known, RHS: Known2);
2108 break;
2109 case Intrinsic::usub_sat:
2110 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2111 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2112 Known = KnownBits::usub_sat(LHS: Known, RHS: Known2);
2113 break;
2114 case Intrinsic::sadd_sat:
2115 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2116 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2117 Known = KnownBits::sadd_sat(LHS: Known, RHS: Known2);
2118 break;
2119 case Intrinsic::ssub_sat:
2120 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2121 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2122 Known = KnownBits::ssub_sat(LHS: Known, RHS: Known2);
2123 break;
2124 // Vec reverse preserves bits from input vec.
2125 case Intrinsic::vector_reverse:
2126 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: DemandedElts.reverseBits(), Known, Q,
2127 Depth: Depth + 1);
2128 break;
2129 // for min/max/and/or reduce, any bit common to each element in the
2130 // input vec is set in the output.
2131 case Intrinsic::vector_reduce_and:
2132 case Intrinsic::vector_reduce_or:
2133 case Intrinsic::vector_reduce_umax:
2134 case Intrinsic::vector_reduce_umin:
2135 case Intrinsic::vector_reduce_smax:
2136 case Intrinsic::vector_reduce_smin:
2137 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2138 break;
2139 case Intrinsic::vector_reduce_xor: {
2140 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2141 // The zeros common to all vecs are zero in the output.
2142 // If the number of elements is odd, then the common ones remain. If the
2143 // number of elements is even, then the common ones becomes zeros.
2144 auto *VecTy = cast<VectorType>(Val: I->getOperand(i: 0)->getType());
2145 // Even, so the ones become zeros.
2146 bool EvenCnt = VecTy->getElementCount().isKnownEven();
2147 if (EvenCnt)
2148 Known.Zero |= Known.One;
2149 // Maybe even element count so need to clear ones.
2150 if (VecTy->isScalableTy() || EvenCnt)
2151 Known.One.clearAllBits();
2152 break;
2153 }
2154 case Intrinsic::vector_reduce_add: {
2155 auto *VecTy = dyn_cast<FixedVectorType>(Val: I->getOperand(i: 0)->getType());
2156 if (!VecTy)
2157 break;
2158 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2159 Known = Known.reduceAdd(NumElts: VecTy->getNumElements());
2160 break;
2161 }
2162 case Intrinsic::umin:
2163 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2164 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2165 Known = KnownBits::umin(LHS: Known, RHS: Known2);
2166 break;
2167 case Intrinsic::umax:
2168 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2169 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2170 Known = KnownBits::umax(LHS: Known, RHS: Known2);
2171 break;
2172 case Intrinsic::smin:
2173 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2174 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2175 Known = KnownBits::smin(LHS: Known, RHS: Known2);
2176 unionWithMinMaxIntrinsicClamp(II, Known);
2177 break;
2178 case Intrinsic::smax:
2179 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2180 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2181 Known = KnownBits::smax(LHS: Known, RHS: Known2);
2182 unionWithMinMaxIntrinsicClamp(II, Known);
2183 break;
2184 case Intrinsic::ptrmask: {
2185 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2186
2187 const Value *Mask = I->getOperand(i: 1);
2188 Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
2189 computeKnownBits(V: Mask, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2190 // TODO: 1-extend would be more precise.
2191 Known &= Known2.anyextOrTrunc(BitWidth);
2192 break;
2193 }
2194 case Intrinsic::x86_sse2_pmulh_w:
2195 case Intrinsic::x86_avx2_pmulh_w:
2196 case Intrinsic::x86_avx512_pmulh_w_512:
2197 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2198 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2199 Known = KnownBits::mulhs(LHS: Known, RHS: Known2);
2200 break;
2201 case Intrinsic::x86_sse2_pmulhu_w:
2202 case Intrinsic::x86_avx2_pmulhu_w:
2203 case Intrinsic::x86_avx512_pmulhu_w_512:
2204 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2205 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2206 Known = KnownBits::mulhu(LHS: Known, RHS: Known2);
2207 break;
2208 case Intrinsic::x86_sse42_crc32_64_64:
2209 Known.Zero.setBitsFrom(32);
2210 break;
2211 case Intrinsic::x86_ssse3_phadd_d_128:
2212 case Intrinsic::x86_ssse3_phadd_w_128:
2213 case Intrinsic::x86_avx2_phadd_d:
2214 case Intrinsic::x86_avx2_phadd_w: {
2215 Known = computeKnownBitsForHorizontalOperation(
2216 I, DemandedElts, Q, Depth,
2217 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2218 return KnownBits::add(LHS: KnownLHS, RHS: KnownRHS);
2219 });
2220 break;
2221 }
2222 case Intrinsic::x86_ssse3_phadd_sw_128:
2223 case Intrinsic::x86_avx2_phadd_sw: {
2224 Known = computeKnownBitsForHorizontalOperation(
2225 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::sadd_sat);
2226 break;
2227 }
2228 case Intrinsic::x86_ssse3_phsub_d_128:
2229 case Intrinsic::x86_ssse3_phsub_w_128:
2230 case Intrinsic::x86_avx2_phsub_d:
2231 case Intrinsic::x86_avx2_phsub_w: {
2232 Known = computeKnownBitsForHorizontalOperation(
2233 I, DemandedElts, Q, Depth,
2234 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2235 return KnownBits::sub(LHS: KnownLHS, RHS: KnownRHS);
2236 });
2237 break;
2238 }
2239 case Intrinsic::x86_ssse3_phsub_sw_128:
2240 case Intrinsic::x86_avx2_phsub_sw: {
2241 Known = computeKnownBitsForHorizontalOperation(
2242 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::ssub_sat);
2243 break;
2244 }
2245 case Intrinsic::riscv_vsetvli:
2246 case Intrinsic::riscv_vsetvlimax: {
2247 bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
2248 const ConstantRange Range = getVScaleRange(F: II->getFunction(), BitWidth);
2249 uint64_t SEW = RISCVVType::decodeVSEW(
2250 VSEW: cast<ConstantInt>(Val: II->getArgOperand(i: HasAVL))->getZExtValue());
2251 RISCVVType::VLMUL VLMUL = static_cast<RISCVVType::VLMUL>(
2252 cast<ConstantInt>(Val: II->getArgOperand(i: 1 + HasAVL))->getZExtValue());
2253 uint64_t MaxVLEN =
2254 Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
2255 uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMul: VLMUL);
2256
2257 // Result of vsetvli must be not larger than AVL.
2258 if (HasAVL)
2259 if (auto *CI = dyn_cast<ConstantInt>(Val: II->getArgOperand(i: 0)))
2260 MaxVL = std::min(a: MaxVL, b: CI->getZExtValue());
2261
2262 unsigned KnownZeroFirstBit = Log2_32(Value: MaxVL) + 1;
2263 if (BitWidth > KnownZeroFirstBit)
2264 Known.Zero.setBitsFrom(KnownZeroFirstBit);
2265 break;
2266 }
2267 case Intrinsic::amdgcn_mbcnt_hi:
2268 case Intrinsic::amdgcn_mbcnt_lo: {
2269 // Wave64 mbcnt_lo returns at most 32 + src1. Otherwise these return at
2270 // most 31 + src1.
2271 Known.Zero.setBitsFrom(
2272 II->getIntrinsicID() == Intrinsic::amdgcn_mbcnt_lo ? 6 : 5);
2273 computeKnownBits(V: I->getOperand(i: 1), Known&: Known2, Q, Depth: Depth + 1);
2274 Known = KnownBits::add(LHS: Known, RHS: Known2);
2275 break;
2276 }
2277 case Intrinsic::vscale: {
2278 if (!II->getParent() || !II->getFunction())
2279 break;
2280
2281 Known = getVScaleRange(F: II->getFunction(), BitWidth).toKnownBits();
2282 break;
2283 }
2284 }
2285 }
2286 break;
2287 }
2288 case Instruction::ShuffleVector: {
2289 if (auto *Splat = getSplatValue(V: I)) {
2290 computeKnownBits(V: Splat, Known, Q, Depth: Depth + 1);
2291 break;
2292 }
2293
2294 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
2295 // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
2296 if (!Shuf) {
2297 Known.resetAll();
2298 return;
2299 }
2300 // For undef elements, we don't know anything about the common state of
2301 // the shuffle result.
2302 APInt DemandedLHS, DemandedRHS;
2303 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
2304 Known.resetAll();
2305 return;
2306 }
2307 Known.setAllConflict();
2308 if (!!DemandedLHS) {
2309 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
2310 computeKnownBits(V: LHS, DemandedElts: DemandedLHS, Known, Q, Depth: Depth + 1);
2311 // If we don't know any bits, early out.
2312 if (Known.isUnknown())
2313 break;
2314 }
2315 if (!!DemandedRHS) {
2316 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
2317 computeKnownBits(V: RHS, DemandedElts: DemandedRHS, Known&: Known2, Q, Depth: Depth + 1);
2318 Known = Known.intersectWith(RHS: Known2);
2319 }
2320 break;
2321 }
2322 case Instruction::InsertElement: {
2323 if (isa<ScalableVectorType>(Val: I->getType())) {
2324 Known.resetAll();
2325 return;
2326 }
2327 const Value *Vec = I->getOperand(i: 0);
2328 const Value *Elt = I->getOperand(i: 1);
2329 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
2330 unsigned NumElts = DemandedElts.getBitWidth();
2331 APInt DemandedVecElts = DemandedElts;
2332 bool NeedsElt = true;
2333 // If we know the index we are inserting too, clear it from Vec check.
2334 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
2335 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
2336 NeedsElt = DemandedElts[CIdx->getZExtValue()];
2337 }
2338
2339 Known.setAllConflict();
2340 if (NeedsElt) {
2341 computeKnownBits(V: Elt, Known, Q, Depth: Depth + 1);
2342 // If we don't know any bits, early out.
2343 if (Known.isUnknown())
2344 break;
2345 }
2346
2347 if (!DemandedVecElts.isZero()) {
2348 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known&: Known2, Q, Depth: Depth + 1);
2349 Known = Known.intersectWith(RHS: Known2);
2350 }
2351 break;
2352 }
2353 case Instruction::ExtractElement: {
2354 // Look through extract element. If the index is non-constant or
2355 // out-of-range demand all elements, otherwise just the extracted element.
2356 const Value *Vec = I->getOperand(i: 0);
2357 const Value *Idx = I->getOperand(i: 1);
2358 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
2359 if (isa<ScalableVectorType>(Val: Vec->getType())) {
2360 // FIXME: there's probably *something* we can do with scalable vectors
2361 Known.resetAll();
2362 break;
2363 }
2364 unsigned NumElts = cast<FixedVectorType>(Val: Vec->getType())->getNumElements();
2365 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
2366 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
2367 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
2368 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known, Q, Depth: Depth + 1);
2369 break;
2370 }
2371 case Instruction::ExtractValue:
2372 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I->getOperand(i: 0))) {
2373 const ExtractValueInst *EVI = cast<ExtractValueInst>(Val: I);
2374 if (EVI->getNumIndices() != 1) break;
2375 if (EVI->getIndices()[0] == 0) {
2376 switch (II->getIntrinsicID()) {
2377 default: break;
2378 case Intrinsic::uadd_with_overflow:
2379 case Intrinsic::sadd_with_overflow:
2380 computeKnownBitsAddSub(
2381 Add: true, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2382 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2383 break;
2384 case Intrinsic::usub_with_overflow:
2385 case Intrinsic::ssub_with_overflow:
2386 computeKnownBitsAddSub(
2387 Add: false, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2388 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2389 break;
2390 case Intrinsic::umul_with_overflow:
2391 case Intrinsic::smul_with_overflow:
2392 computeKnownBitsMul(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), NSW: false,
2393 NUW: false, DemandedElts, Known, Known2, Q, Depth);
2394 break;
2395 }
2396 }
2397 }
2398 break;
2399 case Instruction::Freeze:
2400 if (isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
2401 Depth: Depth + 1))
2402 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2403 break;
2404 }
2405}
2406
2407/// Determine which bits of V are known to be either zero or one and return
2408/// them.
2409KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
2410 const SimplifyQuery &Q, unsigned Depth) {
2411 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2412 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
2413 return Known;
2414}
2415
2416/// Determine which bits of V are known to be either zero or one and return
2417/// them.
2418KnownBits llvm::computeKnownBits(const Value *V, const SimplifyQuery &Q,
2419 unsigned Depth) {
2420 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2421 computeKnownBits(V, Known, Q, Depth);
2422 return Known;
2423}
2424
2425/// Determine which bits of V are known to be either zero or one and return
2426/// them in the Known bit set.
2427///
2428/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that
2429/// we cannot optimize based on the assumption that it is zero without changing
2430/// it to be an explicit zero. If we don't change it to zero, other code could
2431/// optimized based on the contradictory assumption that it is non-zero.
2432/// Because instcombine aggressively folds operations with undef args anyway,
2433/// this won't lose us code quality.
2434///
2435/// This function is defined on values with integer type, values with pointer
2436/// type, and vectors of integers. In the case
2437/// where V is a vector, known zero, and known one values are the
2438/// same width as the vector element, and the bit is set only if it is true
2439/// for all of the demanded elements in the vector specified by DemandedElts.
2440void computeKnownBits(const Value *V, const APInt &DemandedElts,
2441 KnownBits &Known, const SimplifyQuery &Q,
2442 unsigned Depth) {
2443 if (!DemandedElts) {
2444 // No demanded elts, better to assume we don't know anything.
2445 Known.resetAll();
2446 return;
2447 }
2448
2449 assert(V && "No Value?");
2450 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2451
2452#ifndef NDEBUG
2453 Type *Ty = V->getType();
2454 unsigned BitWidth = Known.getBitWidth();
2455
2456 assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
2457 "Not integer or pointer type!");
2458
2459 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
2460 assert(
2461 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
2462 "DemandedElt width should equal the fixed vector number of elements");
2463 } else {
2464 assert(DemandedElts == APInt(1, 1) &&
2465 "DemandedElt width should be 1 for scalars or scalable vectors");
2466 }
2467
2468 Type *ScalarTy = Ty->getScalarType();
2469 if (ScalarTy->isPointerTy()) {
2470 assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
2471 "V and Known should have same BitWidth");
2472 } else {
2473 assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
2474 "V and Known should have same BitWidth");
2475 }
2476#endif
2477
2478 const APInt *C;
2479 if (match(V, P: m_APInt(Res&: C))) {
2480 // We know all of the bits for a scalar constant or a splat vector constant!
2481 Known = KnownBits::makeConstant(C: *C);
2482 return;
2483 }
2484 // Null and aggregate-zero are all-zeros.
2485 if (isa<ConstantPointerNull>(Val: V) || isa<ConstantAggregateZero>(Val: V)) {
2486 Known.setAllZero();
2487 return;
2488 }
2489 // Handle a constant vector by taking the intersection of the known bits of
2490 // each element.
2491 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(Val: V)) {
2492 assert(!isa<ScalableVectorType>(V->getType()));
2493 // We know that CDV must be a vector of integers. Take the intersection of
2494 // each element.
2495 Known.setAllConflict();
2496 for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
2497 if (!DemandedElts[i])
2498 continue;
2499 APInt Elt = CDV->getElementAsAPInt(i);
2500 Known.Zero &= ~Elt;
2501 Known.One &= Elt;
2502 }
2503 if (Known.hasConflict())
2504 Known.resetAll();
2505 return;
2506 }
2507
2508 if (const auto *CV = dyn_cast<ConstantVector>(Val: V)) {
2509 assert(!isa<ScalableVectorType>(V->getType()));
2510 // We know that CV must be a vector of integers. Take the intersection of
2511 // each element.
2512 Known.setAllConflict();
2513 for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
2514 if (!DemandedElts[i])
2515 continue;
2516 Constant *Element = CV->getAggregateElement(Elt: i);
2517 if (isa<PoisonValue>(Val: Element))
2518 continue;
2519 auto *ElementCI = dyn_cast_or_null<ConstantInt>(Val: Element);
2520 if (!ElementCI) {
2521 Known.resetAll();
2522 return;
2523 }
2524 const APInt &Elt = ElementCI->getValue();
2525 Known.Zero &= ~Elt;
2526 Known.One &= Elt;
2527 }
2528 if (Known.hasConflict())
2529 Known.resetAll();
2530 return;
2531 }
2532
2533 // Start out not knowing anything.
2534 Known.resetAll();
2535
2536 // We can't imply anything about undefs.
2537 if (isa<UndefValue>(Val: V))
2538 return;
2539
2540 // There's no point in looking through other users of ConstantData for
2541 // assumptions. Confirm that we've handled them all.
2542 assert(!isa<ConstantData>(V) && "Unhandled constant data!");
2543
2544 if (const auto *A = dyn_cast<Argument>(Val: V))
2545 if (std::optional<ConstantRange> Range = A->getRange())
2546 Known = Range->toKnownBits();
2547
2548 // All recursive calls that increase depth must come after this.
2549 if (Depth == MaxAnalysisRecursionDepth)
2550 return;
2551
2552 // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
2553 // the bits of its aliasee.
2554 if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(Val: V)) {
2555 if (!GA->isInterposable())
2556 computeKnownBits(V: GA->getAliasee(), Known, Q, Depth: Depth + 1);
2557 return;
2558 }
2559
2560 if (const Operator *I = dyn_cast<Operator>(Val: V))
2561 computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
2562 else if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
2563 if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
2564 Known = CR->toKnownBits();
2565 }
2566
2567 // Aligned pointers have trailing zeros - refine Known.Zero set
2568 if (isa<PointerType>(Val: V->getType())) {
2569 Align Alignment = V->getPointerAlignment(DL: Q.DL);
2570 Known.Zero.setLowBits(Log2(A: Alignment));
2571 }
2572
2573 // computeKnownBitsFromContext strictly refines Known.
2574 // Therefore, we run them after computeKnownBitsFromOperator.
2575
2576 // Check whether we can determine known bits from context such as assumes.
2577 computeKnownBitsFromContext(V, Known, Q, Depth);
2578}
2579
2580/// Try to detect a recurrence that the value of the induction variable is
2581/// always a power of two (or zero).
2582static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
2583 SimplifyQuery &Q, unsigned Depth) {
2584 BinaryOperator *BO = nullptr;
2585 Value *Start = nullptr, *Step = nullptr;
2586 if (!matchSimpleRecurrence(P: PN, BO, Start, Step))
2587 return false;
2588
2589 // Initial value must be a power of two.
2590 for (const Use &U : PN->operands()) {
2591 if (U.get() == Start) {
2592 // Initial value comes from a different BB, need to adjust context
2593 // instruction for analysis.
2594 Q.CxtI = PN->getIncomingBlock(U)->getTerminator();
2595 if (!isKnownToBeAPowerOfTwo(V: Start, OrZero, Q, Depth))
2596 return false;
2597 }
2598 }
2599
2600 // Except for Mul, the induction variable must be on the left side of the
2601 // increment expression, otherwise its value can be arbitrary.
2602 if (BO->getOpcode() != Instruction::Mul && BO->getOperand(i_nocapture: 1) != Step)
2603 return false;
2604
2605 Q.CxtI = BO->getParent()->getTerminator();
2606 switch (BO->getOpcode()) {
2607 case Instruction::Mul:
2608 // Power of two is closed under multiplication.
2609 return (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) ||
2610 Q.IIQ.hasNoSignedWrap(Op: BO)) &&
2611 isKnownToBeAPowerOfTwo(V: Step, OrZero, Q, Depth);
2612 case Instruction::SDiv:
2613 // Start value must not be signmask for signed division, so simply being a
2614 // power of two is not sufficient, and it has to be a constant.
2615 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2616 return false;
2617 [[fallthrough]];
2618 case Instruction::UDiv:
2619 // Divisor must be a power of two.
2620 // If OrZero is false, cannot guarantee induction variable is non-zero after
2621 // division, same for Shr, unless it is exact division.
2622 return (OrZero || Q.IIQ.isExact(Op: BO)) &&
2623 isKnownToBeAPowerOfTwo(V: Step, OrZero: false, Q, Depth);
2624 case Instruction::Shl:
2625 return OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO);
2626 case Instruction::AShr:
2627 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2628 return false;
2629 [[fallthrough]];
2630 case Instruction::LShr:
2631 return OrZero || Q.IIQ.isExact(Op: BO);
2632 default:
2633 return false;
2634 }
2635}
2636
2637/// Return true if we can infer that \p V is known to be a power of 2 from
2638/// dominating condition \p Cond (e.g., ctpop(V) == 1).
2639static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
2640 const Value *Cond,
2641 bool CondIsTrue) {
2642 CmpPredicate Pred;
2643 const APInt *RHSC;
2644 if (!match(V: Cond, P: m_ICmp(Pred, L: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Specific(V)),
2645 R: m_APInt(Res&: RHSC))))
2646 return false;
2647 if (!CondIsTrue)
2648 Pred = ICmpInst::getInversePredicate(pred: Pred);
2649 // ctpop(V) u< 2
2650 if (OrZero && Pred == ICmpInst::ICMP_ULT && *RHSC == 2)
2651 return true;
2652 // ctpop(V) == 1
2653 return Pred == ICmpInst::ICMP_EQ && *RHSC == 1;
2654}
2655
2656/// Return true if the given value is known to have exactly one
2657/// bit set when defined. For vectors return true if every element is known to
2658/// be a power of two when defined. Supports values with integer or pointer
2659/// types and vectors of integers.
2660bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero,
2661 const SimplifyQuery &Q, unsigned Depth) {
2662 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2663
2664 if (isa<Constant>(Val: V))
2665 return OrZero ? match(V, P: m_Power2OrZero()) : match(V, P: m_Power2());
2666
2667 // i1 is by definition a power of 2 or zero.
2668 if (OrZero && V->getType()->getScalarSizeInBits() == 1)
2669 return true;
2670
2671 // Try to infer from assumptions.
2672 if (Q.AC && Q.CxtI) {
2673 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
2674 if (!AssumeVH)
2675 continue;
2676 CallInst *I = cast<CallInst>(Val&: AssumeVH);
2677 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond: I->getArgOperand(i: 0),
2678 /*CondIsTrue=*/true) &&
2679 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
2680 return true;
2681 }
2682 }
2683
2684 // Handle dominating conditions.
2685 if (Q.DC && Q.CxtI && Q.DT) {
2686 for (CondBrInst *BI : Q.DC->conditionsFor(V)) {
2687 Value *Cond = BI->getCondition();
2688
2689 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
2690 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2691 /*CondIsTrue=*/true) &&
2692 Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
2693 return true;
2694
2695 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
2696 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2697 /*CondIsTrue=*/false) &&
2698 Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
2699 return true;
2700 }
2701 }
2702
2703 auto *I = dyn_cast<Instruction>(Val: V);
2704 if (!I)
2705 return false;
2706
2707 if (Q.CxtI && match(V, P: m_VScale())) {
2708 const Function *F = Q.CxtI->getFunction();
2709 // The vscale_range indicates vscale is a power-of-two.
2710 return F->hasFnAttribute(Kind: Attribute::VScaleRange);
2711 }
2712
2713 // 1 << X is clearly a power of two if the one is not shifted off the end. If
2714 // it is shifted off the end then the result is undefined.
2715 if (match(V: I, P: m_Shl(L: m_One(), R: m_Value())))
2716 return true;
2717
2718 // (signmask) >>l X is clearly a power of two if the one is not shifted off
2719 // the bottom. If it is shifted off the bottom then the result is undefined.
2720 if (match(V: I, P: m_LShr(L: m_SignMask(), R: m_Value())))
2721 return true;
2722
2723 // The remaining tests are all recursive, so bail out if we hit the limit.
2724 if (Depth++ == MaxAnalysisRecursionDepth)
2725 return false;
2726
2727 switch (I->getOpcode()) {
2728 case Instruction::ZExt:
2729 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2730 case Instruction::Trunc:
2731 return OrZero && isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2732 case Instruction::Shl:
2733 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: I) || Q.IIQ.hasNoSignedWrap(Op: I))
2734 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2735 return false;
2736 case Instruction::LShr:
2737 if (OrZero || Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2738 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2739 return false;
2740 case Instruction::UDiv:
2741 if (Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2742 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2743 return false;
2744 case Instruction::Mul:
2745 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2746 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth) &&
2747 (OrZero || isKnownNonZero(V: I, Q, Depth));
2748 case Instruction::And:
2749 // A power of two and'd with anything is a power of two or zero.
2750 if (OrZero &&
2751 (isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), /*OrZero*/ true, Q, Depth) ||
2752 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), /*OrZero*/ true, Q, Depth)))
2753 return true;
2754 // X & (-X) is always a power of two or zero.
2755 if (match(V: I->getOperand(i: 0), P: m_Neg(V: m_Specific(V: I->getOperand(i: 1)))) ||
2756 match(V: I->getOperand(i: 1), P: m_Neg(V: m_Specific(V: I->getOperand(i: 0)))))
2757 return OrZero || isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
2758 return false;
2759 case Instruction::Add: {
2760 // Adding a power-of-two or zero to the same power-of-two or zero yields
2761 // either the original power-of-two, a larger power-of-two or zero.
2762 const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(Val: V);
2763 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO) ||
2764 Q.IIQ.hasNoSignedWrap(Op: VOBO)) {
2765 if (match(V: I->getOperand(i: 0),
2766 P: m_c_And(L: m_Specific(V: I->getOperand(i: 1)), R: m_Value())) &&
2767 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth))
2768 return true;
2769 if (match(V: I->getOperand(i: 1),
2770 P: m_c_And(L: m_Specific(V: I->getOperand(i: 0)), R: m_Value())) &&
2771 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth))
2772 return true;
2773
2774 unsigned BitWidth = V->getType()->getScalarSizeInBits();
2775 KnownBits LHSBits(BitWidth);
2776 computeKnownBits(V: I->getOperand(i: 0), Known&: LHSBits, Q, Depth);
2777
2778 KnownBits RHSBits(BitWidth);
2779 computeKnownBits(V: I->getOperand(i: 1), Known&: RHSBits, Q, Depth);
2780 // If i8 V is a power of two or zero:
2781 // ZeroBits: 1 1 1 0 1 1 1 1
2782 // ~ZeroBits: 0 0 0 1 0 0 0 0
2783 if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
2784 // If OrZero isn't set, we cannot give back a zero result.
2785 // Make sure either the LHS or RHS has a bit set.
2786 if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
2787 return true;
2788 }
2789
2790 // LShr(UINT_MAX, Y) + 1 is a power of two (if add is nuw) or zero.
2791 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO))
2792 if (match(V: I, P: m_Add(L: m_LShr(L: m_AllOnes(), R: m_Value()), R: m_One())))
2793 return true;
2794 return false;
2795 }
2796 case Instruction::Select:
2797 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2798 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 2), OrZero, Q, Depth);
2799 case Instruction::PHI: {
2800 // A PHI node is power of two if all incoming values are power of two, or if
2801 // it is an induction variable where in each step its value is a power of
2802 // two.
2803 auto *PN = cast<PHINode>(Val: I);
2804 SimplifyQuery RecQ = Q.getWithoutCondContext();
2805
2806 // Check if it is an induction variable and always power of two.
2807 if (isPowerOfTwoRecurrence(PN, OrZero, Q&: RecQ, Depth))
2808 return true;
2809
2810 // Recursively check all incoming values. Limit recursion to 2 levels, so
2811 // that search complexity is limited to number of operands^2.
2812 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
2813 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
2814 // Value is power of 2 if it is coming from PHI node itself by induction.
2815 if (U.get() == PN)
2816 return true;
2817
2818 // Change the context instruction to the incoming block where it is
2819 // evaluated.
2820 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
2821 return isKnownToBeAPowerOfTwo(V: U.get(), OrZero, Q: RecQ, Depth: NewDepth);
2822 });
2823 }
2824 case Instruction::Invoke:
2825 case Instruction::Call: {
2826 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
2827 switch (II->getIntrinsicID()) {
2828 case Intrinsic::umax:
2829 case Intrinsic::smax:
2830 case Intrinsic::umin:
2831 case Intrinsic::smin:
2832 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 1), OrZero, Q, Depth) &&
2833 isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2834 // bswap/bitreverse just move around bits, but don't change any 1s/0s
2835 // thus dont change pow2/non-pow2 status.
2836 case Intrinsic::bitreverse:
2837 case Intrinsic::bswap:
2838 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2839 case Intrinsic::fshr:
2840 case Intrinsic::fshl:
2841 // If Op0 == Op1, this is a rotate. is_pow2(rotate(x, y)) == is_pow2(x)
2842 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
2843 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2844 break;
2845 default:
2846 break;
2847 }
2848 }
2849 return false;
2850 }
2851 default:
2852 return false;
2853 }
2854}
2855
2856/// Test whether a GEP's result is known to be non-null.
2857///
2858/// Uses properties inherent in a GEP to try to determine whether it is known
2859/// to be non-null.
2860///
2861/// Currently this routine does not support vector GEPs.
2862static bool isGEPKnownNonNull(const GEPOperator *GEP, const SimplifyQuery &Q,
2863 unsigned Depth) {
2864 const Function *F = nullptr;
2865 if (const Instruction *I = dyn_cast<Instruction>(Val: GEP))
2866 F = I->getFunction();
2867
2868 // If the gep is nuw or inbounds with invalid null pointer, then the GEP
2869 // may be null iff the base pointer is null and the offset is zero.
2870 if (!GEP->hasNoUnsignedWrap() &&
2871 !(GEP->isInBounds() &&
2872 !NullPointerIsDefined(F, AS: GEP->getPointerAddressSpace())))
2873 return false;
2874
2875 // FIXME: Support vector-GEPs.
2876 assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
2877
2878 // If the base pointer is non-null, we cannot walk to a null address with an
2879 // inbounds GEP in address space zero.
2880 if (isKnownNonZero(V: GEP->getPointerOperand(), Q, Depth))
2881 return true;
2882
2883 // Walk the GEP operands and see if any operand introduces a non-zero offset.
2884 // If so, then the GEP cannot produce a null pointer, as doing so would
2885 // inherently violate the inbounds contract within address space zero.
2886 for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
2887 GTI != GTE; ++GTI) {
2888 // Struct types are easy -- they must always be indexed by a constant.
2889 if (StructType *STy = GTI.getStructTypeOrNull()) {
2890 ConstantInt *OpC = cast<ConstantInt>(Val: GTI.getOperand());
2891 unsigned ElementIdx = OpC->getZExtValue();
2892 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
2893 uint64_t ElementOffset = SL->getElementOffset(Idx: ElementIdx);
2894 if (ElementOffset > 0)
2895 return true;
2896 continue;
2897 }
2898
2899 // If we have a zero-sized type, the index doesn't matter. Keep looping.
2900 if (GTI.getSequentialElementStride(DL: Q.DL).isZero())
2901 continue;
2902
2903 // Fast path the constant operand case both for efficiency and so we don't
2904 // increment Depth when just zipping down an all-constant GEP.
2905 if (ConstantInt *OpC = dyn_cast<ConstantInt>(Val: GTI.getOperand())) {
2906 if (!OpC->isZero())
2907 return true;
2908 continue;
2909 }
2910
2911 // We post-increment Depth here because while isKnownNonZero increments it
2912 // as well, when we pop back up that increment won't persist. We don't want
2913 // to recurse 10k times just because we have 10k GEP operands. We don't
2914 // bail completely out because we want to handle constant GEPs regardless
2915 // of depth.
2916 if (Depth++ >= MaxAnalysisRecursionDepth)
2917 continue;
2918
2919 if (isKnownNonZero(V: GTI.getOperand(), Q, Depth))
2920 return true;
2921 }
2922
2923 return false;
2924}
2925
2926static bool isKnownNonNullFromDominatingCondition(const Value *V,
2927 const Instruction *CtxI,
2928 const DominatorTree *DT) {
2929 assert(!isa<Constant>(V) && "Called for constant?");
2930
2931 if (!CtxI || !DT)
2932 return false;
2933
2934 unsigned NumUsesExplored = 0;
2935 for (auto &U : V->uses()) {
2936 // Avoid massive lists
2937 if (NumUsesExplored >= DomConditionsMaxUses)
2938 break;
2939 NumUsesExplored++;
2940
2941 const Instruction *UI = cast<Instruction>(Val: U.getUser());
2942 // If the value is used as an argument to a call or invoke, then argument
2943 // attributes may provide an answer about null-ness.
2944 if (V->getType()->isPointerTy()) {
2945 if (const auto *CB = dyn_cast<CallBase>(Val: UI)) {
2946 if (CB->isArgOperand(U: &U) &&
2947 CB->paramHasNonNullAttr(ArgNo: CB->getArgOperandNo(U: &U),
2948 /*AllowUndefOrPoison=*/false) &&
2949 DT->dominates(Def: CB, User: CtxI))
2950 return true;
2951 }
2952 }
2953
2954 // If the value is used as a load/store, then the pointer must be non null.
2955 if (V == getLoadStorePointerOperand(V: UI)) {
2956 if (!NullPointerIsDefined(F: UI->getFunction(),
2957 AS: V->getType()->getPointerAddressSpace()) &&
2958 DT->dominates(Def: UI, User: CtxI))
2959 return true;
2960 }
2961
2962 if ((match(V: UI, P: m_IDiv(L: m_Value(), R: m_Specific(V))) ||
2963 match(V: UI, P: m_IRem(L: m_Value(), R: m_Specific(V)))) &&
2964 isValidAssumeForContext(Inv: UI, CxtI: CtxI, DT))
2965 return true;
2966
2967 // Consider only compare instructions uniquely controlling a branch
2968 Value *RHS;
2969 CmpPredicate Pred;
2970 if (!match(V: UI, P: m_c_ICmp(Pred, L: m_Specific(V), R: m_Value(V&: RHS))))
2971 continue;
2972
2973 bool NonNullIfTrue;
2974 if (cmpExcludesZero(Pred, RHS))
2975 NonNullIfTrue = true;
2976 else if (cmpExcludesZero(Pred: CmpInst::getInversePredicate(pred: Pred), RHS))
2977 NonNullIfTrue = false;
2978 else
2979 continue;
2980
2981 SmallVector<const User *, 4> WorkList;
2982 SmallPtrSet<const User *, 4> Visited;
2983 for (const auto *CmpU : UI->users()) {
2984 assert(WorkList.empty() && "Should be!");
2985 if (Visited.insert(Ptr: CmpU).second)
2986 WorkList.push_back(Elt: CmpU);
2987
2988 while (!WorkList.empty()) {
2989 auto *Curr = WorkList.pop_back_val();
2990
2991 // If a user is an AND, add all its users to the work list. We only
2992 // propagate "pred != null" condition through AND because it is only
2993 // correct to assume that all conditions of AND are met in true branch.
2994 // TODO: Support similar logic of OR and EQ predicate?
2995 if (NonNullIfTrue)
2996 if (match(V: Curr, P: m_LogicalAnd(L: m_Value(), R: m_Value()))) {
2997 for (const auto *CurrU : Curr->users())
2998 if (Visited.insert(Ptr: CurrU).second)
2999 WorkList.push_back(Elt: CurrU);
3000 continue;
3001 }
3002
3003 if (const CondBrInst *BI = dyn_cast<CondBrInst>(Val: Curr)) {
3004 BasicBlock *NonNullSuccessor =
3005 BI->getSuccessor(i: NonNullIfTrue ? 0 : 1);
3006 BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
3007 if (DT->dominates(BBE: Edge, BB: CtxI->getParent()))
3008 return true;
3009 } else if (NonNullIfTrue && isGuard(U: Curr) &&
3010 DT->dominates(Def: cast<Instruction>(Val: Curr), User: CtxI)) {
3011 return true;
3012 }
3013 }
3014 }
3015 }
3016
3017 return false;
3018}
3019
3020/// Does the 'Range' metadata (which must be a valid MD_range operand list)
3021/// ensure that the value it's attached to is never Value? 'RangeType' is
3022/// is the type of the value described by the range.
3023static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
3024 const unsigned NumRanges = Ranges->getNumOperands() / 2;
3025 assert(NumRanges >= 1);
3026 for (unsigned i = 0; i < NumRanges; ++i) {
3027 ConstantInt *Lower =
3028 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 0));
3029 ConstantInt *Upper =
3030 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 1));
3031 ConstantRange Range(Lower->getValue(), Upper->getValue());
3032 if (Range.contains(Val: Value))
3033 return false;
3034 }
3035 return true;
3036}
3037
3038/// Try to detect a recurrence that monotonically increases/decreases from a
3039/// non-zero starting value. These are common as induction variables.
3040static bool isNonZeroRecurrence(const PHINode *PN) {
3041 BinaryOperator *BO = nullptr;
3042 Value *Start = nullptr, *Step = nullptr;
3043 const APInt *StartC, *StepC;
3044 if (!matchSimpleRecurrence(P: PN, BO, Start, Step) ||
3045 !match(V: Start, P: m_APInt(Res&: StartC)) || StartC->isZero())
3046 return false;
3047
3048 switch (BO->getOpcode()) {
3049 case Instruction::Add:
3050 // Starting from non-zero and stepping away from zero can never wrap back
3051 // to zero.
3052 return BO->hasNoUnsignedWrap() ||
3053 (BO->hasNoSignedWrap() && match(V: Step, P: m_APInt(Res&: StepC)) &&
3054 StartC->isNegative() == StepC->isNegative());
3055 case Instruction::Mul:
3056 return (BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) &&
3057 match(V: Step, P: m_APInt(Res&: StepC)) && !StepC->isZero();
3058 case Instruction::Shl:
3059 return BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap();
3060 case Instruction::AShr:
3061 case Instruction::LShr:
3062 return BO->isExact();
3063 default:
3064 return false;
3065 }
3066}
3067
3068static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
3069 return match(V: Op0, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3070 L: m_Specific(V: Op1), R: m_Zero()))) ||
3071 match(V: Op1, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3072 L: m_Specific(V: Op0), R: m_Zero())));
3073}
3074
3075static bool isNonZeroAdd(const APInt &DemandedElts, const SimplifyQuery &Q,
3076 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3077 bool NUW, unsigned Depth) {
3078 // (X + (X != 0)) is non zero
3079 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3080 return true;
3081
3082 if (NUW)
3083 return isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3084 isKnownNonZero(V: X, DemandedElts, Q, Depth);
3085
3086 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3087 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3088
3089 // If X and Y are both non-negative (as signed values) then their sum is not
3090 // zero unless both X and Y are zero.
3091 if (XKnown.isNonNegative() && YKnown.isNonNegative())
3092 if (isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3093 isKnownNonZero(V: X, DemandedElts, Q, Depth))
3094 return true;
3095
3096 // If X and Y are both negative (as signed values) then their sum is not
3097 // zero unless both X and Y equal INT_MIN.
3098 if (XKnown.isNegative() && YKnown.isNegative()) {
3099 APInt Mask = APInt::getSignedMaxValue(numBits: BitWidth);
3100 // The sign bit of X is set. If some other bit is set then X is not equal
3101 // to INT_MIN.
3102 if (XKnown.One.intersects(RHS: Mask))
3103 return true;
3104 // The sign bit of Y is set. If some other bit is set then Y is not equal
3105 // to INT_MIN.
3106 if (YKnown.One.intersects(RHS: Mask))
3107 return true;
3108 }
3109
3110 // The sum of a non-negative number and a power of two is not zero.
3111 if (XKnown.isNonNegative() &&
3112 isKnownToBeAPowerOfTwo(V: Y, /*OrZero*/ false, Q, Depth))
3113 return true;
3114 if (YKnown.isNonNegative() &&
3115 isKnownToBeAPowerOfTwo(V: X, /*OrZero*/ false, Q, Depth))
3116 return true;
3117
3118 return KnownBits::add(LHS: XKnown, RHS: YKnown, NSW, NUW).isNonZero();
3119}
3120
3121static bool isNonZeroSub(const APInt &DemandedElts, const SimplifyQuery &Q,
3122 unsigned BitWidth, Value *X, Value *Y,
3123 unsigned Depth) {
3124 // (X - (X != 0)) is non zero
3125 // ((X != 0) - X) is non zero
3126 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3127 return true;
3128
3129 // TODO: Move this case into isKnownNonEqual().
3130 if (auto *C = dyn_cast<Constant>(Val: X))
3131 if (C->isNullValue() && isKnownNonZero(V: Y, DemandedElts, Q, Depth))
3132 return true;
3133
3134 return ::isKnownNonEqual(V1: X, V2: Y, DemandedElts, Q, Depth);
3135}
3136
3137static bool isNonZeroMul(const APInt &DemandedElts, const SimplifyQuery &Q,
3138 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3139 bool NUW, unsigned Depth) {
3140 // If X and Y are non-zero then so is X * Y as long as the multiplication
3141 // does not overflow.
3142 if (NSW || NUW)
3143 return isKnownNonZero(V: X, DemandedElts, Q, Depth) &&
3144 isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3145
3146 // If either X or Y is odd, then if the other is non-zero the result can't
3147 // be zero.
3148 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3149 if (XKnown.One[0])
3150 return isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3151
3152 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3153 if (YKnown.One[0])
3154 return XKnown.isNonZero() || isKnownNonZero(V: X, DemandedElts, Q, Depth);
3155
3156 // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
3157 // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
3158 // the lowest known One of X and Y. If they are non-zero, the result
3159 // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
3160 // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
3161 return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) <
3162 BitWidth;
3163}
3164
3165static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts,
3166 const SimplifyQuery &Q, const KnownBits &KnownVal,
3167 unsigned Depth) {
3168 auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3169 switch (I->getOpcode()) {
3170 case Instruction::Shl:
3171 return Lhs.shl(ShiftAmt: Rhs);
3172 case Instruction::LShr:
3173 return Lhs.lshr(ShiftAmt: Rhs);
3174 case Instruction::AShr:
3175 return Lhs.ashr(ShiftAmt: Rhs);
3176 default:
3177 llvm_unreachable("Unknown Shift Opcode");
3178 }
3179 };
3180
3181 auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3182 switch (I->getOpcode()) {
3183 case Instruction::Shl:
3184 return Lhs.lshr(ShiftAmt: Rhs);
3185 case Instruction::LShr:
3186 case Instruction::AShr:
3187 return Lhs.shl(ShiftAmt: Rhs);
3188 default:
3189 llvm_unreachable("Unknown Shift Opcode");
3190 }
3191 };
3192
3193 if (KnownVal.isUnknown())
3194 return false;
3195
3196 KnownBits KnownCnt =
3197 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3198 APInt MaxShift = KnownCnt.getMaxValue();
3199 unsigned NumBits = KnownVal.getBitWidth();
3200 if (MaxShift.uge(RHS: NumBits))
3201 return false;
3202
3203 if (!ShiftOp(KnownVal.One, MaxShift).isZero())
3204 return true;
3205
3206 // If all of the bits shifted out are known to be zero, and Val is known
3207 // non-zero then at least one non-zero bit must remain.
3208 if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift)
3209 .eq(RHS: InvShiftOp(APInt::getAllOnes(numBits: NumBits), NumBits - MaxShift)) &&
3210 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth))
3211 return true;
3212
3213 return false;
3214}
3215
3216static bool isKnownNonZeroFromOperator(const Operator *I,
3217 const APInt &DemandedElts,
3218 const SimplifyQuery &Q, unsigned Depth) {
3219 unsigned BitWidth = getBitWidth(Ty: I->getType()->getScalarType(), DL: Q.DL);
3220 switch (I->getOpcode()) {
3221 case Instruction::Alloca:
3222 // Alloca never returns null, malloc might.
3223 return I->getType()->getPointerAddressSpace() == 0;
3224 case Instruction::GetElementPtr:
3225 if (I->getType()->isPointerTy())
3226 return isGEPKnownNonNull(GEP: cast<GEPOperator>(Val: I), Q, Depth);
3227 break;
3228 case Instruction::BitCast: {
3229 // We need to be a bit careful here. We can only peek through the bitcast
3230 // if the scalar size of elements in the operand are smaller than and a
3231 // multiple of the size they are casting too. Take three cases:
3232 //
3233 // 1) Unsafe:
3234 // bitcast <2 x i16> %NonZero to <4 x i8>
3235 //
3236 // %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a
3237 // <4 x i8> requires that all 4 i8 elements be non-zero which isn't
3238 // guranteed (imagine just sign bit set in the 2 i16 elements).
3239 //
3240 // 2) Unsafe:
3241 // bitcast <4 x i3> %NonZero to <3 x i4>
3242 //
3243 // Even though the scalar size of the src (`i3`) is smaller than the
3244 // scalar size of the dst `i4`, because `i3` is not a multiple of `i4`
3245 // its possible for the `3 x i4` elements to be zero because there are
3246 // some elements in the destination that don't contain any full src
3247 // element.
3248 //
3249 // 3) Safe:
3250 // bitcast <4 x i8> %NonZero to <2 x i16>
3251 //
3252 // This is always safe as non-zero in the 4 i8 elements implies
3253 // non-zero in the combination of any two adjacent ones. Since i8 is a
3254 // multiple of i16, each i16 is guranteed to have 2 full i8 elements.
3255 // This all implies the 2 i16 elements are non-zero.
3256 Type *FromTy = I->getOperand(i: 0)->getType();
3257 if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
3258 (BitWidth % getBitWidth(Ty: FromTy->getScalarType(), DL: Q.DL)) == 0)
3259 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3260 } break;
3261 case Instruction::IntToPtr:
3262 // Note that we have to take special care to avoid looking through
3263 // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
3264 // as casts that can alter the value, e.g., AddrSpaceCasts.
3265 if (!isa<ScalableVectorType>(Val: I->getType()) &&
3266 Q.DL.getTypeSizeInBits(Ty: I->getOperand(i: 0)->getType()).getFixedValue() <=
3267 Q.DL.getTypeSizeInBits(Ty: I->getType()).getFixedValue())
3268 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3269 break;
3270 case Instruction::PtrToAddr:
3271 // isKnownNonZero() for pointers refers to the address bits being non-zero,
3272 // so we can directly forward.
3273 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3274 case Instruction::PtrToInt:
3275 // For inttoptr, make sure the result size is >= the address size. If the
3276 // address is non-zero, any larger value is also non-zero.
3277 if (Q.DL.getAddressSizeInBits(Ty: I->getOperand(i: 0)->getType()) <=
3278 I->getType()->getScalarSizeInBits())
3279 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3280 break;
3281 case Instruction::Trunc:
3282 // nuw/nsw trunc preserves zero/non-zero status of input.
3283 if (auto *TI = dyn_cast<TruncInst>(Val: I))
3284 if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
3285 return isKnownNonZero(V: TI->getOperand(i_nocapture: 0), DemandedElts, Q, Depth);
3286 break;
3287
3288 // Iff x - y != 0, then x ^ y != 0
3289 // Therefore we can do the same exact checks
3290 case Instruction::Xor:
3291 case Instruction::Sub:
3292 return isNonZeroSub(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3293 Y: I->getOperand(i: 1), Depth);
3294 case Instruction::Or:
3295 // (X | (X != 0)) is non zero
3296 if (matchOpWithOpEqZero(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1)))
3297 return true;
3298 // X | Y != 0 if X != Y.
3299 if (isKnownNonEqual(V1: I->getOperand(i: 0), V2: I->getOperand(i: 1), DemandedElts, Q,
3300 Depth))
3301 return true;
3302 // X | Y != 0 if X != 0 or Y != 0.
3303 return isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth) ||
3304 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3305 case Instruction::SExt:
3306 case Instruction::ZExt:
3307 // ext X != 0 if X != 0.
3308 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3309
3310 case Instruction::Shl: {
3311 // shl nsw/nuw can't remove any non-zero bits.
3312 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3313 if (Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO))
3314 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3315
3316 // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
3317 // if the lowest bit is shifted off the end.
3318 KnownBits Known(BitWidth);
3319 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth);
3320 if (Known.One[0])
3321 return true;
3322
3323 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3324 }
3325 case Instruction::LShr:
3326 case Instruction::AShr: {
3327 // shr exact can only shift out zero bits.
3328 const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(Val: I);
3329 if (BO->isExact())
3330 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3331
3332 // shr X, Y != 0 if X is negative. Note that the value of the shift is not
3333 // defined if the sign bit is shifted off the end.
3334 KnownBits Known =
3335 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3336 if (Known.isNegative())
3337 return true;
3338
3339 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3340 }
3341 case Instruction::UDiv:
3342 case Instruction::SDiv: {
3343 // X / Y
3344 // div exact can only produce a zero if the dividend is zero.
3345 if (cast<PossiblyExactOperator>(Val: I)->isExact())
3346 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3347
3348 KnownBits XKnown =
3349 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3350 // If X is fully unknown we won't be able to figure anything out so don't
3351 // both computing knownbits for Y.
3352 if (XKnown.isUnknown())
3353 return false;
3354
3355 KnownBits YKnown =
3356 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3357 if (I->getOpcode() == Instruction::SDiv) {
3358 // For signed division need to compare abs value of the operands.
3359 XKnown = XKnown.abs(/*IntMinIsPoison*/ false);
3360 YKnown = YKnown.abs(/*IntMinIsPoison*/ false);
3361 }
3362 // If X u>= Y then div is non zero (0/0 is UB).
3363 std::optional<bool> XUgeY = KnownBits::uge(LHS: XKnown, RHS: YKnown);
3364 // If X is total unknown or X u< Y we won't be able to prove non-zero
3365 // with compute known bits so just return early.
3366 return XUgeY && *XUgeY;
3367 }
3368 case Instruction::Add: {
3369 // X + Y.
3370
3371 // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
3372 // non-zero.
3373 auto *BO = cast<OverflowingBinaryOperator>(Val: I);
3374 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3375 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3376 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3377 }
3378 case Instruction::Mul: {
3379 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3380 return isNonZeroMul(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3381 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3382 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3383 }
3384 case Instruction::Select: {
3385 // (C ? X : Y) != 0 if X != 0 and Y != 0.
3386
3387 // First check if the arm is non-zero using `isKnownNonZero`. If that fails,
3388 // then see if the select condition implies the arm is non-zero. For example
3389 // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is
3390 // dominated by `X != 0`.
3391 auto SelectArmIsNonZero = [&](bool IsTrueArm) {
3392 Value *Op;
3393 Op = IsTrueArm ? I->getOperand(i: 1) : I->getOperand(i: 2);
3394 // Op is trivially non-zero.
3395 if (isKnownNonZero(V: Op, DemandedElts, Q, Depth))
3396 return true;
3397
3398 // The condition of the select dominates the true/false arm. Check if the
3399 // condition implies that a given arm is non-zero.
3400 Value *X;
3401 CmpPredicate Pred;
3402 if (!match(V: I->getOperand(i: 0), P: m_c_ICmp(Pred, L: m_Specific(V: Op), R: m_Value(V&: X))))
3403 return false;
3404
3405 if (!IsTrueArm)
3406 Pred = ICmpInst::getInversePredicate(pred: Pred);
3407
3408 return cmpExcludesZero(Pred, RHS: X);
3409 };
3410
3411 if (SelectArmIsNonZero(/* IsTrueArm */ true) &&
3412 SelectArmIsNonZero(/* IsTrueArm */ false))
3413 return true;
3414 break;
3415 }
3416 case Instruction::PHI: {
3417 auto *PN = cast<PHINode>(Val: I);
3418 if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
3419 return true;
3420
3421 // Check if all incoming values are non-zero using recursion.
3422 SimplifyQuery RecQ = Q.getWithoutCondContext();
3423 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
3424 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
3425 if (U.get() == PN)
3426 return true;
3427 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
3428 // Check if the branch on the phi excludes zero.
3429 CmpPredicate Pred;
3430 Value *X;
3431 BasicBlock *TrueSucc, *FalseSucc;
3432 if (match(V: RecQ.CxtI,
3433 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: U.get()), R: m_Value(V&: X)),
3434 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
3435 // Check for cases of duplicate successors.
3436 if ((TrueSucc == PN->getParent()) != (FalseSucc == PN->getParent())) {
3437 // If we're using the false successor, invert the predicate.
3438 if (FalseSucc == PN->getParent())
3439 Pred = CmpInst::getInversePredicate(pred: Pred);
3440 if (cmpExcludesZero(Pred, RHS: X))
3441 return true;
3442 }
3443 }
3444 // Finally recurse on the edge and check it directly.
3445 return isKnownNonZero(V: U.get(), DemandedElts, Q: RecQ, Depth: NewDepth);
3446 });
3447 }
3448 case Instruction::InsertElement: {
3449 if (isa<ScalableVectorType>(Val: I->getType()))
3450 break;
3451
3452 const Value *Vec = I->getOperand(i: 0);
3453 const Value *Elt = I->getOperand(i: 1);
3454 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
3455
3456 unsigned NumElts = DemandedElts.getBitWidth();
3457 APInt DemandedVecElts = DemandedElts;
3458 bool SkipElt = false;
3459 // If we know the index we are inserting too, clear it from Vec check.
3460 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
3461 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
3462 SkipElt = !DemandedElts[CIdx->getZExtValue()];
3463 }
3464
3465 // Result is zero if Elt is non-zero and rest of the demanded elts in Vec
3466 // are non-zero.
3467 return (SkipElt || isKnownNonZero(V: Elt, Q, Depth)) &&
3468 (DemandedVecElts.isZero() ||
3469 isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth));
3470 }
3471 case Instruction::ExtractElement:
3472 if (const auto *EEI = dyn_cast<ExtractElementInst>(Val: I)) {
3473 const Value *Vec = EEI->getVectorOperand();
3474 const Value *Idx = EEI->getIndexOperand();
3475 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
3476 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
3477 unsigned NumElts = VecTy->getNumElements();
3478 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
3479 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
3480 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
3481 return isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth);
3482 }
3483 }
3484 break;
3485 case Instruction::ShuffleVector: {
3486 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
3487 if (!Shuf)
3488 break;
3489 APInt DemandedLHS, DemandedRHS;
3490 // For undef elements, we don't know anything about the common state of
3491 // the shuffle result.
3492 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
3493 break;
3494 // If demanded elements for both vecs are non-zero, the shuffle is non-zero.
3495 return (DemandedRHS.isZero() ||
3496 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 1), DemandedElts: DemandedRHS, Q, Depth)) &&
3497 (DemandedLHS.isZero() ||
3498 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 0), DemandedElts: DemandedLHS, Q, Depth));
3499 }
3500 case Instruction::Freeze:
3501 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth) &&
3502 isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
3503 Depth);
3504 case Instruction::Load: {
3505 auto *LI = cast<LoadInst>(Val: I);
3506 // A Load tagged with nonnull or dereferenceable with null pointer undefined
3507 // is never null.
3508 if (auto *PtrT = dyn_cast<PointerType>(Val: I->getType())) {
3509 if (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_nonnull) ||
3510 (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_dereferenceable) &&
3511 !NullPointerIsDefined(F: LI->getFunction(), AS: PtrT->getAddressSpace())))
3512 return true;
3513 } else if (MDNode *Ranges = Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_range)) {
3514 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3515 }
3516
3517 // No need to fall through to computeKnownBits as range metadata is already
3518 // handled in isKnownNonZero.
3519 return false;
3520 }
3521 case Instruction::ExtractValue: {
3522 const WithOverflowInst *WO;
3523 if (match(V: I, P: m_ExtractValue<0>(V: m_WithOverflowInst(I&: WO)))) {
3524 switch (WO->getBinaryOp()) {
3525 default:
3526 break;
3527 case Instruction::Add:
3528 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3529 Y: WO->getArgOperand(i: 1),
3530 /*NSW=*/false,
3531 /*NUW=*/false, Depth);
3532 case Instruction::Sub:
3533 return isNonZeroSub(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3534 Y: WO->getArgOperand(i: 1), Depth);
3535 case Instruction::Mul:
3536 return isNonZeroMul(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3537 Y: WO->getArgOperand(i: 1),
3538 /*NSW=*/false, /*NUW=*/false, Depth);
3539 break;
3540 }
3541 }
3542 break;
3543 }
3544 case Instruction::Call:
3545 case Instruction::Invoke: {
3546 const auto *Call = cast<CallBase>(Val: I);
3547 if (I->getType()->isPointerTy()) {
3548 if (Call->isReturnNonNull())
3549 return true;
3550 if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: true))
3551 return isKnownNonZero(V: RP, Q, Depth);
3552 } else {
3553 if (MDNode *Ranges = Q.IIQ.getMetadata(I: Call, KindID: LLVMContext::MD_range))
3554 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3555 if (std::optional<ConstantRange> Range = Call->getRange()) {
3556 const APInt ZeroValue(Range->getBitWidth(), 0);
3557 if (!Range->contains(Val: ZeroValue))
3558 return true;
3559 }
3560 if (const Value *RV = Call->getReturnedArgOperand())
3561 if (RV->getType() == I->getType() && isKnownNonZero(V: RV, Q, Depth))
3562 return true;
3563 }
3564
3565 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
3566 switch (II->getIntrinsicID()) {
3567 case Intrinsic::sshl_sat:
3568 case Intrinsic::ushl_sat:
3569 case Intrinsic::abs:
3570 case Intrinsic::bitreverse:
3571 case Intrinsic::bswap:
3572 case Intrinsic::ctpop:
3573 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3574 // NB: We don't do usub_sat here as in any case we can prove its
3575 // non-zero, we will fold it to `sub nuw` in InstCombine.
3576 case Intrinsic::ssub_sat:
3577 // For most types, if x != y then ssub.sat x, y != 0. But
3578 // ssub.sat.i1 0, -1 = 0, because 1 saturates to 0. This means
3579 // isNonZeroSub will do the wrong thing for ssub.sat.i1.
3580 if (BitWidth == 1)
3581 return false;
3582 return isNonZeroSub(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3583 Y: II->getArgOperand(i: 1), Depth);
3584 case Intrinsic::sadd_sat:
3585 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3586 Y: II->getArgOperand(i: 1),
3587 /*NSW=*/true, /* NUW=*/false, Depth);
3588 // Vec reverse preserves zero/non-zero status from input vec.
3589 case Intrinsic::vector_reverse:
3590 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
3591 Q, Depth);
3592 // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
3593 case Intrinsic::vector_reduce_or:
3594 case Intrinsic::vector_reduce_umax:
3595 case Intrinsic::vector_reduce_umin:
3596 case Intrinsic::vector_reduce_smax:
3597 case Intrinsic::vector_reduce_smin:
3598 return isKnownNonZero(V: II->getArgOperand(i: 0), Q, Depth);
3599 case Intrinsic::umax:
3600 case Intrinsic::uadd_sat:
3601 // umax(X, (X != 0)) is non zero
3602 // X +usat (X != 0) is non zero
3603 if (matchOpWithOpEqZero(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1)))
3604 return true;
3605
3606 return isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth) ||
3607 isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3608 case Intrinsic::smax: {
3609 // If either arg is strictly positive the result is non-zero. Otherwise
3610 // the result is non-zero if both ops are non-zero.
3611 auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
3612 const KnownBits &OpKnown) {
3613 if (!OpNonZero.has_value())
3614 OpNonZero = OpKnown.isNonZero() ||
3615 isKnownNonZero(V: Op, DemandedElts, Q, Depth);
3616 return *OpNonZero;
3617 };
3618 // Avoid re-computing isKnownNonZero.
3619 std::optional<bool> Op0NonZero, Op1NonZero;
3620 KnownBits Op1Known =
3621 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3622 if (Op1Known.isNonNegative() &&
3623 IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known))
3624 return true;
3625 KnownBits Op0Known =
3626 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3627 if (Op0Known.isNonNegative() &&
3628 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known))
3629 return true;
3630 return IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known) &&
3631 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known);
3632 }
3633 case Intrinsic::smin: {
3634 // If either arg is negative the result is non-zero. Otherwise
3635 // the result is non-zero if both ops are non-zero.
3636 KnownBits Op1Known =
3637 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3638 if (Op1Known.isNegative())
3639 return true;
3640 KnownBits Op0Known =
3641 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3642 if (Op0Known.isNegative())
3643 return true;
3644
3645 if (Op1Known.isNonZero() && Op0Known.isNonZero())
3646 return true;
3647 }
3648 [[fallthrough]];
3649 case Intrinsic::umin:
3650 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth) &&
3651 isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3652 case Intrinsic::cttz:
3653 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3654 .Zero[0];
3655 case Intrinsic::ctlz:
3656 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3657 .isNonNegative();
3658 case Intrinsic::fshr:
3659 case Intrinsic::fshl:
3660 // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0.
3661 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
3662 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3663 break;
3664 case Intrinsic::vscale:
3665 return true;
3666 case Intrinsic::experimental_get_vector_length:
3667 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3668 default:
3669 break;
3670 }
3671 break;
3672 }
3673
3674 return false;
3675 }
3676 }
3677
3678 KnownBits Known(BitWidth);
3679 computeKnownBits(V: I, DemandedElts, Known, Q, Depth);
3680 return Known.One != 0;
3681}
3682
3683/// Return true if the given value is known to be non-zero when defined. For
3684/// vectors, return true if every demanded element is known to be non-zero when
3685/// defined. For pointers, if the context instruction and dominator tree are
3686/// specified, perform context-sensitive analysis and return true if the
3687/// pointer couldn't possibly be null at the specified instruction.
3688/// Supports values with integer or pointer type and vectors of integers.
3689bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
3690 const SimplifyQuery &Q, unsigned Depth) {
3691 Type *Ty = V->getType();
3692
3693#ifndef NDEBUG
3694 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3695
3696 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3697 assert(
3698 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3699 "DemandedElt width should equal the fixed vector number of elements");
3700 } else {
3701 assert(DemandedElts == APInt(1, 1) &&
3702 "DemandedElt width should be 1 for scalars");
3703 }
3704#endif
3705
3706 if (auto *C = dyn_cast<Constant>(Val: V)) {
3707 if (C->isNullValue())
3708 return false;
3709 if (isa<ConstantInt>(Val: C))
3710 // Must be non-zero due to null test above.
3711 return true;
3712
3713 // For constant vectors, check that all elements are poison or known
3714 // non-zero to determine that the whole vector is known non-zero.
3715 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Ty)) {
3716 for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
3717 if (!DemandedElts[i])
3718 continue;
3719 Constant *Elt = C->getAggregateElement(Elt: i);
3720 if (!Elt || Elt->isNullValue())
3721 return false;
3722 if (!isa<PoisonValue>(Val: Elt) && !isa<ConstantInt>(Val: Elt))
3723 return false;
3724 }
3725 return true;
3726 }
3727
3728 // Constant ptrauth can be null, iff the base pointer can be.
3729 if (auto *CPA = dyn_cast<ConstantPtrAuth>(Val: V))
3730 return isKnownNonZero(V: CPA->getPointer(), DemandedElts, Q, Depth);
3731
3732 // A global variable in address space 0 is non null unless extern weak
3733 // or an absolute symbol reference. Other address spaces may have null as a
3734 // valid address for a global, so we can't assume anything.
3735 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
3736 if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
3737 GV->getType()->getAddressSpace() == 0)
3738 return true;
3739 }
3740
3741 // For constant expressions, fall through to the Operator code below.
3742 if (!isa<ConstantExpr>(Val: V))
3743 return false;
3744 }
3745
3746 if (const auto *A = dyn_cast<Argument>(Val: V))
3747 if (std::optional<ConstantRange> Range = A->getRange()) {
3748 const APInt ZeroValue(Range->getBitWidth(), 0);
3749 if (!Range->contains(Val: ZeroValue))
3750 return true;
3751 }
3752
3753 if (!isa<Constant>(Val: V) && isKnownNonZeroFromAssume(V, Q))
3754 return true;
3755
3756 // Some of the tests below are recursive, so bail out if we hit the limit.
3757 if (Depth++ >= MaxAnalysisRecursionDepth)
3758 return false;
3759
3760 // Check for pointer simplifications.
3761
3762 if (PointerType *PtrTy = dyn_cast<PointerType>(Val: Ty)) {
3763 // A byval, inalloca may not be null in a non-default addres space. A
3764 // nonnull argument is assumed never 0.
3765 if (const Argument *A = dyn_cast<Argument>(Val: V)) {
3766 if (((A->hasPassPointeeByValueCopyAttr() &&
3767 !NullPointerIsDefined(F: A->getParent(), AS: PtrTy->getAddressSpace())) ||
3768 A->hasNonNullAttr()))
3769 return true;
3770 }
3771 }
3772
3773 if (const auto *I = dyn_cast<Operator>(Val: V))
3774 if (isKnownNonZeroFromOperator(I, DemandedElts, Q, Depth))
3775 return true;
3776
3777 if (!isa<Constant>(Val: V) &&
3778 isKnownNonNullFromDominatingCondition(V, CtxI: Q.CxtI, DT: Q.DT))
3779 return true;
3780
3781 if (const Value *Stripped = stripNullTest(V))
3782 return isKnownNonZero(V: Stripped, DemandedElts, Q, Depth);
3783
3784 return false;
3785}
3786
3787bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
3788 unsigned Depth) {
3789 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
3790 APInt DemandedElts =
3791 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
3792 return ::isKnownNonZero(V, DemandedElts, Q, Depth);
3793}
3794
3795/// If the pair of operators are the same invertible function, return the
3796/// the operands of the function corresponding to each input. Otherwise,
3797/// return std::nullopt. An invertible function is one that is 1-to-1 and maps
3798/// every input value to exactly one output value. This is equivalent to
3799/// saying that Op1 and Op2 are equal exactly when the specified pair of
3800/// operands are equal, (except that Op1 and Op2 may be poison more often.)
3801static std::optional<std::pair<Value*, Value*>>
3802getInvertibleOperands(const Operator *Op1,
3803 const Operator *Op2) {
3804 if (Op1->getOpcode() != Op2->getOpcode())
3805 return std::nullopt;
3806
3807 auto getOperands = [&](unsigned OpNum) -> auto {
3808 return std::make_pair(x: Op1->getOperand(i: OpNum), y: Op2->getOperand(i: OpNum));
3809 };
3810
3811 switch (Op1->getOpcode()) {
3812 default:
3813 break;
3814 case Instruction::Or:
3815 if (!cast<PossiblyDisjointInst>(Val: Op1)->isDisjoint() ||
3816 !cast<PossiblyDisjointInst>(Val: Op2)->isDisjoint())
3817 break;
3818 [[fallthrough]];
3819 case Instruction::Xor:
3820 case Instruction::Add: {
3821 Value *Other;
3822 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 0)), R: m_Value(V&: Other))))
3823 return std::make_pair(x: Op1->getOperand(i: 1), y&: Other);
3824 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 1)), R: m_Value(V&: Other))))
3825 return std::make_pair(x: Op1->getOperand(i: 0), y&: Other);
3826 break;
3827 }
3828 case Instruction::Sub:
3829 if (Op1->getOperand(i: 0) == Op2->getOperand(i: 0))
3830 return getOperands(1);
3831 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3832 return getOperands(0);
3833 break;
3834 case Instruction::Mul: {
3835 // invertible if A * B == (A * B) mod 2^N where A, and B are integers
3836 // and N is the bitwdith. The nsw case is non-obvious, but proven by
3837 // alive2: https://alive2.llvm.org/ce/z/Z6D5qK
3838 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3839 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3840 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3841 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3842 break;
3843
3844 // Assume operand order has been canonicalized
3845 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1) &&
3846 isa<ConstantInt>(Val: Op1->getOperand(i: 1)) &&
3847 !cast<ConstantInt>(Val: Op1->getOperand(i: 1))->isZero())
3848 return getOperands(0);
3849 break;
3850 }
3851 case Instruction::Shl: {
3852 // Same as multiplies, with the difference that we don't need to check
3853 // for a non-zero multiply. Shifts always multiply by non-zero.
3854 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3855 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3856 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3857 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3858 break;
3859
3860 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3861 return getOperands(0);
3862 break;
3863 }
3864 case Instruction::AShr:
3865 case Instruction::LShr: {
3866 auto *PEO1 = cast<PossiblyExactOperator>(Val: Op1);
3867 auto *PEO2 = cast<PossiblyExactOperator>(Val: Op2);
3868 if (!PEO1->isExact() || !PEO2->isExact())
3869 break;
3870
3871 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3872 return getOperands(0);
3873 break;
3874 }
3875 case Instruction::SExt:
3876 case Instruction::ZExt:
3877 if (Op1->getOperand(i: 0)->getType() == Op2->getOperand(i: 0)->getType())
3878 return getOperands(0);
3879 break;
3880 case Instruction::PHI: {
3881 const PHINode *PN1 = cast<PHINode>(Val: Op1);
3882 const PHINode *PN2 = cast<PHINode>(Val: Op2);
3883
3884 // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
3885 // are a single invertible function of the start values? Note that repeated
3886 // application of an invertible function is also invertible
3887 BinaryOperator *BO1 = nullptr;
3888 Value *Start1 = nullptr, *Step1 = nullptr;
3889 BinaryOperator *BO2 = nullptr;
3890 Value *Start2 = nullptr, *Step2 = nullptr;
3891 if (PN1->getParent() != PN2->getParent() ||
3892 !matchSimpleRecurrence(P: PN1, BO&: BO1, Start&: Start1, Step&: Step1) ||
3893 !matchSimpleRecurrence(P: PN2, BO&: BO2, Start&: Start2, Step&: Step2))
3894 break;
3895
3896 auto Values = getInvertibleOperands(Op1: cast<Operator>(Val: BO1),
3897 Op2: cast<Operator>(Val: BO2));
3898 if (!Values)
3899 break;
3900
3901 // We have to be careful of mutually defined recurrences here. Ex:
3902 // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
3903 // * X_i = Y_i = X_(i-1) OP Y_(i-1)
3904 // The invertibility of these is complicated, and not worth reasoning
3905 // about (yet?).
3906 if (Values->first != PN1 || Values->second != PN2)
3907 break;
3908
3909 return std::make_pair(x&: Start1, y&: Start2);
3910 }
3911 }
3912 return std::nullopt;
3913}
3914
3915/// Return true if V1 == (binop V2, X), where X is known non-zero.
3916/// Only handle a small subset of binops where (binop V2, X) with non-zero X
3917/// implies V2 != V1.
3918static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
3919 const APInt &DemandedElts,
3920 const SimplifyQuery &Q, unsigned Depth) {
3921 const BinaryOperator *BO = dyn_cast<BinaryOperator>(Val: V1);
3922 if (!BO)
3923 return false;
3924 switch (BO->getOpcode()) {
3925 default:
3926 break;
3927 case Instruction::Or:
3928 if (!cast<PossiblyDisjointInst>(Val: V1)->isDisjoint())
3929 break;
3930 [[fallthrough]];
3931 case Instruction::Xor:
3932 case Instruction::Add:
3933 Value *Op = nullptr;
3934 if (V2 == BO->getOperand(i_nocapture: 0))
3935 Op = BO->getOperand(i_nocapture: 1);
3936 else if (V2 == BO->getOperand(i_nocapture: 1))
3937 Op = BO->getOperand(i_nocapture: 0);
3938 else
3939 return false;
3940 return isKnownNonZero(V: Op, DemandedElts, Q, Depth: Depth + 1);
3941 }
3942 return false;
3943}
3944
3945/// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
3946/// the multiplication is nuw or nsw.
3947static bool isNonEqualMul(const Value *V1, const Value *V2,
3948 const APInt &DemandedElts, const SimplifyQuery &Q,
3949 unsigned Depth) {
3950 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3951 const APInt *C;
3952 return match(V: OBO, P: m_Mul(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3953 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3954 !C->isZero() && !C->isOne() &&
3955 isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3956 }
3957 return false;
3958}
3959
3960/// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
3961/// the shift is nuw or nsw.
3962static bool isNonEqualShl(const Value *V1, const Value *V2,
3963 const APInt &DemandedElts, const SimplifyQuery &Q,
3964 unsigned Depth) {
3965 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3966 const APInt *C;
3967 return match(V: OBO, P: m_Shl(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3968 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3969 !C->isZero() && isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3970 }
3971 return false;
3972}
3973
3974static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
3975 const APInt &DemandedElts, const SimplifyQuery &Q,
3976 unsigned Depth) {
3977 // Check two PHIs are in same block.
3978 if (PN1->getParent() != PN2->getParent())
3979 return false;
3980
3981 SmallPtrSet<const BasicBlock *, 8> VisitedBBs;
3982 bool UsedFullRecursion = false;
3983 for (const BasicBlock *IncomBB : PN1->blocks()) {
3984 if (!VisitedBBs.insert(Ptr: IncomBB).second)
3985 continue; // Don't reprocess blocks that we have dealt with already.
3986 const Value *IV1 = PN1->getIncomingValueForBlock(BB: IncomBB);
3987 const Value *IV2 = PN2->getIncomingValueForBlock(BB: IncomBB);
3988 const APInt *C1, *C2;
3989 if (match(V: IV1, P: m_APInt(Res&: C1)) && match(V: IV2, P: m_APInt(Res&: C2)) && *C1 != *C2)
3990 continue;
3991
3992 // Only one pair of phi operands is allowed for full recursion.
3993 if (UsedFullRecursion)
3994 return false;
3995
3996 SimplifyQuery RecQ = Q.getWithoutCondContext();
3997 RecQ.CxtI = IncomBB->getTerminator();
3998 if (!isKnownNonEqual(V1: IV1, V2: IV2, DemandedElts, Q: RecQ, Depth: Depth + 1))
3999 return false;
4000 UsedFullRecursion = true;
4001 }
4002 return true;
4003}
4004
4005static bool isNonEqualSelect(const Value *V1, const Value *V2,
4006 const APInt &DemandedElts, const SimplifyQuery &Q,
4007 unsigned Depth) {
4008 const SelectInst *SI1 = dyn_cast<SelectInst>(Val: V1);
4009 if (!SI1)
4010 return false;
4011
4012 if (const SelectInst *SI2 = dyn_cast<SelectInst>(Val: V2)) {
4013 const Value *Cond1 = SI1->getCondition();
4014 const Value *Cond2 = SI2->getCondition();
4015 if (Cond1 == Cond2)
4016 return isKnownNonEqual(V1: SI1->getTrueValue(), V2: SI2->getTrueValue(),
4017 DemandedElts, Q, Depth: Depth + 1) &&
4018 isKnownNonEqual(V1: SI1->getFalseValue(), V2: SI2->getFalseValue(),
4019 DemandedElts, Q, Depth: Depth + 1);
4020 }
4021 return isKnownNonEqual(V1: SI1->getTrueValue(), V2, DemandedElts, Q, Depth: Depth + 1) &&
4022 isKnownNonEqual(V1: SI1->getFalseValue(), V2, DemandedElts, Q, Depth: Depth + 1);
4023}
4024
4025// Check to see if A is both a GEP and is the incoming value for a PHI in the
4026// loop, and B is either a ptr or another GEP. If the PHI has 2 incoming values,
4027// one of them being the recursive GEP A and the other a ptr at same base and at
4028// the same/higher offset than B we are only incrementing the pointer further in
4029// loop if offset of recursive GEP is greater than 0.
4030static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
4031 const SimplifyQuery &Q) {
4032 if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
4033 return false;
4034
4035 auto *GEPA = dyn_cast<GEPOperator>(Val: A);
4036 if (!GEPA || GEPA->getNumIndices() != 1 || !isa<Constant>(Val: GEPA->idx_begin()))
4037 return false;
4038
4039 // Handle 2 incoming PHI values with one being a recursive GEP.
4040 auto *PN = dyn_cast<PHINode>(Val: GEPA->getPointerOperand());
4041 if (!PN || PN->getNumIncomingValues() != 2)
4042 return false;
4043
4044 // Search for the recursive GEP as an incoming operand, and record that as
4045 // Step.
4046 Value *Start = nullptr;
4047 Value *Step = const_cast<Value *>(A);
4048 if (PN->getIncomingValue(i: 0) == Step)
4049 Start = PN->getIncomingValue(i: 1);
4050 else if (PN->getIncomingValue(i: 1) == Step)
4051 Start = PN->getIncomingValue(i: 0);
4052 else
4053 return false;
4054
4055 // Other incoming node base should match the B base.
4056 // StartOffset >= OffsetB && StepOffset > 0?
4057 // StartOffset <= OffsetB && StepOffset < 0?
4058 // Is non-equal if above are true.
4059 // We use stripAndAccumulateInBoundsConstantOffsets to restrict the
4060 // optimisation to inbounds GEPs only.
4061 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: Start->getType());
4062 APInt StartOffset(IndexWidth, 0);
4063 Start = Start->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StartOffset);
4064 APInt StepOffset(IndexWidth, 0);
4065 Step = Step->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StepOffset);
4066
4067 // Check if Base Pointer of Step matches the PHI.
4068 if (Step != PN)
4069 return false;
4070 APInt OffsetB(IndexWidth, 0);
4071 B = B->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: OffsetB);
4072 return Start == B &&
4073 ((StartOffset.sge(RHS: OffsetB) && StepOffset.isStrictlyPositive()) ||
4074 (StartOffset.sle(RHS: OffsetB) && StepOffset.isNegative()));
4075}
4076
4077static bool isKnownNonEqualFromContext(const Value *V1, const Value *V2,
4078 const SimplifyQuery &Q, unsigned Depth) {
4079 if (!Q.CxtI)
4080 return false;
4081
4082 // Try to infer NonEqual based on information from dominating conditions.
4083 if (Q.DC && Q.DT) {
4084 auto IsKnownNonEqualFromDominatingCondition = [&](const Value *V) {
4085 for (CondBrInst *BI : Q.DC->conditionsFor(V)) {
4086 Value *Cond = BI->getCondition();
4087 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4088 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()) &&
4089 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4090 /*LHSIsTrue=*/true, Depth)
4091 .value_or(u: false))
4092 return true;
4093
4094 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4095 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()) &&
4096 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4097 /*LHSIsTrue=*/false, Depth)
4098 .value_or(u: false))
4099 return true;
4100 }
4101
4102 return false;
4103 };
4104
4105 if (IsKnownNonEqualFromDominatingCondition(V1) ||
4106 IsKnownNonEqualFromDominatingCondition(V2))
4107 return true;
4108 }
4109
4110 if (!Q.AC)
4111 return false;
4112
4113 // Try to infer NonEqual based on information from assumptions.
4114 for (auto &AssumeVH : Q.AC->assumptionsFor(V: V1)) {
4115 if (!AssumeVH)
4116 continue;
4117 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4118
4119 assert(I->getFunction() == Q.CxtI->getFunction() &&
4120 "Got assumption for the wrong function!");
4121 assert(I->getIntrinsicID() == Intrinsic::assume &&
4122 "must be an assume intrinsic");
4123
4124 if (isImpliedCondition(LHS: I->getArgOperand(i: 0), RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4125 /*LHSIsTrue=*/true, Depth)
4126 .value_or(u: false) &&
4127 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4128 return true;
4129 }
4130
4131 return false;
4132}
4133
4134/// Return true if it is known that V1 != V2.
4135static bool isKnownNonEqual(const Value *V1, const Value *V2,
4136 const APInt &DemandedElts, const SimplifyQuery &Q,
4137 unsigned Depth) {
4138 if (V1 == V2)
4139 return false;
4140 if (V1->getType() != V2->getType())
4141 // We can't look through casts yet.
4142 return false;
4143
4144 if (Depth >= MaxAnalysisRecursionDepth)
4145 return false;
4146
4147 // See if we can recurse through (exactly one of) our operands. This
4148 // requires our operation be 1-to-1 and map every input value to exactly
4149 // one output value. Such an operation is invertible.
4150 auto *O1 = dyn_cast<Operator>(Val: V1);
4151 auto *O2 = dyn_cast<Operator>(Val: V2);
4152 if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
4153 if (auto Values = getInvertibleOperands(Op1: O1, Op2: O2))
4154 return isKnownNonEqual(V1: Values->first, V2: Values->second, DemandedElts, Q,
4155 Depth: Depth + 1);
4156
4157 if (const PHINode *PN1 = dyn_cast<PHINode>(Val: V1)) {
4158 const PHINode *PN2 = cast<PHINode>(Val: V2);
4159 // FIXME: This is missing a generalization to handle the case where one is
4160 // a PHI and another one isn't.
4161 if (isNonEqualPHIs(PN1, PN2, DemandedElts, Q, Depth))
4162 return true;
4163 };
4164 }
4165
4166 if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Q, Depth) ||
4167 isModifyingBinopOfNonZero(V1: V2, V2: V1, DemandedElts, Q, Depth))
4168 return true;
4169
4170 if (isNonEqualMul(V1, V2, DemandedElts, Q, Depth) ||
4171 isNonEqualMul(V1: V2, V2: V1, DemandedElts, Q, Depth))
4172 return true;
4173
4174 if (isNonEqualShl(V1, V2, DemandedElts, Q, Depth) ||
4175 isNonEqualShl(V1: V2, V2: V1, DemandedElts, Q, Depth))
4176 return true;
4177
4178 if (V1->getType()->isIntOrIntVectorTy()) {
4179 // Are any known bits in V1 contradictory to known bits in V2? If V1
4180 // has a known zero where V2 has a known one, they must not be equal.
4181 KnownBits Known1 = computeKnownBits(V: V1, DemandedElts, Q, Depth);
4182 if (!Known1.isUnknown()) {
4183 KnownBits Known2 = computeKnownBits(V: V2, DemandedElts, Q, Depth);
4184 if (Known1.Zero.intersects(RHS: Known2.One) ||
4185 Known2.Zero.intersects(RHS: Known1.One))
4186 return true;
4187 }
4188 }
4189
4190 if (isNonEqualSelect(V1, V2, DemandedElts, Q, Depth) ||
4191 isNonEqualSelect(V1: V2, V2: V1, DemandedElts, Q, Depth))
4192 return true;
4193
4194 if (isNonEqualPointersWithRecursiveGEP(A: V1, B: V2, Q) ||
4195 isNonEqualPointersWithRecursiveGEP(A: V2, B: V1, Q))
4196 return true;
4197
4198 Value *A, *B;
4199 // PtrToInts are NonEqual if their Ptrs are NonEqual.
4200 // Check PtrToInt type matches the pointer size.
4201 if (match(V: V1, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: A))) &&
4202 match(V: V2, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: B))))
4203 return isKnownNonEqual(V1: A, V2: B, DemandedElts, Q, Depth: Depth + 1);
4204
4205 if (isKnownNonEqualFromContext(V1, V2, Q, Depth))
4206 return true;
4207
4208 return false;
4209}
4210
4211/// For vector constants, loop over the elements and find the constant with the
4212/// minimum number of sign bits. Return 0 if the value is not a vector constant
4213/// or if any element was not analyzed; otherwise, return the count for the
4214/// element with the minimum number of sign bits.
4215static unsigned computeNumSignBitsVectorConstant(const Value *V,
4216 const APInt &DemandedElts,
4217 unsigned TyBits) {
4218 const auto *CV = dyn_cast<Constant>(Val: V);
4219 if (!CV || !isa<FixedVectorType>(Val: CV->getType()))
4220 return 0;
4221
4222 unsigned MinSignBits = TyBits;
4223 unsigned NumElts = cast<FixedVectorType>(Val: CV->getType())->getNumElements();
4224 for (unsigned i = 0; i != NumElts; ++i) {
4225 if (!DemandedElts[i])
4226 continue;
4227 // If we find a non-ConstantInt, bail out.
4228 auto *Elt = dyn_cast_or_null<ConstantInt>(Val: CV->getAggregateElement(Elt: i));
4229 if (!Elt)
4230 return 0;
4231
4232 MinSignBits = std::min(a: MinSignBits, b: Elt->getValue().getNumSignBits());
4233 }
4234
4235 return MinSignBits;
4236}
4237
4238static unsigned ComputeNumSignBitsImpl(const Value *V,
4239 const APInt &DemandedElts,
4240 const SimplifyQuery &Q, unsigned Depth);
4241
4242static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
4243 const SimplifyQuery &Q, unsigned Depth) {
4244 unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Q, Depth);
4245 assert(Result > 0 && "At least one sign bit needs to be present!");
4246 return Result;
4247}
4248
4249/// Return the number of times the sign bit of the register is replicated into
4250/// the other bits. We know that at least 1 bit is always equal to the sign bit
4251/// (itself), but other cases can give us information. For example, immediately
4252/// after an "ashr X, 2", we know that the top 3 bits are all equal to each
4253/// other, so we return 3. For vectors, return the number of sign bits for the
4254/// vector element with the minimum number of known sign bits of the demanded
4255/// elements in the vector specified by DemandedElts.
4256static unsigned ComputeNumSignBitsImpl(const Value *V,
4257 const APInt &DemandedElts,
4258 const SimplifyQuery &Q, unsigned Depth) {
4259 Type *Ty = V->getType();
4260#ifndef NDEBUG
4261 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4262
4263 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
4264 assert(
4265 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
4266 "DemandedElt width should equal the fixed vector number of elements");
4267 } else {
4268 assert(DemandedElts == APInt(1, 1) &&
4269 "DemandedElt width should be 1 for scalars");
4270 }
4271#endif
4272
4273 // We return the minimum number of sign bits that are guaranteed to be present
4274 // in V, so for undef we have to conservatively return 1. We don't have the
4275 // same behavior for poison though -- that's a FIXME today.
4276
4277 Type *ScalarTy = Ty->getScalarType();
4278 unsigned TyBits = ScalarTy->isPointerTy() ?
4279 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
4280 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
4281
4282 unsigned Tmp, Tmp2;
4283 unsigned FirstAnswer = 1;
4284
4285 // Note that ConstantInt is handled by the general computeKnownBits case
4286 // below.
4287
4288 if (Depth == MaxAnalysisRecursionDepth)
4289 return 1;
4290
4291 if (auto *U = dyn_cast<Operator>(Val: V)) {
4292 switch (Operator::getOpcode(V)) {
4293 default: break;
4294 case Instruction::BitCast: {
4295 Value *Src = U->getOperand(i: 0);
4296 Type *SrcTy = Src->getType();
4297
4298 // Skip if the source type is not an integer or integer vector type
4299 // This ensures we only process integer-like types
4300 if (!SrcTy->isIntOrIntVectorTy())
4301 break;
4302
4303 unsigned SrcBits = SrcTy->getScalarSizeInBits();
4304
4305 // Bitcast 'large element' scalar/vector to 'small element' vector.
4306 if ((SrcBits % TyBits) != 0)
4307 break;
4308
4309 // Only proceed if the destination type is a fixed-size vector
4310 if (isa<FixedVectorType>(Val: Ty)) {
4311 // Fast case - sign splat can be simply split across the small elements.
4312 // This works for both vector and scalar sources
4313 Tmp = ComputeNumSignBits(V: Src, Q, Depth: Depth + 1);
4314 if (Tmp == SrcBits)
4315 return TyBits;
4316 }
4317 break;
4318 }
4319 case Instruction::SExt:
4320 Tmp = TyBits - U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4321 return ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1) +
4322 Tmp;
4323
4324 case Instruction::SDiv: {
4325 const APInt *Denominator;
4326 // sdiv X, C -> adds log(C) sign bits.
4327 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4328
4329 // Ignore non-positive denominator.
4330 if (!Denominator->isStrictlyPositive())
4331 break;
4332
4333 // Calculate the incoming numerator bits.
4334 unsigned NumBits =
4335 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4336
4337 // Add floor(log(C)) bits to the numerator bits.
4338 return std::min(a: TyBits, b: NumBits + Denominator->logBase2());
4339 }
4340 break;
4341 }
4342
4343 case Instruction::SRem: {
4344 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4345
4346 const APInt *Denominator;
4347 // srem X, C -> we know that the result is within [-C+1,C) when C is a
4348 // positive constant. This let us put a lower bound on the number of sign
4349 // bits.
4350 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4351
4352 // Ignore non-positive denominator.
4353 if (Denominator->isStrictlyPositive()) {
4354 // Calculate the leading sign bit constraints by examining the
4355 // denominator. Given that the denominator is positive, there are two
4356 // cases:
4357 //
4358 // 1. The numerator is positive. The result range is [0,C) and
4359 // [0,C) u< (1 << ceilLogBase2(C)).
4360 //
4361 // 2. The numerator is negative. Then the result range is (-C,0] and
4362 // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
4363 //
4364 // Thus a lower bound on the number of sign bits is `TyBits -
4365 // ceilLogBase2(C)`.
4366
4367 unsigned ResBits = TyBits - Denominator->ceilLogBase2();
4368 Tmp = std::max(a: Tmp, b: ResBits);
4369 }
4370 }
4371 return Tmp;
4372 }
4373
4374 case Instruction::AShr: {
4375 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4376 // ashr X, C -> adds C sign bits. Vectors too.
4377 const APInt *ShAmt;
4378 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4379 if (ShAmt->uge(RHS: TyBits))
4380 break; // Bad shift.
4381 unsigned ShAmtLimited = ShAmt->getZExtValue();
4382 Tmp += ShAmtLimited;
4383 if (Tmp > TyBits) Tmp = TyBits;
4384 }
4385 return Tmp;
4386 }
4387 case Instruction::Shl: {
4388 const APInt *ShAmt;
4389 Value *X = nullptr;
4390 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4391 // shl destroys sign bits.
4392 if (ShAmt->uge(RHS: TyBits))
4393 break; // Bad shift.
4394 // We can look through a zext (more or less treating it as a sext) if
4395 // all extended bits are shifted out.
4396 if (match(V: U->getOperand(i: 0), P: m_ZExt(Op: m_Value(V&: X))) &&
4397 ShAmt->uge(RHS: TyBits - X->getType()->getScalarSizeInBits())) {
4398 Tmp = ComputeNumSignBits(V: X, DemandedElts, Q, Depth: Depth + 1);
4399 Tmp += TyBits - X->getType()->getScalarSizeInBits();
4400 } else
4401 Tmp =
4402 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4403 if (ShAmt->uge(RHS: Tmp))
4404 break; // Shifted all sign bits out.
4405 Tmp2 = ShAmt->getZExtValue();
4406 return Tmp - Tmp2;
4407 }
4408 break;
4409 }
4410 case Instruction::And:
4411 case Instruction::Or:
4412 case Instruction::Xor: // NOT is handled here.
4413 // Logical binary ops preserve the number of sign bits at the worst.
4414 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4415 if (Tmp != 1) {
4416 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4417 FirstAnswer = std::min(a: Tmp, b: Tmp2);
4418 // We computed what we know about the sign bits as our first
4419 // answer. Now proceed to the generic code that uses
4420 // computeKnownBits, and pick whichever answer is better.
4421 }
4422 break;
4423
4424 case Instruction::Select: {
4425 // If we have a clamp pattern, we know that the number of sign bits will
4426 // be the minimum of the clamp min/max range.
4427 const Value *X;
4428 const APInt *CLow, *CHigh;
4429 if (isSignedMinMaxClamp(Select: U, In&: X, CLow, CHigh))
4430 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4431
4432 Tmp = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4433 if (Tmp == 1)
4434 break;
4435 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 2), DemandedElts, Q, Depth: Depth + 1);
4436 return std::min(a: Tmp, b: Tmp2);
4437 }
4438
4439 case Instruction::Add:
4440 // Add can have at most one carry bit. Thus we know that the output
4441 // is, at worst, one more bit than the inputs.
4442 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4443 if (Tmp == 1) break;
4444
4445 // Special case decrementing a value (ADD X, -1):
4446 if (const auto *CRHS = dyn_cast<Constant>(Val: U->getOperand(i: 1)))
4447 if (CRHS->isAllOnesValue()) {
4448 KnownBits Known(TyBits);
4449 computeKnownBits(V: U->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
4450
4451 // If the input is known to be 0 or 1, the output is 0/-1, which is
4452 // all sign bits set.
4453 if ((Known.Zero | 1).isAllOnes())
4454 return TyBits;
4455
4456 // If we are subtracting one from a positive number, there is no carry
4457 // out of the result.
4458 if (Known.isNonNegative())
4459 return Tmp;
4460 }
4461
4462 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4463 if (Tmp2 == 1)
4464 break;
4465 return std::min(a: Tmp, b: Tmp2) - 1;
4466
4467 case Instruction::Sub:
4468 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4469 if (Tmp2 == 1)
4470 break;
4471
4472 // Handle NEG.
4473 if (const auto *CLHS = dyn_cast<Constant>(Val: U->getOperand(i: 0)))
4474 if (CLHS->isNullValue()) {
4475 KnownBits Known(TyBits);
4476 computeKnownBits(V: U->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
4477 // If the input is known to be 0 or 1, the output is 0/-1, which is
4478 // all sign bits set.
4479 if ((Known.Zero | 1).isAllOnes())
4480 return TyBits;
4481
4482 // If the input is known to be positive (the sign bit is known clear),
4483 // the output of the NEG has the same number of sign bits as the
4484 // input.
4485 if (Known.isNonNegative())
4486 return Tmp2;
4487
4488 // Otherwise, we treat this like a SUB.
4489 }
4490
4491 // Sub can have at most one carry bit. Thus we know that the output
4492 // is, at worst, one more bit than the inputs.
4493 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4494 if (Tmp == 1)
4495 break;
4496 return std::min(a: Tmp, b: Tmp2) - 1;
4497
4498 case Instruction::Mul: {
4499 // The output of the Mul can be at most twice the valid bits in the
4500 // inputs.
4501 unsigned SignBitsOp0 =
4502 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4503 if (SignBitsOp0 == 1)
4504 break;
4505 unsigned SignBitsOp1 =
4506 ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4507 if (SignBitsOp1 == 1)
4508 break;
4509 unsigned OutValidBits =
4510 (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
4511 return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
4512 }
4513
4514 case Instruction::PHI: {
4515 const PHINode *PN = cast<PHINode>(Val: U);
4516 unsigned NumIncomingValues = PN->getNumIncomingValues();
4517 // Don't analyze large in-degree PHIs.
4518 if (NumIncomingValues > 4) break;
4519 // Unreachable blocks may have zero-operand PHI nodes.
4520 if (NumIncomingValues == 0) break;
4521
4522 // Take the minimum of all incoming values. This can't infinitely loop
4523 // because of our depth threshold.
4524 SimplifyQuery RecQ = Q.getWithoutCondContext();
4525 Tmp = TyBits;
4526 for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
4527 if (Tmp == 1) return Tmp;
4528 RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
4529 Tmp = std::min(a: Tmp, b: ComputeNumSignBits(V: PN->getIncomingValue(i),
4530 DemandedElts, Q: RecQ, Depth: Depth + 1));
4531 }
4532 return Tmp;
4533 }
4534
4535 case Instruction::Trunc: {
4536 // If the input contained enough sign bits that some remain after the
4537 // truncation, then we can make use of that. Otherwise we don't know
4538 // anything.
4539 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4540 unsigned OperandTyBits = U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4541 if (Tmp > (OperandTyBits - TyBits))
4542 return Tmp - (OperandTyBits - TyBits);
4543
4544 return 1;
4545 }
4546
4547 case Instruction::ExtractElement:
4548 // Look through extract element. At the moment we keep this simple and
4549 // skip tracking the specific element. But at least we might find
4550 // information valid for all elements of the vector (for example if vector
4551 // is sign extended, shifted, etc).
4552 return ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4553
4554 case Instruction::ShuffleVector: {
4555 // Collect the minimum number of sign bits that are shared by every vector
4556 // element referenced by the shuffle.
4557 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: U);
4558 if (!Shuf) {
4559 // FIXME: Add support for shufflevector constant expressions.
4560 return 1;
4561 }
4562 APInt DemandedLHS, DemandedRHS;
4563 // For undef elements, we don't know anything about the common state of
4564 // the shuffle result.
4565 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
4566 return 1;
4567 Tmp = std::numeric_limits<unsigned>::max();
4568 if (!!DemandedLHS) {
4569 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
4570 Tmp = ComputeNumSignBits(V: LHS, DemandedElts: DemandedLHS, Q, Depth: Depth + 1);
4571 }
4572 // If we don't know anything, early out and try computeKnownBits
4573 // fall-back.
4574 if (Tmp == 1)
4575 break;
4576 if (!!DemandedRHS) {
4577 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
4578 Tmp2 = ComputeNumSignBits(V: RHS, DemandedElts: DemandedRHS, Q, Depth: Depth + 1);
4579 Tmp = std::min(a: Tmp, b: Tmp2);
4580 }
4581 // If we don't know anything, early out and try computeKnownBits
4582 // fall-back.
4583 if (Tmp == 1)
4584 break;
4585 assert(Tmp <= TyBits && "Failed to determine minimum sign bits");
4586 return Tmp;
4587 }
4588 case Instruction::Call: {
4589 if (const auto *II = dyn_cast<IntrinsicInst>(Val: U)) {
4590 switch (II->getIntrinsicID()) {
4591 default:
4592 break;
4593 case Intrinsic::abs:
4594 Tmp =
4595 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4596 if (Tmp == 1)
4597 break;
4598
4599 // Absolute value reduces number of sign bits by at most 1.
4600 return Tmp - 1;
4601 case Intrinsic::smin:
4602 case Intrinsic::smax: {
4603 const APInt *CLow, *CHigh;
4604 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
4605 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4606 }
4607 }
4608 }
4609 }
4610 }
4611 }
4612
4613 // Finally, if we can prove that the top bits of the result are 0's or 1's,
4614 // use this information.
4615
4616 // If we can examine all elements of a vector constant successfully, we're
4617 // done (we can't do any better than that). If not, keep trying.
4618 if (unsigned VecSignBits =
4619 computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
4620 return VecSignBits;
4621
4622 KnownBits Known(TyBits);
4623 computeKnownBits(V, DemandedElts, Known, Q, Depth);
4624
4625 // If we know that the sign bit is either zero or one, determine the number of
4626 // identical bits in the top of the input value.
4627 return std::max(a: FirstAnswer, b: Known.countMinSignBits());
4628}
4629
4630Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
4631 const TargetLibraryInfo *TLI) {
4632 const Function *F = CB.getCalledFunction();
4633 if (!F)
4634 return Intrinsic::not_intrinsic;
4635
4636 if (F->isIntrinsic())
4637 return F->getIntrinsicID();
4638
4639 // We are going to infer semantics of a library function based on mapping it
4640 // to an LLVM intrinsic. Check that the library function is available from
4641 // this callbase and in this environment.
4642 LibFunc Func;
4643 if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, F&: Func) ||
4644 !CB.onlyReadsMemory())
4645 return Intrinsic::not_intrinsic;
4646
4647 switch (Func) {
4648 default:
4649 break;
4650 case LibFunc_sin:
4651 case LibFunc_sinf:
4652 case LibFunc_sinl:
4653 return Intrinsic::sin;
4654 case LibFunc_cos:
4655 case LibFunc_cosf:
4656 case LibFunc_cosl:
4657 return Intrinsic::cos;
4658 case LibFunc_tan:
4659 case LibFunc_tanf:
4660 case LibFunc_tanl:
4661 return Intrinsic::tan;
4662 case LibFunc_asin:
4663 case LibFunc_asinf:
4664 case LibFunc_asinl:
4665 return Intrinsic::asin;
4666 case LibFunc_acos:
4667 case LibFunc_acosf:
4668 case LibFunc_acosl:
4669 return Intrinsic::acos;
4670 case LibFunc_atan:
4671 case LibFunc_atanf:
4672 case LibFunc_atanl:
4673 return Intrinsic::atan;
4674 case LibFunc_atan2:
4675 case LibFunc_atan2f:
4676 case LibFunc_atan2l:
4677 return Intrinsic::atan2;
4678 case LibFunc_sinh:
4679 case LibFunc_sinhf:
4680 case LibFunc_sinhl:
4681 return Intrinsic::sinh;
4682 case LibFunc_cosh:
4683 case LibFunc_coshf:
4684 case LibFunc_coshl:
4685 return Intrinsic::cosh;
4686 case LibFunc_tanh:
4687 case LibFunc_tanhf:
4688 case LibFunc_tanhl:
4689 return Intrinsic::tanh;
4690 case LibFunc_exp:
4691 case LibFunc_expf:
4692 case LibFunc_expl:
4693 return Intrinsic::exp;
4694 case LibFunc_exp2:
4695 case LibFunc_exp2f:
4696 case LibFunc_exp2l:
4697 return Intrinsic::exp2;
4698 case LibFunc_exp10:
4699 case LibFunc_exp10f:
4700 case LibFunc_exp10l:
4701 return Intrinsic::exp10;
4702 case LibFunc_log:
4703 case LibFunc_logf:
4704 case LibFunc_logl:
4705 return Intrinsic::log;
4706 case LibFunc_log10:
4707 case LibFunc_log10f:
4708 case LibFunc_log10l:
4709 return Intrinsic::log10;
4710 case LibFunc_log2:
4711 case LibFunc_log2f:
4712 case LibFunc_log2l:
4713 return Intrinsic::log2;
4714 case LibFunc_fabs:
4715 case LibFunc_fabsf:
4716 case LibFunc_fabsl:
4717 return Intrinsic::fabs;
4718 case LibFunc_fmin:
4719 case LibFunc_fminf:
4720 case LibFunc_fminl:
4721 return Intrinsic::minnum;
4722 case LibFunc_fmax:
4723 case LibFunc_fmaxf:
4724 case LibFunc_fmaxl:
4725 return Intrinsic::maxnum;
4726 case LibFunc_copysign:
4727 case LibFunc_copysignf:
4728 case LibFunc_copysignl:
4729 return Intrinsic::copysign;
4730 case LibFunc_floor:
4731 case LibFunc_floorf:
4732 case LibFunc_floorl:
4733 return Intrinsic::floor;
4734 case LibFunc_ceil:
4735 case LibFunc_ceilf:
4736 case LibFunc_ceill:
4737 return Intrinsic::ceil;
4738 case LibFunc_trunc:
4739 case LibFunc_truncf:
4740 case LibFunc_truncl:
4741 return Intrinsic::trunc;
4742 case LibFunc_rint:
4743 case LibFunc_rintf:
4744 case LibFunc_rintl:
4745 return Intrinsic::rint;
4746 case LibFunc_nearbyint:
4747 case LibFunc_nearbyintf:
4748 case LibFunc_nearbyintl:
4749 return Intrinsic::nearbyint;
4750 case LibFunc_round:
4751 case LibFunc_roundf:
4752 case LibFunc_roundl:
4753 return Intrinsic::round;
4754 case LibFunc_roundeven:
4755 case LibFunc_roundevenf:
4756 case LibFunc_roundevenl:
4757 return Intrinsic::roundeven;
4758 case LibFunc_pow:
4759 case LibFunc_powf:
4760 case LibFunc_powl:
4761 return Intrinsic::pow;
4762 case LibFunc_sqrt:
4763 case LibFunc_sqrtf:
4764 case LibFunc_sqrtl:
4765 return Intrinsic::sqrt;
4766 }
4767
4768 return Intrinsic::not_intrinsic;
4769}
4770
4771/// Given an exploded icmp instruction, return true if the comparison only
4772/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
4773/// the result of the comparison is true when the input value is signed.
4774bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
4775 bool &TrueIfSigned) {
4776 switch (Pred) {
4777 case ICmpInst::ICMP_SLT: // True if LHS s< 0
4778 TrueIfSigned = true;
4779 return RHS.isZero();
4780 case ICmpInst::ICMP_SLE: // True if LHS s<= -1
4781 TrueIfSigned = true;
4782 return RHS.isAllOnes();
4783 case ICmpInst::ICMP_SGT: // True if LHS s> -1
4784 TrueIfSigned = false;
4785 return RHS.isAllOnes();
4786 case ICmpInst::ICMP_SGE: // True if LHS s>= 0
4787 TrueIfSigned = false;
4788 return RHS.isZero();
4789 case ICmpInst::ICMP_UGT:
4790 // True if LHS u> RHS and RHS == sign-bit-mask - 1
4791 TrueIfSigned = true;
4792 return RHS.isMaxSignedValue();
4793 case ICmpInst::ICMP_UGE:
4794 // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4795 TrueIfSigned = true;
4796 return RHS.isMinSignedValue();
4797 case ICmpInst::ICMP_ULT:
4798 // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4799 TrueIfSigned = false;
4800 return RHS.isMinSignedValue();
4801 case ICmpInst::ICMP_ULE:
4802 // True if LHS u<= RHS and RHS == sign-bit-mask - 1
4803 TrueIfSigned = false;
4804 return RHS.isMaxSignedValue();
4805 default:
4806 return false;
4807 }
4808}
4809
4810static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4811 bool CondIsTrue,
4812 const Instruction *CxtI,
4813 KnownFPClass &KnownFromContext,
4814 unsigned Depth = 0) {
4815 Value *A, *B;
4816 if (Depth < MaxAnalysisRecursionDepth &&
4817 (CondIsTrue ? match(V: Cond, P: m_LogicalAnd(L: m_Value(V&: A), R: m_Value(V&: B)))
4818 : match(V: Cond, P: m_LogicalOr(L: m_Value(V&: A), R: m_Value(V&: B))))) {
4819 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue, CxtI, KnownFromContext,
4820 Depth: Depth + 1);
4821 computeKnownFPClassFromCond(V, Cond: B, CondIsTrue, CxtI, KnownFromContext,
4822 Depth: Depth + 1);
4823 return;
4824 }
4825 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A)))) {
4826 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue: !CondIsTrue, CxtI, KnownFromContext,
4827 Depth: Depth + 1);
4828 return;
4829 }
4830 CmpPredicate Pred;
4831 Value *LHS;
4832 uint64_t ClassVal = 0;
4833 const APFloat *CRHS;
4834 const APInt *RHS;
4835 if (match(V: Cond, P: m_FCmp(Pred, L: m_Value(V&: LHS), R: m_APFloat(Res&: CRHS)))) {
4836 auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4837 Pred, F: *cast<Instruction>(Val: Cond)->getParent()->getParent(), LHS, ConstRHS: *CRHS,
4838 LookThroughSrc: LHS != V);
4839 if (CmpVal == V)
4840 KnownFromContext.knownNot(RuleOut: ~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4841 } else if (match(V: Cond, P: m_Intrinsic<Intrinsic::is_fpclass>(
4842 Op0: m_Specific(V), Op1: m_ConstantInt(V&: ClassVal)))) {
4843 FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4844 KnownFromContext.knownNot(RuleOut: CondIsTrue ? ~Mask : Mask);
4845 } else if (match(V: Cond, P: m_ICmp(Pred, L: m_ElementWiseBitCast(Op: m_Specific(V)),
4846 R: m_APInt(Res&: RHS)))) {
4847 bool TrueIfSigned;
4848 if (!isSignBitCheck(Pred, RHS: *RHS, TrueIfSigned))
4849 return;
4850 if (TrueIfSigned == CondIsTrue)
4851 KnownFromContext.signBitMustBeOne();
4852 else
4853 KnownFromContext.signBitMustBeZero();
4854 }
4855}
4856
4857static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4858 const SimplifyQuery &Q) {
4859 KnownFPClass KnownFromContext;
4860
4861 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
4862 computeKnownFPClassFromCond(V, Cond: Q.CC->Cond, CondIsTrue: !Q.CC->Invert, CxtI: Q.CxtI,
4863 KnownFromContext);
4864
4865 if (!Q.CxtI)
4866 return KnownFromContext;
4867
4868 if (Q.DC && Q.DT) {
4869 // Handle dominating conditions.
4870 for (CondBrInst *BI : Q.DC->conditionsFor(V)) {
4871 Value *Cond = BI->getCondition();
4872
4873 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4874 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
4875 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, CxtI: Q.CxtI,
4876 KnownFromContext);
4877
4878 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4879 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
4880 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, CxtI: Q.CxtI,
4881 KnownFromContext);
4882 }
4883 }
4884
4885 if (!Q.AC)
4886 return KnownFromContext;
4887
4888 // Try to restrict the floating-point classes based on information from
4889 // assumptions.
4890 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
4891 if (!AssumeVH)
4892 continue;
4893 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4894
4895 assert(I->getFunction() == Q.CxtI->getParent()->getParent() &&
4896 "Got assumption for the wrong function!");
4897 assert(I->getIntrinsicID() == Intrinsic::assume &&
4898 "must be an assume intrinsic");
4899
4900 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4901 continue;
4902
4903 computeKnownFPClassFromCond(V, Cond: I->getArgOperand(i: 0),
4904 /*CondIsTrue=*/true, CxtI: Q.CxtI, KnownFromContext);
4905 }
4906
4907 return KnownFromContext;
4908}
4909
4910void llvm::adjustKnownFPClassForSelectArm(KnownFPClass &Known, Value *Cond,
4911 Value *Arm, bool Invert,
4912 const SimplifyQuery &SQ,
4913 unsigned Depth) {
4914
4915 KnownFPClass KnownSrc;
4916 computeKnownFPClassFromCond(V: Arm, Cond,
4917 /*CondIsTrue=*/!Invert, CxtI: SQ.CxtI, KnownFromContext&: KnownSrc,
4918 Depth: Depth + 1);
4919 KnownSrc = KnownSrc.unionWith(RHS: Known);
4920 if (KnownSrc.isUnknown())
4921 return;
4922
4923 if (isGuaranteedNotToBeUndef(V: Arm, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT, Depth: Depth + 1))
4924 Known = KnownSrc;
4925}
4926
4927void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4928 FPClassTest InterestedClasses, KnownFPClass &Known,
4929 const SimplifyQuery &Q, unsigned Depth);
4930
4931static void computeKnownFPClass(const Value *V, KnownFPClass &Known,
4932 FPClassTest InterestedClasses,
4933 const SimplifyQuery &Q, unsigned Depth) {
4934 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
4935 APInt DemandedElts =
4936 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
4937 computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Q, Depth);
4938}
4939
4940static void computeKnownFPClassForFPTrunc(const Operator *Op,
4941 const APInt &DemandedElts,
4942 FPClassTest InterestedClasses,
4943 KnownFPClass &Known,
4944 const SimplifyQuery &Q,
4945 unsigned Depth) {
4946 if ((InterestedClasses &
4947 (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone)
4948 return;
4949
4950 KnownFPClass KnownSrc;
4951 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
4952 Known&: KnownSrc, Q, Depth: Depth + 1);
4953 Known = KnownFPClass::fptrunc(KnownSrc);
4954}
4955
4956static constexpr KnownFPClass::MinMaxKind getMinMaxKind(Intrinsic::ID IID) {
4957 switch (IID) {
4958 case Intrinsic::minimum:
4959 return KnownFPClass::MinMaxKind::minimum;
4960 case Intrinsic::maximum:
4961 return KnownFPClass::MinMaxKind::maximum;
4962 case Intrinsic::minimumnum:
4963 return KnownFPClass::MinMaxKind::minimumnum;
4964 case Intrinsic::maximumnum:
4965 return KnownFPClass::MinMaxKind::maximumnum;
4966 case Intrinsic::minnum:
4967 return KnownFPClass::MinMaxKind::minnum;
4968 case Intrinsic::maxnum:
4969 return KnownFPClass::MinMaxKind::maxnum;
4970 default:
4971 llvm_unreachable("not a floating-point min-max intrinsic");
4972 }
4973}
4974
4975/// \return true if this is a floating point value that is known to have a
4976/// magnitude smaller than 1. i.e., fabs(X) <= 1.0 or is nan.
4977static bool isAbsoluteValueULEOne(const Value *V) {
4978 // TODO: Handle frexp
4979 // TODO: Other rounding intrinsics?
4980
4981 // fabs(x - floor(x)) <= 1
4982 const Value *SubFloorX;
4983 if (match(V, P: m_FSub(L: m_Value(V&: SubFloorX),
4984 R: m_Intrinsic<Intrinsic::floor>(Op0: m_Deferred(V: SubFloorX)))))
4985 return true;
4986
4987 return match(V, P: m_Intrinsic<Intrinsic::amdgcn_trig_preop>(Op0: m_Value())) ||
4988 match(V, P: m_Intrinsic<Intrinsic::amdgcn_fract>(Op0: m_Value()));
4989}
4990
4991void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4992 FPClassTest InterestedClasses, KnownFPClass &Known,
4993 const SimplifyQuery &Q, unsigned Depth) {
4994 assert(Known.isUnknown() && "should not be called with known information");
4995
4996 if (!DemandedElts) {
4997 // No demanded elts, better to assume we don't know anything.
4998 Known.resetAll();
4999 return;
5000 }
5001
5002 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
5003
5004 if (auto *CFP = dyn_cast<ConstantFP>(Val: V)) {
5005 Known = KnownFPClass(CFP->getValueAPF());
5006 return;
5007 }
5008
5009 if (isa<ConstantAggregateZero>(Val: V)) {
5010 Known.KnownFPClasses = fcPosZero;
5011 Known.SignBit = false;
5012 return;
5013 }
5014
5015 if (isa<PoisonValue>(Val: V)) {
5016 Known.KnownFPClasses = fcNone;
5017 Known.SignBit = false;
5018 return;
5019 }
5020
5021 // Try to handle fixed width vector constants
5022 auto *VFVTy = dyn_cast<FixedVectorType>(Val: V->getType());
5023 const Constant *CV = dyn_cast<Constant>(Val: V);
5024 if (VFVTy && CV) {
5025 Known.KnownFPClasses = fcNone;
5026 bool SignBitAllZero = true;
5027 bool SignBitAllOne = true;
5028
5029 // For vectors, verify that each element is not NaN.
5030 unsigned NumElts = VFVTy->getNumElements();
5031 for (unsigned i = 0; i != NumElts; ++i) {
5032 if (!DemandedElts[i])
5033 continue;
5034
5035 Constant *Elt = CV->getAggregateElement(Elt: i);
5036 if (!Elt) {
5037 Known = KnownFPClass();
5038 return;
5039 }
5040 if (isa<PoisonValue>(Val: Elt))
5041 continue;
5042 auto *CElt = dyn_cast<ConstantFP>(Val: Elt);
5043 if (!CElt) {
5044 Known = KnownFPClass();
5045 return;
5046 }
5047
5048 const APFloat &C = CElt->getValueAPF();
5049 Known.KnownFPClasses |= C.classify();
5050 if (C.isNegative())
5051 SignBitAllZero = false;
5052 else
5053 SignBitAllOne = false;
5054 }
5055 if (SignBitAllOne != SignBitAllZero)
5056 Known.SignBit = SignBitAllOne;
5057 return;
5058 }
5059
5060 if (const auto *CDS = dyn_cast<ConstantDataSequential>(Val: V)) {
5061 Known.KnownFPClasses = fcNone;
5062 for (size_t I = 0, E = CDS->getNumElements(); I != E; ++I)
5063 Known |= CDS->getElementAsAPFloat(i: I).classify();
5064 return;
5065 }
5066
5067 if (const auto *CA = dyn_cast<ConstantAggregate>(Val: V)) {
5068 // TODO: Handle complex aggregates
5069 Known.KnownFPClasses = fcNone;
5070 for (const Use &Op : CA->operands()) {
5071 auto *CFP = dyn_cast<ConstantFP>(Val: Op.get());
5072 if (!CFP) {
5073 Known = KnownFPClass();
5074 return;
5075 }
5076
5077 Known |= CFP->getValueAPF().classify();
5078 }
5079
5080 return;
5081 }
5082
5083 FPClassTest KnownNotFromFlags = fcNone;
5084 if (const auto *CB = dyn_cast<CallBase>(Val: V))
5085 KnownNotFromFlags |= CB->getRetNoFPClass();
5086 else if (const auto *Arg = dyn_cast<Argument>(Val: V))
5087 KnownNotFromFlags |= Arg->getNoFPClass();
5088
5089 const Operator *Op = dyn_cast<Operator>(Val: V);
5090 if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Val: Op)) {
5091 if (FPOp->hasNoNaNs())
5092 KnownNotFromFlags |= fcNan;
5093 if (FPOp->hasNoInfs())
5094 KnownNotFromFlags |= fcInf;
5095 }
5096
5097 KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
5098 KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
5099
5100 // We no longer need to find out about these bits from inputs if we can
5101 // assume this from flags/attributes.
5102 InterestedClasses &= ~KnownNotFromFlags;
5103
5104 llvm::scope_exit ClearClassesFromFlags([=, &Known] {
5105 Known.knownNot(RuleOut: KnownNotFromFlags);
5106 if (!Known.SignBit && AssumedClasses.SignBit) {
5107 if (*AssumedClasses.SignBit)
5108 Known.signBitMustBeOne();
5109 else
5110 Known.signBitMustBeZero();
5111 }
5112 });
5113
5114 if (!Op)
5115 return;
5116
5117 // All recursive calls that increase depth must come after this.
5118 if (Depth == MaxAnalysisRecursionDepth)
5119 return;
5120
5121 const unsigned Opc = Op->getOpcode();
5122 switch (Opc) {
5123 case Instruction::FNeg: {
5124 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5125 Known, Q, Depth: Depth + 1);
5126 Known.fneg();
5127 break;
5128 }
5129 case Instruction::Select: {
5130 auto ComputeForArm = [&](Value *Arm, bool Invert) {
5131 KnownFPClass Res;
5132 computeKnownFPClass(V: Arm, DemandedElts, InterestedClasses, Known&: Res, Q,
5133 Depth: Depth + 1);
5134 adjustKnownFPClassForSelectArm(Known&: Res, Cond: Op->getOperand(i: 0), Arm, Invert, SQ: Q,
5135 Depth);
5136 return Res;
5137 };
5138 // Only known if known in both the LHS and RHS.
5139 Known =
5140 ComputeForArm(Op->getOperand(i: 1), /*Invert=*/false)
5141 .intersectWith(RHS: ComputeForArm(Op->getOperand(i: 2), /*Invert=*/true));
5142 break;
5143 }
5144 case Instruction::Load: {
5145 const MDNode *NoFPClass =
5146 cast<LoadInst>(Val: Op)->getMetadata(KindID: LLVMContext::MD_nofpclass);
5147 if (!NoFPClass)
5148 break;
5149
5150 ConstantInt *MaskVal =
5151 mdconst::extract<ConstantInt>(MD: NoFPClass->getOperand(I: 0));
5152 Known.knownNot(RuleOut: static_cast<FPClassTest>(MaskVal->getZExtValue()));
5153 break;
5154 }
5155 case Instruction::Call: {
5156 const CallInst *II = cast<CallInst>(Val: Op);
5157 const Intrinsic::ID IID = II->getIntrinsicID();
5158 switch (IID) {
5159 case Intrinsic::fabs: {
5160 if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) {
5161 // If we only care about the sign bit we don't need to inspect the
5162 // operand.
5163 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5164 InterestedClasses, Known, Q, Depth: Depth + 1);
5165 }
5166
5167 Known.fabs();
5168 break;
5169 }
5170 case Intrinsic::copysign: {
5171 KnownFPClass KnownSign;
5172
5173 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5174 Known, Q, Depth: Depth + 1);
5175 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5176 Known&: KnownSign, Q, Depth: Depth + 1);
5177 Known.copysign(Sign: KnownSign);
5178 break;
5179 }
5180 case Intrinsic::fma:
5181 case Intrinsic::fmuladd: {
5182 if ((InterestedClasses & fcNegative) == fcNone)
5183 break;
5184
5185 // FIXME: This should check isGuaranteedNotToBeUndef
5186 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1)) {
5187 KnownFPClass KnownSrc, KnownAddend;
5188 computeKnownFPClass(V: II->getArgOperand(i: 2), DemandedElts,
5189 InterestedClasses, Known&: KnownAddend, Q, Depth: Depth + 1);
5190 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5191 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5192
5193 const Function *F = II->getFunction();
5194 const fltSemantics &FltSem =
5195 II->getType()->getScalarType()->getFltSemantics();
5196 DenormalMode Mode =
5197 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5198
5199 if (KnownNotFromFlags & fcNan) {
5200 KnownSrc.knownNot(RuleOut: fcNan);
5201 KnownAddend.knownNot(RuleOut: fcNan);
5202 }
5203
5204 if (KnownNotFromFlags & fcInf) {
5205 KnownSrc.knownNot(RuleOut: fcInf);
5206 KnownAddend.knownNot(RuleOut: fcInf);
5207 }
5208
5209 Known = KnownFPClass::fma_square(Squared: KnownSrc, Addend: KnownAddend, Mode);
5210 break;
5211 }
5212
5213 KnownFPClass KnownSrc[3];
5214 for (int I = 0; I != 3; ++I) {
5215 computeKnownFPClass(V: II->getArgOperand(i: I), DemandedElts,
5216 InterestedClasses, Known&: KnownSrc[I], Q, Depth: Depth + 1);
5217 if (KnownSrc[I].isUnknown())
5218 return;
5219
5220 if (KnownNotFromFlags & fcNan)
5221 KnownSrc[I].knownNot(RuleOut: fcNan);
5222 if (KnownNotFromFlags & fcInf)
5223 KnownSrc[I].knownNot(RuleOut: fcInf);
5224 }
5225
5226 const Function *F = II->getFunction();
5227 const fltSemantics &FltSem =
5228 II->getType()->getScalarType()->getFltSemantics();
5229 DenormalMode Mode =
5230 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5231 Known = KnownFPClass::fma(LHS: KnownSrc[0], RHS: KnownSrc[1], Addend: KnownSrc[2], Mode);
5232 break;
5233 }
5234 case Intrinsic::sqrt:
5235 case Intrinsic::experimental_constrained_sqrt: {
5236 KnownFPClass KnownSrc;
5237 FPClassTest InterestedSrcs = InterestedClasses;
5238 if (InterestedClasses & fcNan)
5239 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5240
5241 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5242 Known&: KnownSrc, Q, Depth: Depth + 1);
5243
5244 DenormalMode Mode = DenormalMode::getDynamic();
5245
5246 bool HasNSZ = Q.IIQ.hasNoSignedZeros(Op: II);
5247 if (!HasNSZ) {
5248 const Function *F = II->getFunction();
5249 const fltSemantics &FltSem =
5250 II->getType()->getScalarType()->getFltSemantics();
5251 Mode = F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5252 }
5253
5254 Known = KnownFPClass::sqrt(Src: KnownSrc, Mode);
5255 if (HasNSZ)
5256 Known.knownNot(RuleOut: fcNegZero);
5257
5258 break;
5259 }
5260 case Intrinsic::sin:
5261 case Intrinsic::cos: {
5262 // Return NaN on infinite inputs.
5263 KnownFPClass KnownSrc;
5264 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5265 Known&: KnownSrc, Q, Depth: Depth + 1);
5266 Known = IID == Intrinsic::sin ? KnownFPClass::sin(Src: KnownSrc)
5267 : KnownFPClass::cos(Src: KnownSrc);
5268 break;
5269 }
5270 case Intrinsic::maxnum:
5271 case Intrinsic::minnum:
5272 case Intrinsic::minimum:
5273 case Intrinsic::maximum:
5274 case Intrinsic::minimumnum:
5275 case Intrinsic::maximumnum: {
5276 KnownFPClass KnownLHS, KnownRHS;
5277 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5278 Known&: KnownLHS, Q, Depth: Depth + 1);
5279 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5280 Known&: KnownRHS, Q, Depth: Depth + 1);
5281
5282 const Function *F = II->getFunction();
5283
5284 DenormalMode Mode =
5285 F ? F->getDenormalMode(
5286 FPType: II->getType()->getScalarType()->getFltSemantics())
5287 : DenormalMode::getDynamic();
5288
5289 Known = KnownFPClass::minMaxLike(LHS: KnownLHS, RHS: KnownRHS, Kind: getMinMaxKind(IID),
5290 DenormMode: Mode);
5291 break;
5292 }
5293 case Intrinsic::canonicalize: {
5294 KnownFPClass KnownSrc;
5295 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5296 Known&: KnownSrc, Q, Depth: Depth + 1);
5297
5298 const Function *F = II->getFunction();
5299 DenormalMode DenormMode =
5300 F ? F->getDenormalMode(
5301 FPType: II->getType()->getScalarType()->getFltSemantics())
5302 : DenormalMode::getDynamic();
5303 Known = KnownFPClass::canonicalize(Src: KnownSrc, DenormMode);
5304 break;
5305 }
5306 case Intrinsic::vector_reduce_fmax:
5307 case Intrinsic::vector_reduce_fmin:
5308 case Intrinsic::vector_reduce_fmaximum:
5309 case Intrinsic::vector_reduce_fminimum: {
5310 // reduce min/max will choose an element from one of the vector elements,
5311 // so we can infer and class information that is common to all elements.
5312 Known = computeKnownFPClass(V: II->getArgOperand(i: 0), FMF: II->getFastMathFlags(),
5313 InterestedClasses, SQ: Q, Depth: Depth + 1);
5314 // Can only propagate sign if output is never NaN.
5315 if (!Known.isKnownNeverNaN())
5316 Known.SignBit.reset();
5317 break;
5318 }
5319 // reverse preserves all characteristics of the input vec's element.
5320 case Intrinsic::vector_reverse:
5321 Known = computeKnownFPClass(
5322 V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
5323 FMF: II->getFastMathFlags(), InterestedClasses, SQ: Q, Depth: Depth + 1);
5324 break;
5325 case Intrinsic::trunc:
5326 case Intrinsic::floor:
5327 case Intrinsic::ceil:
5328 case Intrinsic::rint:
5329 case Intrinsic::nearbyint:
5330 case Intrinsic::round:
5331 case Intrinsic::roundeven: {
5332 KnownFPClass KnownSrc;
5333 FPClassTest InterestedSrcs = InterestedClasses;
5334 if (InterestedSrcs & fcPosFinite)
5335 InterestedSrcs |= fcPosFinite;
5336 if (InterestedSrcs & fcNegFinite)
5337 InterestedSrcs |= fcNegFinite;
5338 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5339 Known&: KnownSrc, Q, Depth: Depth + 1);
5340
5341 Known = KnownFPClass::roundToIntegral(
5342 Src: KnownSrc, IsTrunc: IID == Intrinsic::trunc,
5343 IsMultiUnitFPType: V->getType()->getScalarType()->isMultiUnitFPType());
5344 break;
5345 }
5346 case Intrinsic::exp:
5347 case Intrinsic::exp2:
5348 case Intrinsic::exp10:
5349 case Intrinsic::amdgcn_exp2: {
5350 KnownFPClass KnownSrc;
5351 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5352 Known&: KnownSrc, Q, Depth: Depth + 1);
5353
5354 Known = KnownFPClass::exp(Src: KnownSrc);
5355
5356 Type *EltTy = II->getType()->getScalarType();
5357 if (IID == Intrinsic::amdgcn_exp2 && EltTy->isFloatTy())
5358 Known.knownNot(RuleOut: fcSubnormal);
5359
5360 break;
5361 }
5362 case Intrinsic::fptrunc_round: {
5363 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5364 Q, Depth);
5365 break;
5366 }
5367 case Intrinsic::log:
5368 case Intrinsic::log10:
5369 case Intrinsic::log2:
5370 case Intrinsic::experimental_constrained_log:
5371 case Intrinsic::experimental_constrained_log10:
5372 case Intrinsic::experimental_constrained_log2:
5373 case Intrinsic::amdgcn_log: {
5374 Type *EltTy = II->getType()->getScalarType();
5375
5376 // log(+inf) -> +inf
5377 // log([+-]0.0) -> -inf
5378 // log(-inf) -> nan
5379 // log(-x) -> nan
5380 if ((InterestedClasses & (fcNan | fcInf)) != fcNone) {
5381 FPClassTest InterestedSrcs = InterestedClasses;
5382 if ((InterestedClasses & fcNegInf) != fcNone)
5383 InterestedSrcs |= fcZero | fcSubnormal;
5384 if ((InterestedClasses & fcNan) != fcNone)
5385 InterestedSrcs |= fcNan | fcNegative;
5386
5387 KnownFPClass KnownSrc;
5388 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5389 Known&: KnownSrc, Q, Depth: Depth + 1);
5390
5391 const Function *F = II->getFunction();
5392 DenormalMode Mode = F ? F->getDenormalMode(FPType: EltTy->getFltSemantics())
5393 : DenormalMode::getDynamic();
5394 Known = KnownFPClass::log(Src: KnownSrc, Mode);
5395 }
5396
5397 break;
5398 }
5399 case Intrinsic::powi: {
5400 if ((InterestedClasses & fcNegative) == fcNone)
5401 break;
5402
5403 const Value *Exp = II->getArgOperand(i: 1);
5404 Type *ExpTy = Exp->getType();
5405 unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth();
5406 KnownBits ExponentKnownBits(BitWidth);
5407 computeKnownBits(V: Exp, DemandedElts: isa<VectorType>(Val: ExpTy) ? DemandedElts : APInt(1, 1),
5408 Known&: ExponentKnownBits, Q, Depth: Depth + 1);
5409
5410 KnownFPClass KnownSrc;
5411 if (ExponentKnownBits.isZero() || !ExponentKnownBits.isEven()) {
5412 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: fcNegative,
5413 Known&: KnownSrc, Q, Depth: Depth + 1);
5414 }
5415
5416 Known = KnownFPClass::powi(Src: KnownSrc, N: ExponentKnownBits);
5417 break;
5418 }
5419 case Intrinsic::ldexp: {
5420 KnownFPClass KnownSrc;
5421 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5422 Known&: KnownSrc, Q, Depth: Depth + 1);
5423 // Can refine inf/zero handling based on the exponent operand.
5424 const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf;
5425
5426 KnownBits ExpBits;
5427 if ((KnownSrc.KnownFPClasses & ExpInfoMask) != fcNone) {
5428 const Value *ExpArg = II->getArgOperand(i: 1);
5429 ExpBits = computeKnownBits(V: ExpArg, DemandedElts, Q, Depth: Depth + 1);
5430 }
5431
5432 const fltSemantics &Flt =
5433 II->getType()->getScalarType()->getFltSemantics();
5434
5435 const Function *F = II->getFunction();
5436 DenormalMode Mode =
5437 F ? F->getDenormalMode(FPType: Flt) : DenormalMode::getDynamic();
5438
5439 Known = KnownFPClass::ldexp(Src: KnownSrc, N: ExpBits, Flt, Mode);
5440 break;
5441 }
5442 case Intrinsic::arithmetic_fence: {
5443 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5444 Known, Q, Depth: Depth + 1);
5445 break;
5446 }
5447 case Intrinsic::experimental_constrained_sitofp:
5448 case Intrinsic::experimental_constrained_uitofp:
5449 // Cannot produce nan
5450 Known.knownNot(RuleOut: fcNan);
5451
5452 // sitofp and uitofp turn into +0.0 for zero.
5453 Known.knownNot(RuleOut: fcNegZero);
5454
5455 // Integers cannot be subnormal
5456 Known.knownNot(RuleOut: fcSubnormal);
5457
5458 if (IID == Intrinsic::experimental_constrained_uitofp)
5459 Known.signBitMustBeZero();
5460
5461 // TODO: Copy inf handling from instructions
5462 break;
5463
5464 case Intrinsic::amdgcn_fract: {
5465 Known.knownNot(RuleOut: fcInf);
5466
5467 if (InterestedClasses & fcNan) {
5468 KnownFPClass KnownSrc;
5469 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5470 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5471
5472 if (KnownSrc.isKnownNeverInfOrNaN())
5473 Known.knownNot(RuleOut: fcNan);
5474 else if (KnownSrc.isKnownNever(Mask: fcSNan))
5475 Known.knownNot(RuleOut: fcSNan);
5476 }
5477
5478 break;
5479 }
5480 case Intrinsic::amdgcn_rcp: {
5481 KnownFPClass KnownSrc;
5482 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5483 Known&: KnownSrc, Q, Depth: Depth + 1);
5484
5485 Known.propagateNaN(Src: KnownSrc);
5486
5487 Type *EltTy = II->getType()->getScalarType();
5488
5489 // f32 denormal always flushed.
5490 if (EltTy->isFloatTy()) {
5491 Known.knownNot(RuleOut: fcSubnormal);
5492 KnownSrc.knownNot(RuleOut: fcSubnormal);
5493 }
5494
5495 if (KnownSrc.isKnownNever(Mask: fcNegative))
5496 Known.knownNot(RuleOut: fcNegative);
5497 if (KnownSrc.isKnownNever(Mask: fcPositive))
5498 Known.knownNot(RuleOut: fcPositive);
5499
5500 if (const Function *F = II->getFunction()) {
5501 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5502 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5503 Known.knownNot(RuleOut: fcPosInf);
5504 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5505 Known.knownNot(RuleOut: fcNegInf);
5506 }
5507
5508 break;
5509 }
5510 case Intrinsic::amdgcn_rsq: {
5511 KnownFPClass KnownSrc;
5512 // The only negative value that can be returned is -inf for -0 inputs.
5513 Known.knownNot(RuleOut: fcNegZero | fcNegSubnormal | fcNegNormal);
5514
5515 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5516 Known&: KnownSrc, Q, Depth: Depth + 1);
5517
5518 // Negative -> nan
5519 if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5520 Known.knownNot(RuleOut: fcNan);
5521 else if (KnownSrc.isKnownNever(Mask: fcSNan))
5522 Known.knownNot(RuleOut: fcSNan);
5523
5524 // +inf -> +0
5525 if (KnownSrc.isKnownNeverPosInfinity())
5526 Known.knownNot(RuleOut: fcPosZero);
5527
5528 Type *EltTy = II->getType()->getScalarType();
5529
5530 // f32 denormal always flushed.
5531 if (EltTy->isFloatTy())
5532 Known.knownNot(RuleOut: fcPosSubnormal);
5533
5534 if (const Function *F = II->getFunction()) {
5535 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5536
5537 // -0 -> -inf
5538 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5539 Known.knownNot(RuleOut: fcNegInf);
5540
5541 // +0 -> +inf
5542 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5543 Known.knownNot(RuleOut: fcPosInf);
5544 }
5545
5546 break;
5547 }
5548 case Intrinsic::amdgcn_trig_preop: {
5549 // Always returns a value [0, 1)
5550 Known.knownNot(RuleOut: fcNan | fcInf | fcNegative);
5551 break;
5552 }
5553 default:
5554 break;
5555 }
5556
5557 break;
5558 }
5559 case Instruction::FAdd:
5560 case Instruction::FSub: {
5561 KnownFPClass KnownLHS, KnownRHS;
5562 bool WantNegative =
5563 Op->getOpcode() == Instruction::FAdd &&
5564 (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone;
5565 bool WantNaN = (InterestedClasses & fcNan) != fcNone;
5566 bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone;
5567
5568 if (!WantNaN && !WantNegative && !WantNegZero)
5569 break;
5570
5571 FPClassTest InterestedSrcs = InterestedClasses;
5572 if (WantNegative)
5573 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5574 if (InterestedClasses & fcNan)
5575 InterestedSrcs |= fcInf;
5576 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts, InterestedClasses: InterestedSrcs,
5577 Known&: KnownRHS, Q, Depth: Depth + 1);
5578
5579 // Special case fadd x, x, which is the canonical form of fmul x, 2.
5580 bool Self = Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5581 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
5582 Depth: Depth + 1);
5583 if (Self)
5584 KnownLHS = KnownRHS;
5585
5586 if ((WantNaN && KnownRHS.isKnownNeverNaN()) ||
5587 (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) ||
5588 WantNegZero || Opc == Instruction::FSub) {
5589
5590 // FIXME: Context function should always be passed in separately
5591 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5592 const fltSemantics &FltSem =
5593 Op->getType()->getScalarType()->getFltSemantics();
5594 DenormalMode Mode =
5595 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5596
5597 if (Self && Opc == Instruction::FAdd) {
5598 Known = KnownFPClass::fadd_self(Src: KnownLHS, Mode);
5599 } else {
5600 // RHS is canonically cheaper to compute. Skip inspecting the LHS if
5601 // there's no point.
5602
5603 if (!Self) {
5604 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5605 Known&: KnownLHS, Q, Depth: Depth + 1);
5606 }
5607
5608 Known = Opc == Instruction::FAdd
5609 ? KnownFPClass::fadd(LHS: KnownLHS, RHS: KnownRHS, Mode)
5610 : KnownFPClass::fsub(LHS: KnownLHS, RHS: KnownRHS, Mode);
5611 }
5612 }
5613
5614 break;
5615 }
5616 case Instruction::FMul: {
5617 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5618 DenormalMode Mode =
5619 F ? F->getDenormalMode(
5620 FPType: Op->getType()->getScalarType()->getFltSemantics())
5621 : DenormalMode::getDynamic();
5622
5623 Value *LHS = Op->getOperand(i: 0);
5624 Value *RHS = Op->getOperand(i: 1);
5625 // X * X is always non-negative or a NaN.
5626 // FIXME: Should check isGuaranteedNotToBeUndef
5627 if (LHS == RHS) {
5628 KnownFPClass KnownSrc;
5629 computeKnownFPClass(V: LHS, DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownSrc, Q,
5630 Depth: Depth + 1);
5631 Known = KnownFPClass::square(Src: KnownSrc, Mode);
5632 break;
5633 }
5634
5635 KnownFPClass KnownLHS, KnownRHS;
5636
5637 const APFloat *CRHS;
5638 if (match(V: RHS, P: m_APFloat(Res&: CRHS))) {
5639 computeKnownFPClass(V: LHS, DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS, Q,
5640 Depth: Depth + 1);
5641 Known = KnownFPClass::fmul(LHS: KnownLHS, RHS: *CRHS, Mode);
5642 } else {
5643 computeKnownFPClass(V: RHS, DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownRHS, Q,
5644 Depth: Depth + 1);
5645 // TODO: Improve accuracy in unfused FMA pattern. We can prove an
5646 // additional not-nan if the addend is known-not negative infinity if the
5647 // multiply is known-not infinity.
5648
5649 computeKnownFPClass(V: LHS, DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS, Q,
5650 Depth: Depth + 1);
5651 Known = KnownFPClass::fmul(LHS: KnownLHS, RHS: KnownRHS, Mode);
5652 }
5653
5654 /// Propgate no-infs if the other source is known smaller than one, such
5655 /// that this cannot introduce overflow.
5656 if (KnownLHS.isKnownNever(Mask: fcInf) && isAbsoluteValueULEOne(V: RHS))
5657 Known.knownNot(RuleOut: fcInf);
5658 else if (KnownRHS.isKnownNever(Mask: fcInf) && isAbsoluteValueULEOne(V: LHS))
5659 Known.knownNot(RuleOut: fcInf);
5660
5661 break;
5662 }
5663 case Instruction::FDiv:
5664 case Instruction::FRem: {
5665 const bool WantNan = (InterestedClasses & fcNan) != fcNone;
5666
5667 if (Op->getOpcode() == Instruction::FRem)
5668 Known.knownNot(RuleOut: fcInf);
5669
5670 if (Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5671 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT)) {
5672 if (Op->getOpcode() == Instruction::FDiv) {
5673 // X / X is always exactly 1.0 or a NaN.
5674 Known.KnownFPClasses = fcNan | fcPosNormal;
5675 } else {
5676 // X % X is always exactly [+-]0.0 or a NaN.
5677 Known.KnownFPClasses = fcNan | fcZero;
5678 }
5679
5680 if (!WantNan)
5681 break;
5682
5683 KnownFPClass KnownSrc;
5684 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts,
5685 InterestedClasses: fcNan | fcInf | fcZero | fcSubnormal, Known&: KnownSrc, Q,
5686 Depth: Depth + 1);
5687 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5688 const fltSemantics &FltSem =
5689 Op->getType()->getScalarType()->getFltSemantics();
5690
5691 DenormalMode Mode =
5692 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5693
5694 Known = Op->getOpcode() == Instruction::FDiv
5695 ? KnownFPClass::fdiv_self(Src: KnownSrc, Mode)
5696 : KnownFPClass::frem_self(Src: KnownSrc, Mode);
5697 break;
5698 }
5699
5700 const bool WantNegative = (InterestedClasses & fcNegative) != fcNone;
5701 const bool WantPositive =
5702 Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone;
5703 if (!WantNan && !WantNegative && !WantPositive)
5704 break;
5705
5706 KnownFPClass KnownLHS, KnownRHS;
5707
5708 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts,
5709 InterestedClasses: fcNan | fcInf | fcZero | fcNegative, Known&: KnownRHS, Q,
5710 Depth: Depth + 1);
5711
5712 bool KnowSomethingUseful = KnownRHS.isKnownNeverNaN() ||
5713 KnownRHS.isKnownNever(Mask: fcNegative) ||
5714 KnownRHS.isKnownNever(Mask: fcPositive);
5715
5716 if (KnowSomethingUseful || WantPositive) {
5717 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS,
5718 Q, Depth: Depth + 1);
5719 }
5720
5721 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5722 const fltSemantics &FltSem =
5723 Op->getType()->getScalarType()->getFltSemantics();
5724
5725 if (Op->getOpcode() == Instruction::FDiv) {
5726 DenormalMode Mode =
5727 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5728 Known = KnownFPClass::fdiv(LHS: KnownLHS, RHS: KnownRHS, Mode);
5729 } else {
5730 // Inf REM x and x REM 0 produce NaN.
5731 if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5732 KnownLHS.isKnownNeverInfinity() && F &&
5733 KnownRHS.isKnownNeverLogicalZero(Mode: F->getDenormalMode(FPType: FltSem))) {
5734 Known.knownNot(RuleOut: fcNan);
5735 }
5736
5737 // The sign for frem is the same as the first operand.
5738 if (KnownLHS.cannotBeOrderedLessThanZero())
5739 Known.knownNot(RuleOut: KnownFPClass::OrderedLessThanZeroMask);
5740 if (KnownLHS.cannotBeOrderedGreaterThanZero())
5741 Known.knownNot(RuleOut: KnownFPClass::OrderedGreaterThanZeroMask);
5742
5743 // See if we can be more aggressive about the sign of 0.
5744 if (KnownLHS.isKnownNever(Mask: fcNegative))
5745 Known.knownNot(RuleOut: fcNegative);
5746 if (KnownLHS.isKnownNever(Mask: fcPositive))
5747 Known.knownNot(RuleOut: fcPositive);
5748 }
5749
5750 break;
5751 }
5752 case Instruction::FPExt: {
5753 KnownFPClass KnownSrc;
5754 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5755 Known&: KnownSrc, Q, Depth: Depth + 1);
5756
5757 const fltSemantics &DstTy =
5758 Op->getType()->getScalarType()->getFltSemantics();
5759 const fltSemantics &SrcTy =
5760 Op->getOperand(i: 0)->getType()->getScalarType()->getFltSemantics();
5761
5762 Known = KnownFPClass::fpext(KnownSrc, DstTy, SrcTy);
5763 break;
5764 }
5765 case Instruction::FPTrunc: {
5766 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known, Q,
5767 Depth);
5768 break;
5769 }
5770 case Instruction::SIToFP:
5771 case Instruction::UIToFP: {
5772 // Cannot produce nan
5773 Known.knownNot(RuleOut: fcNan);
5774
5775 // Integers cannot be subnormal
5776 Known.knownNot(RuleOut: fcSubnormal);
5777
5778 // sitofp and uitofp turn into +0.0 for zero.
5779 Known.knownNot(RuleOut: fcNegZero);
5780
5781 // UIToFP is always non-negative regardless of known bits.
5782 if (Op->getOpcode() == Instruction::UIToFP)
5783 Known.signBitMustBeZero();
5784
5785 // Only compute known bits if we can learn something useful from them.
5786 if (!(InterestedClasses & (fcPosZero | fcNormal | fcInf)))
5787 break;
5788
5789 KnownBits IntKnown =
5790 computeKnownBits(V: Op->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
5791
5792 // If the integer is non-zero, the result cannot be +0.0
5793 if (IntKnown.isNonZero())
5794 Known.knownNot(RuleOut: fcPosZero);
5795
5796 if (Op->getOpcode() == Instruction::SIToFP) {
5797 // If the signed integer is known non-negative, the result is
5798 // non-negative. If the signed integer is known negative, the result is
5799 // negative.
5800 if (IntKnown.isNonNegative()) {
5801 Known.signBitMustBeZero();
5802 } else if (IntKnown.isNegative()) {
5803 Known.signBitMustBeOne();
5804 }
5805 }
5806
5807 // Guard kept for ilogb()
5808 if (InterestedClasses & fcInf) {
5809 // Get width of largest magnitude integer known.
5810 // This still works for a signed minimum value because the largest FP
5811 // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
5812 int IntSize = IntKnown.getBitWidth();
5813 if (Op->getOpcode() == Instruction::UIToFP)
5814 IntSize -= IntKnown.countMinLeadingZeros();
5815 else if (Op->getOpcode() == Instruction::SIToFP)
5816 IntSize -= IntKnown.countMinSignBits();
5817
5818 // If the exponent of the largest finite FP value can hold the largest
5819 // integer, the result of the cast must be finite.
5820 Type *FPTy = Op->getType()->getScalarType();
5821 if (ilogb(Arg: APFloat::getLargest(Sem: FPTy->getFltSemantics())) >= IntSize)
5822 Known.knownNot(RuleOut: fcInf);
5823 }
5824
5825 break;
5826 }
5827 case Instruction::ExtractElement: {
5828 // Look through extract element. If the index is non-constant or
5829 // out-of-range demand all elements, otherwise just the extracted element.
5830 const Value *Vec = Op->getOperand(i: 0);
5831
5832 APInt DemandedVecElts;
5833 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
5834 unsigned NumElts = VecTy->getNumElements();
5835 DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
5836 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 1));
5837 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
5838 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
5839 } else {
5840 DemandedVecElts = APInt(1, 1);
5841 }
5842
5843 return computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known,
5844 Q, Depth: Depth + 1);
5845 }
5846 case Instruction::InsertElement: {
5847 if (isa<ScalableVectorType>(Val: Op->getType()))
5848 return;
5849
5850 const Value *Vec = Op->getOperand(i: 0);
5851 const Value *Elt = Op->getOperand(i: 1);
5852 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 2));
5853 unsigned NumElts = DemandedElts.getBitWidth();
5854 APInt DemandedVecElts = DemandedElts;
5855 bool NeedsElt = true;
5856 // If we know the index we are inserting to, clear it from Vec check.
5857 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
5858 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
5859 NeedsElt = DemandedElts[CIdx->getZExtValue()];
5860 }
5861
5862 // Do we demand the inserted element?
5863 if (NeedsElt) {
5864 computeKnownFPClass(V: Elt, Known, InterestedClasses, Q, Depth: Depth + 1);
5865 // If we don't know any bits, early out.
5866 if (Known.isUnknown())
5867 break;
5868 } else {
5869 Known.KnownFPClasses = fcNone;
5870 }
5871
5872 // Do we need anymore elements from Vec?
5873 if (!DemandedVecElts.isZero()) {
5874 KnownFPClass Known2;
5875 computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known&: Known2, Q,
5876 Depth: Depth + 1);
5877 Known |= Known2;
5878 }
5879
5880 break;
5881 }
5882 case Instruction::ShuffleVector: {
5883 // Handle vector splat idiom
5884 if (Value *Splat = getSplatValue(V)) {
5885 computeKnownFPClass(V: Splat, Known, InterestedClasses, Q, Depth: Depth + 1);
5886 break;
5887 }
5888
5889 // For undef elements, we don't know anything about the common state of
5890 // the shuffle result.
5891 APInt DemandedLHS, DemandedRHS;
5892 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Op);
5893 if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
5894 return;
5895
5896 if (!!DemandedLHS) {
5897 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
5898 computeKnownFPClass(V: LHS, DemandedElts: DemandedLHS, InterestedClasses, Known, Q,
5899 Depth: Depth + 1);
5900
5901 // If we don't know any bits, early out.
5902 if (Known.isUnknown())
5903 break;
5904 } else {
5905 Known.KnownFPClasses = fcNone;
5906 }
5907
5908 if (!!DemandedRHS) {
5909 KnownFPClass Known2;
5910 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
5911 computeKnownFPClass(V: RHS, DemandedElts: DemandedRHS, InterestedClasses, Known&: Known2, Q,
5912 Depth: Depth + 1);
5913 Known |= Known2;
5914 }
5915
5916 break;
5917 }
5918 case Instruction::ExtractValue: {
5919 const ExtractValueInst *Extract = cast<ExtractValueInst>(Val: Op);
5920 ArrayRef<unsigned> Indices = Extract->getIndices();
5921 const Value *Src = Extract->getAggregateOperand();
5922 if (isa<StructType>(Val: Src->getType()) && Indices.size() == 1 &&
5923 Indices[0] == 0) {
5924 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Src)) {
5925 switch (II->getIntrinsicID()) {
5926 case Intrinsic::frexp: {
5927 Known.knownNot(RuleOut: fcSubnormal);
5928
5929 KnownFPClass KnownSrc;
5930 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5931 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5932
5933 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5934 const fltSemantics &FltSem =
5935 Op->getType()->getScalarType()->getFltSemantics();
5936
5937 DenormalMode Mode =
5938 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5939 Known = KnownFPClass::frexp_mant(Src: KnownSrc, Mode);
5940 return;
5941 }
5942 default:
5943 break;
5944 }
5945 }
5946 }
5947
5948 computeKnownFPClass(V: Src, DemandedElts, InterestedClasses, Known, Q,
5949 Depth: Depth + 1);
5950 break;
5951 }
5952 case Instruction::PHI: {
5953 const PHINode *P = cast<PHINode>(Val: Op);
5954 // Unreachable blocks may have zero-operand PHI nodes.
5955 if (P->getNumIncomingValues() == 0)
5956 break;
5957
5958 // Otherwise take the unions of the known bit sets of the operands,
5959 // taking conservative care to avoid excessive recursion.
5960 const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2;
5961
5962 if (Depth < PhiRecursionLimit) {
5963 // Skip if every incoming value references to ourself.
5964 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
5965 break;
5966
5967 bool First = true;
5968
5969 for (const Use &U : P->operands()) {
5970 Value *IncValue;
5971 Instruction *CxtI;
5972 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI);
5973 // Skip direct self references.
5974 if (IncValue == P)
5975 continue;
5976
5977 KnownFPClass KnownSrc;
5978 // Recurse, but cap the recursion to two levels, because we don't want
5979 // to waste time spinning around in loops. We need at least depth 2 to
5980 // detect known sign bits.
5981 computeKnownFPClass(V: IncValue, DemandedElts, InterestedClasses, Known&: KnownSrc,
5982 Q: Q.getWithoutCondContext().getWithInstruction(I: CxtI),
5983 Depth: PhiRecursionLimit);
5984
5985 if (First) {
5986 Known = KnownSrc;
5987 First = false;
5988 } else {
5989 Known |= KnownSrc;
5990 }
5991
5992 if (Known.KnownFPClasses == fcAllFlags)
5993 break;
5994 }
5995 }
5996
5997 // Look for the case of a for loop which has a positive
5998 // initial value and is incremented by a squared value.
5999 // This will propagate sign information out of such loops.
6000 if (P->getNumIncomingValues() != 2 || Known.cannotBeOrderedLessThanZero())
6001 break;
6002 for (unsigned I = 0; I < 2; I++) {
6003 Value *RecurValue = P->getIncomingValue(i: 1 - I);
6004 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: RecurValue);
6005 if (!II)
6006 continue;
6007 Value *R, *L, *Init;
6008 PHINode *PN;
6009 if (matchSimpleTernaryIntrinsicRecurrence(I: II, P&: PN, Init, OtherOp0&: L, OtherOp1&: R) &&
6010 PN == P) {
6011 switch (II->getIntrinsicID()) {
6012 case Intrinsic::fma:
6013 case Intrinsic::fmuladd: {
6014 KnownFPClass KnownStart;
6015 computeKnownFPClass(V: Init, DemandedElts, InterestedClasses, Known&: KnownStart,
6016 Q, Depth: Depth + 1);
6017 if (KnownStart.cannotBeOrderedLessThanZero() && L == R &&
6018 isGuaranteedNotToBeUndef(V: L, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1))
6019 Known.knownNot(RuleOut: KnownFPClass::OrderedLessThanZeroMask);
6020 break;
6021 }
6022 }
6023 }
6024 }
6025 break;
6026 }
6027 case Instruction::BitCast: {
6028 const Value *Src;
6029 if (!match(V: Op, P: m_ElementWiseBitCast(Op: m_Value(V&: Src))) ||
6030 !Src->getType()->isIntOrIntVectorTy())
6031 break;
6032
6033 const Type *Ty = Op->getType();
6034
6035 Value *CastLHS, *CastRHS;
6036
6037 // Match bitcast(umax(bitcast(a), bitcast(b)))
6038 if (match(V: Src, P: m_c_MaxOrMin(L: m_BitCast(Op: m_Value(V&: CastLHS)),
6039 R: m_BitCast(Op: m_Value(V&: CastRHS)))) &&
6040 CastLHS->getType() == Ty && CastRHS->getType() == Ty) {
6041 KnownFPClass KnownLHS, KnownRHS;
6042 computeKnownFPClass(V: CastRHS, DemandedElts, InterestedClasses, Known&: KnownRHS, Q,
6043 Depth: Depth + 1);
6044 if (!KnownRHS.isUnknown()) {
6045 computeKnownFPClass(V: CastLHS, DemandedElts, InterestedClasses, Known&: KnownLHS,
6046 Q, Depth: Depth + 1);
6047 Known = KnownLHS | KnownRHS;
6048 }
6049
6050 return;
6051 }
6052
6053 const Type *EltTy = Ty->getScalarType();
6054 KnownBits Bits(EltTy->getPrimitiveSizeInBits());
6055 computeKnownBits(V: Src, DemandedElts, Known&: Bits, Q, Depth: Depth + 1);
6056
6057 Known = KnownFPClass::bitcast(FltSemantics: EltTy->getFltSemantics(), Bits);
6058 break;
6059 }
6060 default:
6061 break;
6062 }
6063}
6064
6065KnownFPClass llvm::computeKnownFPClass(const Value *V,
6066 const APInt &DemandedElts,
6067 FPClassTest InterestedClasses,
6068 const SimplifyQuery &SQ,
6069 unsigned Depth) {
6070 KnownFPClass KnownClasses;
6071 ::computeKnownFPClass(V, DemandedElts, InterestedClasses, Known&: KnownClasses, Q: SQ,
6072 Depth);
6073 return KnownClasses;
6074}
6075
6076KnownFPClass llvm::computeKnownFPClass(const Value *V,
6077 FPClassTest InterestedClasses,
6078 const SimplifyQuery &SQ,
6079 unsigned Depth) {
6080 KnownFPClass Known;
6081 ::computeKnownFPClass(V, Known, InterestedClasses, Q: SQ, Depth);
6082 return Known;
6083}
6084
6085KnownFPClass llvm::computeKnownFPClass(
6086 const Value *V, const DataLayout &DL, FPClassTest InterestedClasses,
6087 const TargetLibraryInfo *TLI, AssumptionCache *AC, const Instruction *CxtI,
6088 const DominatorTree *DT, bool UseInstrInfo, unsigned Depth) {
6089 return computeKnownFPClass(V, InterestedClasses,
6090 SQ: SimplifyQuery(DL, TLI, DT, AC, CxtI, UseInstrInfo),
6091 Depth);
6092}
6093
6094KnownFPClass
6095llvm::computeKnownFPClass(const Value *V, const APInt &DemandedElts,
6096 FastMathFlags FMF, FPClassTest InterestedClasses,
6097 const SimplifyQuery &SQ, unsigned Depth) {
6098 if (FMF.noNaNs())
6099 InterestedClasses &= ~fcNan;
6100 if (FMF.noInfs())
6101 InterestedClasses &= ~fcInf;
6102
6103 KnownFPClass Result =
6104 computeKnownFPClass(V, DemandedElts, InterestedClasses, SQ, Depth);
6105
6106 if (FMF.noNaNs())
6107 Result.KnownFPClasses &= ~fcNan;
6108 if (FMF.noInfs())
6109 Result.KnownFPClasses &= ~fcInf;
6110 return Result;
6111}
6112
6113KnownFPClass llvm::computeKnownFPClass(const Value *V, FastMathFlags FMF,
6114 FPClassTest InterestedClasses,
6115 const SimplifyQuery &SQ,
6116 unsigned Depth) {
6117 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
6118 APInt DemandedElts =
6119 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
6120 return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, SQ,
6121 Depth);
6122}
6123
6124bool llvm::cannotBeNegativeZero(const Value *V, const SimplifyQuery &SQ,
6125 unsigned Depth) {
6126 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNegZero, SQ, Depth);
6127 return Known.isKnownNeverNegZero();
6128}
6129
6130bool llvm::cannotBeOrderedLessThanZero(const Value *V, const SimplifyQuery &SQ,
6131 unsigned Depth) {
6132 KnownFPClass Known =
6133 computeKnownFPClass(V, InterestedClasses: KnownFPClass::OrderedLessThanZeroMask, SQ, Depth);
6134 return Known.cannotBeOrderedLessThanZero();
6135}
6136
6137bool llvm::isKnownNeverInfinity(const Value *V, const SimplifyQuery &SQ,
6138 unsigned Depth) {
6139 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf, SQ, Depth);
6140 return Known.isKnownNeverInfinity();
6141}
6142
6143/// Return true if the floating-point value can never contain a NaN or infinity.
6144bool llvm::isKnownNeverInfOrNaN(const Value *V, const SimplifyQuery &SQ,
6145 unsigned Depth) {
6146 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf | fcNan, SQ, Depth);
6147 return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity();
6148}
6149
6150/// Return true if the floating-point scalar value is not a NaN or if the
6151/// floating-point vector value has no NaN elements. Return false if a value
6152/// could ever be NaN.
6153bool llvm::isKnownNeverNaN(const Value *V, const SimplifyQuery &SQ,
6154 unsigned Depth) {
6155 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNan, SQ, Depth);
6156 return Known.isKnownNeverNaN();
6157}
6158
6159/// Return false if we can prove that the specified FP value's sign bit is 0.
6160/// Return true if we can prove that the specified FP value's sign bit is 1.
6161/// Otherwise return std::nullopt.
6162std::optional<bool> llvm::computeKnownFPSignBit(const Value *V,
6163 const SimplifyQuery &SQ,
6164 unsigned Depth) {
6165 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcAllFlags, SQ, Depth);
6166 return Known.SignBit;
6167}
6168
6169bool llvm::canIgnoreSignBitOfZero(const Use &U) {
6170 auto *User = cast<Instruction>(Val: U.getUser());
6171 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6172 if (FPOp->hasNoSignedZeros())
6173 return true;
6174 }
6175
6176 switch (User->getOpcode()) {
6177 case Instruction::FPToSI:
6178 case Instruction::FPToUI:
6179 return true;
6180 case Instruction::FCmp:
6181 // fcmp treats both positive and negative zero as equal.
6182 return true;
6183 case Instruction::Call:
6184 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6185 switch (II->getIntrinsicID()) {
6186 case Intrinsic::fabs:
6187 return true;
6188 case Intrinsic::copysign:
6189 return U.getOperandNo() == 0;
6190 case Intrinsic::is_fpclass:
6191 case Intrinsic::vp_is_fpclass: {
6192 auto Test =
6193 static_cast<FPClassTest>(
6194 cast<ConstantInt>(Val: II->getArgOperand(i: 1))->getZExtValue()) &
6195 FPClassTest::fcZero;
6196 return Test == FPClassTest::fcZero || Test == FPClassTest::fcNone;
6197 }
6198 default:
6199 return false;
6200 }
6201 }
6202 return false;
6203 default:
6204 return false;
6205 }
6206}
6207
6208bool llvm::canIgnoreSignBitOfNaN(const Use &U) {
6209 auto *User = cast<Instruction>(Val: U.getUser());
6210 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6211 if (FPOp->hasNoNaNs())
6212 return true;
6213 }
6214
6215 switch (User->getOpcode()) {
6216 case Instruction::FPToSI:
6217 case Instruction::FPToUI:
6218 return true;
6219 // Proper FP math operations ignore the sign bit of NaN.
6220 case Instruction::FAdd:
6221 case Instruction::FSub:
6222 case Instruction::FMul:
6223 case Instruction::FDiv:
6224 case Instruction::FRem:
6225 case Instruction::FPTrunc:
6226 case Instruction::FPExt:
6227 case Instruction::FCmp:
6228 return true;
6229 // Bitwise FP operations should preserve the sign bit of NaN.
6230 case Instruction::FNeg:
6231 case Instruction::Select:
6232 case Instruction::PHI:
6233 return false;
6234 case Instruction::Ret:
6235 return User->getFunction()->getAttributes().getRetNoFPClass() &
6236 FPClassTest::fcNan;
6237 case Instruction::Call:
6238 case Instruction::Invoke: {
6239 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6240 switch (II->getIntrinsicID()) {
6241 case Intrinsic::fabs:
6242 return true;
6243 case Intrinsic::copysign:
6244 return U.getOperandNo() == 0;
6245 // Other proper FP math intrinsics ignore the sign bit of NaN.
6246 case Intrinsic::maxnum:
6247 case Intrinsic::minnum:
6248 case Intrinsic::maximum:
6249 case Intrinsic::minimum:
6250 case Intrinsic::maximumnum:
6251 case Intrinsic::minimumnum:
6252 case Intrinsic::canonicalize:
6253 case Intrinsic::fma:
6254 case Intrinsic::fmuladd:
6255 case Intrinsic::sqrt:
6256 case Intrinsic::pow:
6257 case Intrinsic::powi:
6258 case Intrinsic::fptoui_sat:
6259 case Intrinsic::fptosi_sat:
6260 case Intrinsic::is_fpclass:
6261 case Intrinsic::vp_is_fpclass:
6262 return true;
6263 default:
6264 return false;
6265 }
6266 }
6267
6268 FPClassTest NoFPClass =
6269 cast<CallBase>(Val: User)->getParamNoFPClass(i: U.getOperandNo());
6270 return NoFPClass & FPClassTest::fcNan;
6271 }
6272 default:
6273 return false;
6274 }
6275}
6276
6277bool llvm::isKnownIntegral(const Value *V, const SimplifyQuery &SQ,
6278 FastMathFlags FMF) {
6279 if (isa<PoisonValue>(Val: V))
6280 return true;
6281 if (isa<UndefValue>(Val: V))
6282 return false;
6283
6284 if (match(V, P: m_CheckedFp(CheckFn: [](const APFloat &Val) { return Val.isInteger(); })))
6285 return true;
6286
6287 const Instruction *I = dyn_cast<Instruction>(Val: V);
6288 if (!I)
6289 return false;
6290
6291 switch (I->getOpcode()) {
6292 case Instruction::SIToFP:
6293 case Instruction::UIToFP:
6294 // TODO: Could check nofpclass(inf) on incoming argument
6295 if (FMF.noInfs())
6296 return true;
6297
6298 // Need to check int size cannot produce infinity, which computeKnownFPClass
6299 // knows how to do already.
6300 return isKnownNeverInfinity(V: I, SQ);
6301 case Instruction::Call: {
6302 const CallInst *CI = cast<CallInst>(Val: I);
6303 switch (CI->getIntrinsicID()) {
6304 case Intrinsic::trunc:
6305 case Intrinsic::floor:
6306 case Intrinsic::ceil:
6307 case Intrinsic::rint:
6308 case Intrinsic::nearbyint:
6309 case Intrinsic::round:
6310 case Intrinsic::roundeven:
6311 return (FMF.noInfs() && FMF.noNaNs()) || isKnownNeverInfOrNaN(V: I, SQ);
6312 default:
6313 break;
6314 }
6315
6316 break;
6317 }
6318 default:
6319 break;
6320 }
6321
6322 return false;
6323}
6324
6325Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
6326
6327 // All byte-wide stores are splatable, even of arbitrary variables.
6328 if (V->getType()->isIntegerTy(Bitwidth: 8))
6329 return V;
6330
6331 LLVMContext &Ctx = V->getContext();
6332
6333 // Undef don't care.
6334 auto *UndefInt8 = UndefValue::get(T: Type::getInt8Ty(C&: Ctx));
6335 if (isa<UndefValue>(Val: V))
6336 return UndefInt8;
6337
6338 // Return poison for zero-sized type.
6339 if (DL.getTypeStoreSize(Ty: V->getType()).isZero())
6340 return PoisonValue::get(T: Type::getInt8Ty(C&: Ctx));
6341
6342 Constant *C = dyn_cast<Constant>(Val: V);
6343 if (!C) {
6344 // Conceptually, we could handle things like:
6345 // %a = zext i8 %X to i16
6346 // %b = shl i16 %a, 8
6347 // %c = or i16 %a, %b
6348 // but until there is an example that actually needs this, it doesn't seem
6349 // worth worrying about.
6350 return nullptr;
6351 }
6352
6353 // Handle 'null' ConstantArrayZero etc.
6354 if (C->isNullValue())
6355 return Constant::getNullValue(Ty: Type::getInt8Ty(C&: Ctx));
6356
6357 // Constant floating-point values can be handled as integer values if the
6358 // corresponding integer value is "byteable". An important case is 0.0.
6359 if (ConstantFP *CFP = dyn_cast<ConstantFP>(Val: C)) {
6360 Type *ScalarTy = CFP->getType()->getScalarType();
6361 if (ScalarTy->isHalfTy() || ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
6362 return isBytewiseValue(
6363 V: ConstantInt::get(Context&: Ctx, V: CFP->getValue().bitcastToAPInt()), DL);
6364
6365 // Don't handle long double formats, which have strange constraints.
6366 return nullptr;
6367 }
6368
6369 // We can handle constant integers that are multiple of 8 bits.
6370 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: C)) {
6371 if (CI->getBitWidth() % 8 == 0) {
6372 if (!CI->getValue().isSplat(SplatSizeInBits: 8))
6373 return nullptr;
6374 return ConstantInt::get(Context&: Ctx, V: CI->getValue().trunc(width: 8));
6375 }
6376 }
6377
6378 if (auto *CE = dyn_cast<ConstantExpr>(Val: C)) {
6379 if (CE->getOpcode() == Instruction::IntToPtr) {
6380 if (auto *PtrTy = dyn_cast<PointerType>(Val: CE->getType())) {
6381 unsigned BitWidth = DL.getPointerSizeInBits(AS: PtrTy->getAddressSpace());
6382 if (Constant *Op = ConstantFoldIntegerCast(
6383 C: CE->getOperand(i_nocapture: 0), DestTy: Type::getIntNTy(C&: Ctx, N: BitWidth), IsSigned: false, DL))
6384 return isBytewiseValue(V: Op, DL);
6385 }
6386 }
6387 }
6388
6389 auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
6390 if (LHS == RHS)
6391 return LHS;
6392 if (!LHS || !RHS)
6393 return nullptr;
6394 if (LHS == UndefInt8)
6395 return RHS;
6396 if (RHS == UndefInt8)
6397 return LHS;
6398 return nullptr;
6399 };
6400
6401 if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(Val: C)) {
6402 Value *Val = UndefInt8;
6403 for (uint64_t I = 0, E = CA->getNumElements(); I != E; ++I)
6404 if (!(Val = Merge(Val, isBytewiseValue(V: CA->getElementAsConstant(i: I), DL))))
6405 return nullptr;
6406 return Val;
6407 }
6408
6409 if (isa<ConstantAggregate>(Val: C)) {
6410 Value *Val = UndefInt8;
6411 for (Value *Op : C->operands())
6412 if (!(Val = Merge(Val, isBytewiseValue(V: Op, DL))))
6413 return nullptr;
6414 return Val;
6415 }
6416
6417 // Don't try to handle the handful of other constants.
6418 return nullptr;
6419}
6420
6421// This is the recursive version of BuildSubAggregate. It takes a few different
6422// arguments. Idxs is the index within the nested struct From that we are
6423// looking at now (which is of type IndexedType). IdxSkip is the number of
6424// indices from Idxs that should be left out when inserting into the resulting
6425// struct. To is the result struct built so far, new insertvalue instructions
6426// build on that.
6427static Value *BuildSubAggregate(Value *From, Value *To, Type *IndexedType,
6428 SmallVectorImpl<unsigned> &Idxs,
6429 unsigned IdxSkip,
6430 BasicBlock::iterator InsertBefore) {
6431 StructType *STy = dyn_cast<StructType>(Val: IndexedType);
6432 if (STy) {
6433 // Save the original To argument so we can modify it
6434 Value *OrigTo = To;
6435 // General case, the type indexed by Idxs is a struct
6436 for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
6437 // Process each struct element recursively
6438 Idxs.push_back(Elt: i);
6439 Value *PrevTo = To;
6440 To = BuildSubAggregate(From, To, IndexedType: STy->getElementType(N: i), Idxs, IdxSkip,
6441 InsertBefore);
6442 Idxs.pop_back();
6443 if (!To) {
6444 // Couldn't find any inserted value for this index? Cleanup
6445 while (PrevTo != OrigTo) {
6446 InsertValueInst* Del = cast<InsertValueInst>(Val: PrevTo);
6447 PrevTo = Del->getAggregateOperand();
6448 Del->eraseFromParent();
6449 }
6450 // Stop processing elements
6451 break;
6452 }
6453 }
6454 // If we successfully found a value for each of our subaggregates
6455 if (To)
6456 return To;
6457 }
6458 // Base case, the type indexed by SourceIdxs is not a struct, or not all of
6459 // the struct's elements had a value that was inserted directly. In the latter
6460 // case, perhaps we can't determine each of the subelements individually, but
6461 // we might be able to find the complete struct somewhere.
6462
6463 // Find the value that is at that particular spot
6464 Value *V = FindInsertedValue(V: From, idx_range: Idxs);
6465
6466 if (!V)
6467 return nullptr;
6468
6469 // Insert the value in the new (sub) aggregate
6470 return InsertValueInst::Create(Agg: To, Val: V, Idxs: ArrayRef(Idxs).slice(N: IdxSkip), NameStr: "tmp",
6471 InsertBefore);
6472}
6473
6474// This helper takes a nested struct and extracts a part of it (which is again a
6475// struct) into a new value. For example, given the struct:
6476// { a, { b, { c, d }, e } }
6477// and the indices "1, 1" this returns
6478// { c, d }.
6479//
6480// It does this by inserting an insertvalue for each element in the resulting
6481// struct, as opposed to just inserting a single struct. This will only work if
6482// each of the elements of the substruct are known (ie, inserted into From by an
6483// insertvalue instruction somewhere).
6484//
6485// All inserted insertvalue instructions are inserted before InsertBefore
6486static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
6487 BasicBlock::iterator InsertBefore) {
6488 Type *IndexedType = ExtractValueInst::getIndexedType(Agg: From->getType(),
6489 Idxs: idx_range);
6490 Value *To = PoisonValue::get(T: IndexedType);
6491 SmallVector<unsigned, 10> Idxs(idx_range);
6492 unsigned IdxSkip = Idxs.size();
6493
6494 return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
6495}
6496
6497/// Given an aggregate and a sequence of indices, see if the scalar value
6498/// indexed is already around as a register, for example if it was inserted
6499/// directly into the aggregate.
6500///
6501/// If InsertBefore is not null, this function will duplicate (modified)
6502/// insertvalues when a part of a nested struct is extracted.
6503Value *
6504llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
6505 std::optional<BasicBlock::iterator> InsertBefore) {
6506 // Nothing to index? Just return V then (this is useful at the end of our
6507 // recursion).
6508 if (idx_range.empty())
6509 return V;
6510 // We have indices, so V should have an indexable type.
6511 assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
6512 "Not looking at a struct or array?");
6513 assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
6514 "Invalid indices for type?");
6515
6516 if (Constant *C = dyn_cast<Constant>(Val: V)) {
6517 C = C->getAggregateElement(Elt: idx_range[0]);
6518 if (!C) return nullptr;
6519 return FindInsertedValue(V: C, idx_range: idx_range.slice(N: 1), InsertBefore);
6520 }
6521
6522 if (InsertValueInst *I = dyn_cast<InsertValueInst>(Val: V)) {
6523 // Loop the indices for the insertvalue instruction in parallel with the
6524 // requested indices
6525 const unsigned *req_idx = idx_range.begin();
6526 for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
6527 i != e; ++i, ++req_idx) {
6528 if (req_idx == idx_range.end()) {
6529 // We can't handle this without inserting insertvalues
6530 if (!InsertBefore)
6531 return nullptr;
6532
6533 // The requested index identifies a part of a nested aggregate. Handle
6534 // this specially. For example,
6535 // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
6536 // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
6537 // %C = extractvalue {i32, { i32, i32 } } %B, 1
6538 // This can be changed into
6539 // %A = insertvalue {i32, i32 } undef, i32 10, 0
6540 // %C = insertvalue {i32, i32 } %A, i32 11, 1
6541 // which allows the unused 0,0 element from the nested struct to be
6542 // removed.
6543 return BuildSubAggregate(From: V, idx_range: ArrayRef(idx_range.begin(), req_idx),
6544 InsertBefore: *InsertBefore);
6545 }
6546
6547 // This insert value inserts something else than what we are looking for.
6548 // See if the (aggregate) value inserted into has the value we are
6549 // looking for, then.
6550 if (*req_idx != *i)
6551 return FindInsertedValue(V: I->getAggregateOperand(), idx_range,
6552 InsertBefore);
6553 }
6554 // If we end up here, the indices of the insertvalue match with those
6555 // requested (though possibly only partially). Now we recursively look at
6556 // the inserted value, passing any remaining indices.
6557 return FindInsertedValue(V: I->getInsertedValueOperand(),
6558 idx_range: ArrayRef(req_idx, idx_range.end()), InsertBefore);
6559 }
6560
6561 if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(Val: V)) {
6562 // If we're extracting a value from an aggregate that was extracted from
6563 // something else, we can extract from that something else directly instead.
6564 // However, we will need to chain I's indices with the requested indices.
6565
6566 // Calculate the number of indices required
6567 unsigned size = I->getNumIndices() + idx_range.size();
6568 // Allocate some space to put the new indices in
6569 SmallVector<unsigned, 5> Idxs;
6570 Idxs.reserve(N: size);
6571 // Add indices from the extract value instruction
6572 Idxs.append(in_start: I->idx_begin(), in_end: I->idx_end());
6573
6574 // Add requested indices
6575 Idxs.append(in_start: idx_range.begin(), in_end: idx_range.end());
6576
6577 assert(Idxs.size() == size
6578 && "Number of indices added not correct?");
6579
6580 return FindInsertedValue(V: I->getAggregateOperand(), idx_range: Idxs, InsertBefore);
6581 }
6582 // Otherwise, we don't know (such as, extracting from a function return value
6583 // or load instruction)
6584 return nullptr;
6585}
6586
6587// If V refers to an initialized global constant, set Slice either to
6588// its initializer if the size of its elements equals ElementSize, or,
6589// for ElementSize == 8, to its representation as an array of unsiged
6590// char. Return true on success.
6591// Offset is in the unit "nr of ElementSize sized elements".
6592bool llvm::getConstantDataArrayInfo(const Value *V,
6593 ConstantDataArraySlice &Slice,
6594 unsigned ElementSize, uint64_t Offset) {
6595 assert(V && "V should not be null.");
6596 assert((ElementSize % 8) == 0 &&
6597 "ElementSize expected to be a multiple of the size of a byte.");
6598 unsigned ElementSizeInBytes = ElementSize / 8;
6599
6600 // Drill down into the pointer expression V, ignoring any intervening
6601 // casts, and determine the identity of the object it references along
6602 // with the cumulative byte offset into it.
6603 const GlobalVariable *GV =
6604 dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V));
6605 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
6606 // Fail if V is not based on constant global object.
6607 return false;
6608
6609 const DataLayout &DL = GV->getDataLayout();
6610 APInt Off(DL.getIndexTypeSizeInBits(Ty: V->getType()), 0);
6611
6612 if (GV != V->stripAndAccumulateConstantOffsets(DL, Offset&: Off,
6613 /*AllowNonInbounds*/ true))
6614 // Fail if a constant offset could not be determined.
6615 return false;
6616
6617 uint64_t StartIdx = Off.getLimitedValue();
6618 if (StartIdx == UINT64_MAX)
6619 // Fail if the constant offset is excessive.
6620 return false;
6621
6622 // Off/StartIdx is in the unit of bytes. So we need to convert to number of
6623 // elements. Simply bail out if that isn't possible.
6624 if ((StartIdx % ElementSizeInBytes) != 0)
6625 return false;
6626
6627 Offset += StartIdx / ElementSizeInBytes;
6628 ConstantDataArray *Array = nullptr;
6629 ArrayType *ArrayTy = nullptr;
6630
6631 if (GV->getInitializer()->isNullValue()) {
6632 Type *GVTy = GV->getValueType();
6633 uint64_t SizeInBytes = DL.getTypeStoreSize(Ty: GVTy).getFixedValue();
6634 uint64_t Length = SizeInBytes / ElementSizeInBytes;
6635
6636 Slice.Array = nullptr;
6637 Slice.Offset = 0;
6638 // Return an empty Slice for undersized constants to let callers
6639 // transform even undefined library calls into simpler, well-defined
6640 // expressions. This is preferable to making the calls although it
6641 // prevents sanitizers from detecting such calls.
6642 Slice.Length = Length < Offset ? 0 : Length - Offset;
6643 return true;
6644 }
6645
6646 auto *Init = const_cast<Constant *>(GV->getInitializer());
6647 if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Val: Init)) {
6648 Type *InitElTy = ArrayInit->getElementType();
6649 if (InitElTy->isIntegerTy(Bitwidth: ElementSize)) {
6650 // If Init is an initializer for an array of the expected type
6651 // and size, use it as is.
6652 Array = ArrayInit;
6653 ArrayTy = ArrayInit->getType();
6654 }
6655 }
6656
6657 if (!Array) {
6658 if (ElementSize != 8)
6659 // TODO: Handle conversions to larger integral types.
6660 return false;
6661
6662 // Otherwise extract the portion of the initializer starting
6663 // at Offset as an array of bytes, and reset Offset.
6664 Init = ReadByteArrayFromGlobal(GV, Offset);
6665 if (!Init)
6666 return false;
6667
6668 Offset = 0;
6669 Array = dyn_cast<ConstantDataArray>(Val: Init);
6670 ArrayTy = dyn_cast<ArrayType>(Val: Init->getType());
6671 }
6672
6673 uint64_t NumElts = ArrayTy->getArrayNumElements();
6674 if (Offset > NumElts)
6675 return false;
6676
6677 Slice.Array = Array;
6678 Slice.Offset = Offset;
6679 Slice.Length = NumElts - Offset;
6680 return true;
6681}
6682
6683/// Extract bytes from the initializer of the constant array V, which need
6684/// not be a nul-terminated string. On success, store the bytes in Str and
6685/// return true. When TrimAtNul is set, Str will contain only the bytes up
6686/// to but not including the first nul. Return false on failure.
6687bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
6688 bool TrimAtNul) {
6689 ConstantDataArraySlice Slice;
6690 if (!getConstantDataArrayInfo(V, Slice, ElementSize: 8))
6691 return false;
6692
6693 if (Slice.Array == nullptr) {
6694 if (TrimAtNul) {
6695 // Return a nul-terminated string even for an empty Slice. This is
6696 // safe because all existing SimplifyLibcalls callers require string
6697 // arguments and the behavior of the functions they fold is undefined
6698 // otherwise. Folding the calls this way is preferable to making
6699 // the undefined library calls, even though it prevents sanitizers
6700 // from reporting such calls.
6701 Str = StringRef();
6702 return true;
6703 }
6704 if (Slice.Length == 1) {
6705 Str = StringRef("", 1);
6706 return true;
6707 }
6708 // We cannot instantiate a StringRef as we do not have an appropriate string
6709 // of 0s at hand.
6710 return false;
6711 }
6712
6713 // Start out with the entire array in the StringRef.
6714 Str = Slice.Array->getAsString();
6715 // Skip over 'offset' bytes.
6716 Str = Str.substr(Start: Slice.Offset);
6717
6718 if (TrimAtNul) {
6719 // Trim off the \0 and anything after it. If the array is not nul
6720 // terminated, we just return the whole end of string. The client may know
6721 // some other way that the string is length-bound.
6722 Str = Str.substr(Start: 0, N: Str.find(C: '\0'));
6723 }
6724 return true;
6725}
6726
6727// These next two are very similar to the above, but also look through PHI
6728// nodes.
6729// TODO: See if we can integrate these two together.
6730
6731/// If we can compute the length of the string pointed to by
6732/// the specified pointer, return 'len+1'. If we can't, return 0.
6733static uint64_t GetStringLengthH(const Value *V,
6734 SmallPtrSetImpl<const PHINode*> &PHIs,
6735 unsigned CharSize) {
6736 // Look through noop bitcast instructions.
6737 V = V->stripPointerCasts();
6738
6739 // If this is a PHI node, there are two cases: either we have already seen it
6740 // or we haven't.
6741 if (const PHINode *PN = dyn_cast<PHINode>(Val: V)) {
6742 if (!PHIs.insert(Ptr: PN).second)
6743 return ~0ULL; // already in the set.
6744
6745 // If it was new, see if all the input strings are the same length.
6746 uint64_t LenSoFar = ~0ULL;
6747 for (Value *IncValue : PN->incoming_values()) {
6748 uint64_t Len = GetStringLengthH(V: IncValue, PHIs, CharSize);
6749 if (Len == 0) return 0; // Unknown length -> unknown.
6750
6751 if (Len == ~0ULL) continue;
6752
6753 if (Len != LenSoFar && LenSoFar != ~0ULL)
6754 return 0; // Disagree -> unknown.
6755 LenSoFar = Len;
6756 }
6757
6758 // Success, all agree.
6759 return LenSoFar;
6760 }
6761
6762 // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
6763 if (const SelectInst *SI = dyn_cast<SelectInst>(Val: V)) {
6764 uint64_t Len1 = GetStringLengthH(V: SI->getTrueValue(), PHIs, CharSize);
6765 if (Len1 == 0) return 0;
6766 uint64_t Len2 = GetStringLengthH(V: SI->getFalseValue(), PHIs, CharSize);
6767 if (Len2 == 0) return 0;
6768 if (Len1 == ~0ULL) return Len2;
6769 if (Len2 == ~0ULL) return Len1;
6770 if (Len1 != Len2) return 0;
6771 return Len1;
6772 }
6773
6774 // Otherwise, see if we can read the string.
6775 ConstantDataArraySlice Slice;
6776 if (!getConstantDataArrayInfo(V, Slice, ElementSize: CharSize))
6777 return 0;
6778
6779 if (Slice.Array == nullptr)
6780 // Zeroinitializer (including an empty one).
6781 return 1;
6782
6783 // Search for the first nul character. Return a conservative result even
6784 // when there is no nul. This is safe since otherwise the string function
6785 // being folded such as strlen is undefined, and can be preferable to
6786 // making the undefined library call.
6787 unsigned NullIndex = 0;
6788 for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
6789 if (Slice.Array->getElementAsInteger(i: Slice.Offset + NullIndex) == 0)
6790 break;
6791 }
6792
6793 return NullIndex + 1;
6794}
6795
6796/// If we can compute the length of the string pointed to by
6797/// the specified pointer, return 'len+1'. If we can't, return 0.
6798uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
6799 if (!V->getType()->isPointerTy())
6800 return 0;
6801
6802 SmallPtrSet<const PHINode*, 32> PHIs;
6803 uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
6804 // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
6805 // an empty string as a length.
6806 return Len == ~0ULL ? 1 : Len;
6807}
6808
6809const Value *
6810llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
6811 bool MustPreserveNullness) {
6812 assert(Call &&
6813 "getArgumentAliasingToReturnedPointer only works on nonnull calls");
6814 if (const Value *RV = Call->getReturnedArgOperand())
6815 return RV;
6816 // This can be used only as a aliasing property.
6817 if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6818 Call, MustPreserveNullness))
6819 return Call->getArgOperand(i: 0);
6820 return nullptr;
6821}
6822
6823bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6824 const CallBase *Call, bool MustPreserveNullness) {
6825 switch (Call->getIntrinsicID()) {
6826 case Intrinsic::launder_invariant_group:
6827 case Intrinsic::strip_invariant_group:
6828 case Intrinsic::aarch64_irg:
6829 case Intrinsic::aarch64_tagp:
6830 // The amdgcn_make_buffer_rsrc function does not alter the address of the
6831 // input pointer (and thus preserve null-ness for the purposes of escape
6832 // analysis, which is where the MustPreserveNullness flag comes in to play).
6833 // However, it will not necessarily map ptr addrspace(N) null to ptr
6834 // addrspace(8) null, aka the "null descriptor", which has "all loads return
6835 // 0, all stores are dropped" semantics. Given the context of this intrinsic
6836 // list, no one should be relying on such a strict interpretation of
6837 // MustPreserveNullness (and, at time of writing, they are not), but we
6838 // document this fact out of an abundance of caution.
6839 case Intrinsic::amdgcn_make_buffer_rsrc:
6840 return true;
6841 case Intrinsic::ptrmask:
6842 return !MustPreserveNullness;
6843 case Intrinsic::threadlocal_address:
6844 // The underlying variable changes with thread ID. The Thread ID may change
6845 // at coroutine suspend points.
6846 return !Call->getParent()->getParent()->isPresplitCoroutine();
6847 default:
6848 return false;
6849 }
6850}
6851
6852/// \p PN defines a loop-variant pointer to an object. Check if the
6853/// previous iteration of the loop was referring to the same object as \p PN.
6854static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
6855 const LoopInfo *LI) {
6856 // Find the loop-defined value.
6857 Loop *L = LI->getLoopFor(BB: PN->getParent());
6858 if (PN->getNumIncomingValues() != 2)
6859 return true;
6860
6861 // Find the value from previous iteration.
6862 auto *PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
6863 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6864 PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
6865 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6866 return true;
6867
6868 // If a new pointer is loaded in the loop, the pointer references a different
6869 // object in every iteration. E.g.:
6870 // for (i)
6871 // int *p = a[i];
6872 // ...
6873 if (auto *Load = dyn_cast<LoadInst>(Val: PrevValue))
6874 if (!L->isLoopInvariant(V: Load->getPointerOperand()))
6875 return false;
6876 return true;
6877}
6878
6879const Value *llvm::getUnderlyingObject(const Value *V, unsigned MaxLookup) {
6880 for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
6881 if (auto *GEP = dyn_cast<GEPOperator>(Val: V)) {
6882 const Value *PtrOp = GEP->getPointerOperand();
6883 if (!PtrOp->getType()->isPointerTy()) // Only handle scalar pointer base.
6884 return V;
6885 V = PtrOp;
6886 } else if (Operator::getOpcode(V) == Instruction::BitCast ||
6887 Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
6888 Value *NewV = cast<Operator>(Val: V)->getOperand(i: 0);
6889 if (!NewV->getType()->isPointerTy())
6890 return V;
6891 V = NewV;
6892 } else if (auto *GA = dyn_cast<GlobalAlias>(Val: V)) {
6893 if (GA->isInterposable())
6894 return V;
6895 V = GA->getAliasee();
6896 } else {
6897 if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
6898 // Look through single-arg phi nodes created by LCSSA.
6899 if (PHI->getNumIncomingValues() == 1) {
6900 V = PHI->getIncomingValue(i: 0);
6901 continue;
6902 }
6903 } else if (auto *Call = dyn_cast<CallBase>(Val: V)) {
6904 // CaptureTracking can know about special capturing properties of some
6905 // intrinsics like launder.invariant.group, that can't be expressed with
6906 // the attributes, but have properties like returning aliasing pointer.
6907 // Because some analysis may assume that nocaptured pointer is not
6908 // returned from some special intrinsic (because function would have to
6909 // be marked with returns attribute), it is crucial to use this function
6910 // because it should be in sync with CaptureTracking. Not using it may
6911 // cause weird miscompilations where 2 aliasing pointers are assumed to
6912 // noalias.
6913 if (auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: false)) {
6914 V = RP;
6915 continue;
6916 }
6917 }
6918
6919 return V;
6920 }
6921 assert(V->getType()->isPointerTy() && "Unexpected operand type!");
6922 }
6923 return V;
6924}
6925
6926void llvm::getUnderlyingObjects(const Value *V,
6927 SmallVectorImpl<const Value *> &Objects,
6928 const LoopInfo *LI, unsigned MaxLookup) {
6929 SmallPtrSet<const Value *, 4> Visited;
6930 SmallVector<const Value *, 4> Worklist;
6931 Worklist.push_back(Elt: V);
6932 do {
6933 const Value *P = Worklist.pop_back_val();
6934 P = getUnderlyingObject(V: P, MaxLookup);
6935
6936 if (!Visited.insert(Ptr: P).second)
6937 continue;
6938
6939 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6940 Worklist.push_back(Elt: SI->getTrueValue());
6941 Worklist.push_back(Elt: SI->getFalseValue());
6942 continue;
6943 }
6944
6945 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6946 // If this PHI changes the underlying object in every iteration of the
6947 // loop, don't look through it. Consider:
6948 // int **A;
6949 // for (i) {
6950 // Prev = Curr; // Prev = PHI (Prev_0, Curr)
6951 // Curr = A[i];
6952 // *Prev, *Curr;
6953 //
6954 // Prev is tracking Curr one iteration behind so they refer to different
6955 // underlying objects.
6956 if (!LI || !LI->isLoopHeader(BB: PN->getParent()) ||
6957 isSameUnderlyingObjectInLoop(PN, LI))
6958 append_range(C&: Worklist, R: PN->incoming_values());
6959 else
6960 Objects.push_back(Elt: P);
6961 continue;
6962 }
6963
6964 Objects.push_back(Elt: P);
6965 } while (!Worklist.empty());
6966}
6967
6968const Value *llvm::getUnderlyingObjectAggressive(const Value *V) {
6969 const unsigned MaxVisited = 8;
6970
6971 SmallPtrSet<const Value *, 8> Visited;
6972 SmallVector<const Value *, 8> Worklist;
6973 Worklist.push_back(Elt: V);
6974 const Value *Object = nullptr;
6975 // Used as fallback if we can't find a common underlying object through
6976 // recursion.
6977 bool First = true;
6978 const Value *FirstObject = getUnderlyingObject(V);
6979 do {
6980 const Value *P = Worklist.pop_back_val();
6981 P = First ? FirstObject : getUnderlyingObject(V: P);
6982 First = false;
6983
6984 if (!Visited.insert(Ptr: P).second)
6985 continue;
6986
6987 if (Visited.size() == MaxVisited)
6988 return FirstObject;
6989
6990 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6991 Worklist.push_back(Elt: SI->getTrueValue());
6992 Worklist.push_back(Elt: SI->getFalseValue());
6993 continue;
6994 }
6995
6996 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6997 append_range(C&: Worklist, R: PN->incoming_values());
6998 continue;
6999 }
7000
7001 if (!Object)
7002 Object = P;
7003 else if (Object != P)
7004 return FirstObject;
7005 } while (!Worklist.empty());
7006
7007 return Object ? Object : FirstObject;
7008}
7009
7010/// This is the function that does the work of looking through basic
7011/// ptrtoint+arithmetic+inttoptr sequences.
7012static const Value *getUnderlyingObjectFromInt(const Value *V) {
7013 do {
7014 if (const Operator *U = dyn_cast<Operator>(Val: V)) {
7015 // If we find a ptrtoint, we can transfer control back to the
7016 // regular getUnderlyingObjectFromInt.
7017 if (U->getOpcode() == Instruction::PtrToInt)
7018 return U->getOperand(i: 0);
7019 // If we find an add of a constant, a multiplied value, or a phi, it's
7020 // likely that the other operand will lead us to the base
7021 // object. We don't have to worry about the case where the
7022 // object address is somehow being computed by the multiply,
7023 // because our callers only care when the result is an
7024 // identifiable object.
7025 if (U->getOpcode() != Instruction::Add ||
7026 (!isa<ConstantInt>(Val: U->getOperand(i: 1)) &&
7027 Operator::getOpcode(V: U->getOperand(i: 1)) != Instruction::Mul &&
7028 !isa<PHINode>(Val: U->getOperand(i: 1))))
7029 return V;
7030 V = U->getOperand(i: 0);
7031 } else {
7032 return V;
7033 }
7034 assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
7035 } while (true);
7036}
7037
7038/// This is a wrapper around getUnderlyingObjects and adds support for basic
7039/// ptrtoint+arithmetic+inttoptr sequences.
7040/// It returns false if unidentified object is found in getUnderlyingObjects.
7041bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
7042 SmallVectorImpl<Value *> &Objects) {
7043 SmallPtrSet<const Value *, 16> Visited;
7044 SmallVector<const Value *, 4> Working(1, V);
7045 do {
7046 V = Working.pop_back_val();
7047
7048 SmallVector<const Value *, 4> Objs;
7049 getUnderlyingObjects(V, Objects&: Objs);
7050
7051 for (const Value *V : Objs) {
7052 if (!Visited.insert(Ptr: V).second)
7053 continue;
7054 if (Operator::getOpcode(V) == Instruction::IntToPtr) {
7055 const Value *O =
7056 getUnderlyingObjectFromInt(V: cast<User>(Val: V)->getOperand(i: 0));
7057 if (O->getType()->isPointerTy()) {
7058 Working.push_back(Elt: O);
7059 continue;
7060 }
7061 }
7062 // If getUnderlyingObjects fails to find an identifiable object,
7063 // getUnderlyingObjectsForCodeGen also fails for safety.
7064 if (!isIdentifiedObject(V)) {
7065 Objects.clear();
7066 return false;
7067 }
7068 Objects.push_back(Elt: const_cast<Value *>(V));
7069 }
7070 } while (!Working.empty());
7071 return true;
7072}
7073
7074AllocaInst *llvm::findAllocaForValue(Value *V, bool OffsetZero) {
7075 AllocaInst *Result = nullptr;
7076 SmallPtrSet<Value *, 4> Visited;
7077 SmallVector<Value *, 4> Worklist;
7078
7079 auto AddWork = [&](Value *V) {
7080 if (Visited.insert(Ptr: V).second)
7081 Worklist.push_back(Elt: V);
7082 };
7083
7084 AddWork(V);
7085 do {
7086 V = Worklist.pop_back_val();
7087 assert(Visited.count(V));
7088
7089 if (AllocaInst *AI = dyn_cast<AllocaInst>(Val: V)) {
7090 if (Result && Result != AI)
7091 return nullptr;
7092 Result = AI;
7093 } else if (CastInst *CI = dyn_cast<CastInst>(Val: V)) {
7094 AddWork(CI->getOperand(i_nocapture: 0));
7095 } else if (PHINode *PN = dyn_cast<PHINode>(Val: V)) {
7096 for (Value *IncValue : PN->incoming_values())
7097 AddWork(IncValue);
7098 } else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
7099 AddWork(SI->getTrueValue());
7100 AddWork(SI->getFalseValue());
7101 } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: V)) {
7102 if (OffsetZero && !GEP->hasAllZeroIndices())
7103 return nullptr;
7104 AddWork(GEP->getPointerOperand());
7105 } else if (CallBase *CB = dyn_cast<CallBase>(Val: V)) {
7106 Value *Returned = CB->getReturnedArgOperand();
7107 if (Returned)
7108 AddWork(Returned);
7109 else
7110 return nullptr;
7111 } else {
7112 return nullptr;
7113 }
7114 } while (!Worklist.empty());
7115
7116 return Result;
7117}
7118
7119static bool onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7120 const Value *V, bool AllowLifetime, bool AllowDroppable) {
7121 for (const User *U : V->users()) {
7122 const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U);
7123 if (!II)
7124 return false;
7125
7126 if (AllowLifetime && II->isLifetimeStartOrEnd())
7127 continue;
7128
7129 if (AllowDroppable && II->isDroppable())
7130 continue;
7131
7132 return false;
7133 }
7134 return true;
7135}
7136
7137bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
7138 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7139 V, /* AllowLifetime */ true, /* AllowDroppable */ false);
7140}
7141bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
7142 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7143 V, /* AllowLifetime */ true, /* AllowDroppable */ true);
7144}
7145
7146bool llvm::isNotCrossLaneOperation(const Instruction *I) {
7147 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
7148 return isTriviallyVectorizable(ID: II->getIntrinsicID());
7149 auto *Shuffle = dyn_cast<ShuffleVectorInst>(Val: I);
7150 return (!Shuffle || Shuffle->isSelect()) &&
7151 !isa<CallBase, BitCastInst, ExtractElementInst>(Val: I);
7152}
7153
7154bool llvm::isSafeToSpeculativelyExecute(
7155 const Instruction *Inst, const Instruction *CtxI, AssumptionCache *AC,
7156 const DominatorTree *DT, const TargetLibraryInfo *TLI, bool UseVariableInfo,
7157 bool IgnoreUBImplyingAttrs) {
7158 return isSafeToSpeculativelyExecuteWithOpcode(Opcode: Inst->getOpcode(), Inst, CtxI,
7159 AC, DT, TLI, UseVariableInfo,
7160 IgnoreUBImplyingAttrs);
7161}
7162
7163bool llvm::isSafeToSpeculativelyExecuteWithOpcode(
7164 unsigned Opcode, const Instruction *Inst, const Instruction *CtxI,
7165 AssumptionCache *AC, const DominatorTree *DT, const TargetLibraryInfo *TLI,
7166 bool UseVariableInfo, bool IgnoreUBImplyingAttrs) {
7167#ifndef NDEBUG
7168 if (Inst->getOpcode() != Opcode) {
7169 // Check that the operands are actually compatible with the Opcode override.
7170 auto hasEqualReturnAndLeadingOperandTypes =
7171 [](const Instruction *Inst, unsigned NumLeadingOperands) {
7172 if (Inst->getNumOperands() < NumLeadingOperands)
7173 return false;
7174 const Type *ExpectedType = Inst->getType();
7175 for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp)
7176 if (Inst->getOperand(ItOp)->getType() != ExpectedType)
7177 return false;
7178 return true;
7179 };
7180 assert(!Instruction::isBinaryOp(Opcode) ||
7181 hasEqualReturnAndLeadingOperandTypes(Inst, 2));
7182 assert(!Instruction::isUnaryOp(Opcode) ||
7183 hasEqualReturnAndLeadingOperandTypes(Inst, 1));
7184 }
7185#endif
7186
7187 switch (Opcode) {
7188 default:
7189 return true;
7190 case Instruction::UDiv:
7191 case Instruction::URem: {
7192 // x / y is undefined if y == 0.
7193 const APInt *V;
7194 if (match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: V)))
7195 return *V != 0;
7196 return false;
7197 }
7198 case Instruction::SDiv:
7199 case Instruction::SRem: {
7200 // x / y is undefined if y == 0 or x == INT_MIN and y == -1
7201 const APInt *Numerator, *Denominator;
7202 if (!match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: Denominator)))
7203 return false;
7204 // We cannot hoist this division if the denominator is 0.
7205 if (*Denominator == 0)
7206 return false;
7207 // It's safe to hoist if the denominator is not 0 or -1.
7208 if (!Denominator->isAllOnes())
7209 return true;
7210 // At this point we know that the denominator is -1. It is safe to hoist as
7211 // long we know that the numerator is not INT_MIN.
7212 if (match(V: Inst->getOperand(i: 0), P: m_APInt(Res&: Numerator)))
7213 return !Numerator->isMinSignedValue();
7214 // The numerator *might* be MinSignedValue.
7215 return false;
7216 }
7217 case Instruction::Load: {
7218 if (!UseVariableInfo)
7219 return false;
7220
7221 const LoadInst *LI = dyn_cast<LoadInst>(Val: Inst);
7222 if (!LI)
7223 return false;
7224 if (mustSuppressSpeculation(LI: *LI))
7225 return false;
7226 const DataLayout &DL = LI->getDataLayout();
7227 return isDereferenceableAndAlignedPointer(V: LI->getPointerOperand(),
7228 Ty: LI->getType(), Alignment: LI->getAlign(), DL,
7229 CtxI, AC, DT, TLI);
7230 }
7231 case Instruction::Call: {
7232 auto *CI = dyn_cast<const CallInst>(Val: Inst);
7233 if (!CI)
7234 return false;
7235 const Function *Callee = CI->getCalledFunction();
7236
7237 // The called function could have undefined behavior or side-effects, even
7238 // if marked readnone nounwind.
7239 if (!Callee || !Callee->isSpeculatable())
7240 return false;
7241 // Since the operands may be changed after hoisting, undefined behavior may
7242 // be triggered by some UB-implying attributes.
7243 return IgnoreUBImplyingAttrs || !CI->hasUBImplyingAttrs();
7244 }
7245 case Instruction::VAArg:
7246 case Instruction::Alloca:
7247 case Instruction::Invoke:
7248 case Instruction::CallBr:
7249 case Instruction::PHI:
7250 case Instruction::Store:
7251 case Instruction::Ret:
7252 case Instruction::UncondBr:
7253 case Instruction::CondBr:
7254 case Instruction::IndirectBr:
7255 case Instruction::Switch:
7256 case Instruction::Unreachable:
7257 case Instruction::Fence:
7258 case Instruction::AtomicRMW:
7259 case Instruction::AtomicCmpXchg:
7260 case Instruction::LandingPad:
7261 case Instruction::Resume:
7262 case Instruction::CatchSwitch:
7263 case Instruction::CatchPad:
7264 case Instruction::CatchRet:
7265 case Instruction::CleanupPad:
7266 case Instruction::CleanupRet:
7267 return false; // Misc instructions which have effects
7268 }
7269}
7270
7271bool llvm::mayHaveNonDefUseDependency(const Instruction &I) {
7272 if (I.mayReadOrWriteMemory())
7273 // Memory dependency possible
7274 return true;
7275 if (!isSafeToSpeculativelyExecute(Inst: &I))
7276 // Can't move above a maythrow call or infinite loop. Or if an
7277 // inalloca alloca, above a stacksave call.
7278 return true;
7279 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7280 // 1) Can't reorder two inf-loop calls, even if readonly
7281 // 2) Also can't reorder an inf-loop call below a instruction which isn't
7282 // safe to speculative execute. (Inverse of above)
7283 return true;
7284 return false;
7285}
7286
7287/// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
7288static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
7289 switch (OR) {
7290 case ConstantRange::OverflowResult::MayOverflow:
7291 return OverflowResult::MayOverflow;
7292 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7293 return OverflowResult::AlwaysOverflowsLow;
7294 case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
7295 return OverflowResult::AlwaysOverflowsHigh;
7296 case ConstantRange::OverflowResult::NeverOverflows:
7297 return OverflowResult::NeverOverflows;
7298 }
7299 llvm_unreachable("Unknown OverflowResult");
7300}
7301
7302/// Combine constant ranges from computeConstantRange() and computeKnownBits().
7303ConstantRange
7304llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
7305 bool ForSigned,
7306 const SimplifyQuery &SQ) {
7307 ConstantRange CR1 =
7308 ConstantRange::fromKnownBits(Known: V.getKnownBits(Q: SQ), IsSigned: ForSigned);
7309 ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo: SQ.IIQ.UseInstrInfo);
7310 ConstantRange::PreferredRangeType RangeType =
7311 ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
7312 return CR1.intersectWith(CR: CR2, Type: RangeType);
7313}
7314
7315OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
7316 const Value *RHS,
7317 const SimplifyQuery &SQ,
7318 bool IsNSW) {
7319 ConstantRange LHSRange =
7320 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7321 ConstantRange RHSRange =
7322 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7323
7324 // mul nsw of two non-negative numbers is also nuw.
7325 if (IsNSW && LHSRange.isAllNonNegative() && RHSRange.isAllNonNegative())
7326 return OverflowResult::NeverOverflows;
7327
7328 return mapOverflowResult(OR: LHSRange.unsignedMulMayOverflow(Other: RHSRange));
7329}
7330
7331OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
7332 const Value *RHS,
7333 const SimplifyQuery &SQ) {
7334 // Multiplying n * m significant bits yields a result of n + m significant
7335 // bits. If the total number of significant bits does not exceed the
7336 // result bit width (minus 1), there is no overflow.
7337 // This means if we have enough leading sign bits in the operands
7338 // we can guarantee that the result does not overflow.
7339 // Ref: "Hacker's Delight" by Henry Warren
7340 unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
7341
7342 // Note that underestimating the number of sign bits gives a more
7343 // conservative answer.
7344 unsigned SignBits =
7345 ::ComputeNumSignBits(V: LHS, Q: SQ) + ::ComputeNumSignBits(V: RHS, Q: SQ);
7346
7347 // First handle the easy case: if we have enough sign bits there's
7348 // definitely no overflow.
7349 if (SignBits > BitWidth + 1)
7350 return OverflowResult::NeverOverflows;
7351
7352 // There are two ambiguous cases where there can be no overflow:
7353 // SignBits == BitWidth + 1 and
7354 // SignBits == BitWidth
7355 // The second case is difficult to check, therefore we only handle the
7356 // first case.
7357 if (SignBits == BitWidth + 1) {
7358 // It overflows only when both arguments are negative and the true
7359 // product is exactly the minimum negative number.
7360 // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
7361 // For simplicity we just check if at least one side is not negative.
7362 KnownBits LHSKnown = computeKnownBits(V: LHS, Q: SQ);
7363 KnownBits RHSKnown = computeKnownBits(V: RHS, Q: SQ);
7364 if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
7365 return OverflowResult::NeverOverflows;
7366 }
7367 return OverflowResult::MayOverflow;
7368}
7369
7370OverflowResult
7371llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
7372 const WithCache<const Value *> &RHS,
7373 const SimplifyQuery &SQ) {
7374 ConstantRange LHSRange =
7375 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7376 ConstantRange RHSRange =
7377 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7378 return mapOverflowResult(OR: LHSRange.unsignedAddMayOverflow(Other: RHSRange));
7379}
7380
7381static OverflowResult
7382computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7383 const WithCache<const Value *> &RHS,
7384 const AddOperator *Add, const SimplifyQuery &SQ) {
7385 if (Add && Add->hasNoSignedWrap()) {
7386 return OverflowResult::NeverOverflows;
7387 }
7388
7389 // If LHS and RHS each have at least two sign bits, the addition will look
7390 // like
7391 //
7392 // XX..... +
7393 // YY.....
7394 //
7395 // If the carry into the most significant position is 0, X and Y can't both
7396 // be 1 and therefore the carry out of the addition is also 0.
7397 //
7398 // If the carry into the most significant position is 1, X and Y can't both
7399 // be 0 and therefore the carry out of the addition is also 1.
7400 //
7401 // Since the carry into the most significant position is always equal to
7402 // the carry out of the addition, there is no signed overflow.
7403 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7404 return OverflowResult::NeverOverflows;
7405
7406 ConstantRange LHSRange =
7407 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7408 ConstantRange RHSRange =
7409 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7410 OverflowResult OR =
7411 mapOverflowResult(OR: LHSRange.signedAddMayOverflow(Other: RHSRange));
7412 if (OR != OverflowResult::MayOverflow)
7413 return OR;
7414
7415 // The remaining code needs Add to be available. Early returns if not so.
7416 if (!Add)
7417 return OverflowResult::MayOverflow;
7418
7419 // If the sign of Add is the same as at least one of the operands, this add
7420 // CANNOT overflow. If this can be determined from the known bits of the
7421 // operands the above signedAddMayOverflow() check will have already done so.
7422 // The only other way to improve on the known bits is from an assumption, so
7423 // call computeKnownBitsFromContext() directly.
7424 bool LHSOrRHSKnownNonNegative =
7425 (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
7426 bool LHSOrRHSKnownNegative =
7427 (LHSRange.isAllNegative() || RHSRange.isAllNegative());
7428 if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
7429 KnownBits AddKnown(LHSRange.getBitWidth());
7430 computeKnownBitsFromContext(V: Add, Known&: AddKnown, Q: SQ);
7431 if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
7432 (AddKnown.isNegative() && LHSOrRHSKnownNegative))
7433 return OverflowResult::NeverOverflows;
7434 }
7435
7436 return OverflowResult::MayOverflow;
7437}
7438
7439OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
7440 const Value *RHS,
7441 const SimplifyQuery &SQ) {
7442 // X - (X % ?)
7443 // The remainder of a value can't have greater magnitude than itself,
7444 // so the subtraction can't overflow.
7445
7446 // X - (X -nuw ?)
7447 // In the minimal case, this would simplify to "?", so there's no subtract
7448 // at all. But if this analysis is used to peek through casts, for example,
7449 // then determining no-overflow may allow other transforms.
7450
7451 // TODO: There are other patterns like this.
7452 // See simplifyICmpWithBinOpOnLHS() for candidates.
7453 if (match(V: RHS, P: m_URem(L: m_Specific(V: LHS), R: m_Value())) ||
7454 match(V: RHS, P: m_NUWSub(L: m_Specific(V: LHS), R: m_Value())))
7455 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7456 return OverflowResult::NeverOverflows;
7457
7458 if (auto C = isImpliedByDomCondition(Pred: CmpInst::ICMP_UGE, LHS, RHS, ContextI: SQ.CxtI,
7459 DL: SQ.DL)) {
7460 if (*C)
7461 return OverflowResult::NeverOverflows;
7462 return OverflowResult::AlwaysOverflowsLow;
7463 }
7464
7465 ConstantRange LHSRange =
7466 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7467 ConstantRange RHSRange =
7468 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7469 return mapOverflowResult(OR: LHSRange.unsignedSubMayOverflow(Other: RHSRange));
7470}
7471
7472OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
7473 const Value *RHS,
7474 const SimplifyQuery &SQ) {
7475 // X - (X % ?)
7476 // The remainder of a value can't have greater magnitude than itself,
7477 // so the subtraction can't overflow.
7478
7479 // X - (X -nsw ?)
7480 // In the minimal case, this would simplify to "?", so there's no subtract
7481 // at all. But if this analysis is used to peek through casts, for example,
7482 // then determining no-overflow may allow other transforms.
7483 if (match(V: RHS, P: m_SRem(L: m_Specific(V: LHS), R: m_Value())) ||
7484 match(V: RHS, P: m_NSWSub(L: m_Specific(V: LHS), R: m_Value())))
7485 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7486 return OverflowResult::NeverOverflows;
7487
7488 // If LHS and RHS each have at least two sign bits, the subtraction
7489 // cannot overflow.
7490 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7491 return OverflowResult::NeverOverflows;
7492
7493 ConstantRange LHSRange =
7494 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7495 ConstantRange RHSRange =
7496 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7497 return mapOverflowResult(OR: LHSRange.signedSubMayOverflow(Other: RHSRange));
7498}
7499
7500bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
7501 const DominatorTree &DT) {
7502 SmallVector<const CondBrInst *, 2> GuardingBranches;
7503 SmallVector<const ExtractValueInst *, 2> Results;
7504
7505 for (const User *U : WO->users()) {
7506 if (const auto *EVI = dyn_cast<ExtractValueInst>(Val: U)) {
7507 assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
7508
7509 if (EVI->getIndices()[0] == 0)
7510 Results.push_back(Elt: EVI);
7511 else {
7512 assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
7513
7514 for (const auto *U : EVI->users())
7515 if (const auto *B = dyn_cast<CondBrInst>(Val: U))
7516 GuardingBranches.push_back(Elt: B);
7517 }
7518 } else {
7519 // We are using the aggregate directly in a way we don't want to analyze
7520 // here (storing it to a global, say).
7521 return false;
7522 }
7523 }
7524
7525 auto AllUsesGuardedByBranch = [&](const CondBrInst *BI) {
7526 BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(i: 1));
7527
7528 // Check if all users of the add are provably no-wrap.
7529 for (const auto *Result : Results) {
7530 // If the extractvalue itself is not executed on overflow, the we don't
7531 // need to check each use separately, since domination is transitive.
7532 if (DT.dominates(BBE: NoWrapEdge, BB: Result->getParent()))
7533 continue;
7534
7535 for (const auto &RU : Result->uses())
7536 if (!DT.dominates(BBE: NoWrapEdge, U: RU))
7537 return false;
7538 }
7539
7540 return true;
7541 };
7542
7543 return llvm::any_of(Range&: GuardingBranches, P: AllUsesGuardedByBranch);
7544}
7545
7546/// Shifts return poison if shiftwidth is larger than the bitwidth.
7547static bool shiftAmountKnownInRange(const Value *ShiftAmount) {
7548 auto *C = dyn_cast<Constant>(Val: ShiftAmount);
7549 if (!C)
7550 return false;
7551
7552 // Shifts return poison if shiftwidth is larger than the bitwidth.
7553 SmallVector<const Constant *, 4> ShiftAmounts;
7554 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: C->getType())) {
7555 unsigned NumElts = FVTy->getNumElements();
7556 for (unsigned i = 0; i < NumElts; ++i)
7557 ShiftAmounts.push_back(Elt: C->getAggregateElement(Elt: i));
7558 } else if (isa<ScalableVectorType>(Val: C->getType()))
7559 return false; // Can't tell, just return false to be safe
7560 else
7561 ShiftAmounts.push_back(Elt: C);
7562
7563 bool Safe = llvm::all_of(Range&: ShiftAmounts, P: [](const Constant *C) {
7564 auto *CI = dyn_cast_or_null<ConstantInt>(Val: C);
7565 return CI && CI->getValue().ult(RHS: C->getType()->getIntegerBitWidth());
7566 });
7567
7568 return Safe;
7569}
7570
7571enum class UndefPoisonKind {
7572 PoisonOnly = (1 << 0),
7573 UndefOnly = (1 << 1),
7574 UndefOrPoison = PoisonOnly | UndefOnly,
7575};
7576
7577static bool includesPoison(UndefPoisonKind Kind) {
7578 return (unsigned(Kind) & unsigned(UndefPoisonKind::PoisonOnly)) != 0;
7579}
7580
7581static bool includesUndef(UndefPoisonKind Kind) {
7582 return (unsigned(Kind) & unsigned(UndefPoisonKind::UndefOnly)) != 0;
7583}
7584
7585static bool canCreateUndefOrPoison(const Operator *Op, UndefPoisonKind Kind,
7586 bool ConsiderFlagsAndMetadata) {
7587
7588 if (ConsiderFlagsAndMetadata && includesPoison(Kind) &&
7589 Op->hasPoisonGeneratingAnnotations())
7590 return true;
7591
7592 unsigned Opcode = Op->getOpcode();
7593
7594 // Check whether opcode is a poison/undef-generating operation
7595 switch (Opcode) {
7596 case Instruction::Shl:
7597 case Instruction::AShr:
7598 case Instruction::LShr:
7599 return includesPoison(Kind) && !shiftAmountKnownInRange(ShiftAmount: Op->getOperand(i: 1));
7600 case Instruction::FPToSI:
7601 case Instruction::FPToUI:
7602 // fptosi/ui yields poison if the resulting value does not fit in the
7603 // destination type.
7604 return true;
7605 case Instruction::Call:
7606 if (auto *II = dyn_cast<IntrinsicInst>(Val: Op)) {
7607 switch (II->getIntrinsicID()) {
7608 // NOTE: Use IntrNoCreateUndefOrPoison when possible.
7609 case Intrinsic::ctlz:
7610 case Intrinsic::cttz:
7611 case Intrinsic::abs:
7612 // We're not considering flags so it is safe to just return false.
7613 return false;
7614 case Intrinsic::sshl_sat:
7615 case Intrinsic::ushl_sat:
7616 if (!includesPoison(Kind) ||
7617 shiftAmountKnownInRange(ShiftAmount: II->getArgOperand(i: 1)))
7618 return false;
7619 break;
7620 }
7621 }
7622 [[fallthrough]];
7623 case Instruction::CallBr:
7624 case Instruction::Invoke: {
7625 const auto *CB = cast<CallBase>(Val: Op);
7626 return !CB->hasRetAttr(Kind: Attribute::NoUndef) &&
7627 !CB->hasFnAttr(Kind: Attribute::NoCreateUndefOrPoison);
7628 }
7629 case Instruction::InsertElement:
7630 case Instruction::ExtractElement: {
7631 // If index exceeds the length of the vector, it returns poison
7632 auto *VTy = cast<VectorType>(Val: Op->getOperand(i: 0)->getType());
7633 unsigned IdxOp = Op->getOpcode() == Instruction::InsertElement ? 2 : 1;
7634 auto *Idx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: IdxOp));
7635 if (includesPoison(Kind))
7636 return !Idx ||
7637 Idx->getValue().uge(RHS: VTy->getElementCount().getKnownMinValue());
7638 return false;
7639 }
7640 case Instruction::ShuffleVector: {
7641 ArrayRef<int> Mask = isa<ConstantExpr>(Val: Op)
7642 ? cast<ConstantExpr>(Val: Op)->getShuffleMask()
7643 : cast<ShuffleVectorInst>(Val: Op)->getShuffleMask();
7644 return includesPoison(Kind) && is_contained(Range&: Mask, Element: PoisonMaskElem);
7645 }
7646 case Instruction::FNeg:
7647 case Instruction::PHI:
7648 case Instruction::Select:
7649 case Instruction::ExtractValue:
7650 case Instruction::InsertValue:
7651 case Instruction::Freeze:
7652 case Instruction::ICmp:
7653 case Instruction::FCmp:
7654 case Instruction::GetElementPtr:
7655 return false;
7656 case Instruction::AddrSpaceCast:
7657 return true;
7658 default: {
7659 const auto *CE = dyn_cast<ConstantExpr>(Val: Op);
7660 if (isa<CastInst>(Val: Op) || (CE && CE->isCast()))
7661 return false;
7662 else if (Instruction::isBinaryOp(Opcode))
7663 return false;
7664 // Be conservative and return true.
7665 return true;
7666 }
7667 }
7668}
7669
7670bool llvm::canCreateUndefOrPoison(const Operator *Op,
7671 bool ConsiderFlagsAndMetadata) {
7672 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::UndefOrPoison,
7673 ConsiderFlagsAndMetadata);
7674}
7675
7676bool llvm::canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata) {
7677 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::PoisonOnly,
7678 ConsiderFlagsAndMetadata);
7679}
7680
7681static bool directlyImpliesPoison(const Value *ValAssumedPoison, const Value *V,
7682 unsigned Depth) {
7683 if (ValAssumedPoison == V)
7684 return true;
7685
7686 const unsigned MaxDepth = 2;
7687 if (Depth >= MaxDepth)
7688 return false;
7689
7690 if (const auto *I = dyn_cast<Instruction>(Val: V)) {
7691 if (any_of(Range: I->operands(), P: [=](const Use &Op) {
7692 return propagatesPoison(PoisonOp: Op) &&
7693 directlyImpliesPoison(ValAssumedPoison, V: Op, Depth: Depth + 1);
7694 }))
7695 return true;
7696
7697 // V = extractvalue V0, idx
7698 // V2 = extractvalue V0, idx2
7699 // V0's elements are all poison or not. (e.g., add_with_overflow)
7700 const WithOverflowInst *II;
7701 if (match(V: I, P: m_ExtractValue(V: m_WithOverflowInst(I&: II))) &&
7702 (match(V: ValAssumedPoison, P: m_ExtractValue(V: m_Specific(V: II))) ||
7703 llvm::is_contained(Range: II->args(), Element: ValAssumedPoison)))
7704 return true;
7705 }
7706 return false;
7707}
7708
7709static bool impliesPoison(const Value *ValAssumedPoison, const Value *V,
7710 unsigned Depth) {
7711 if (isGuaranteedNotToBePoison(V: ValAssumedPoison))
7712 return true;
7713
7714 if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0))
7715 return true;
7716
7717 const unsigned MaxDepth = 2;
7718 if (Depth >= MaxDepth)
7719 return false;
7720
7721 const auto *I = dyn_cast<Instruction>(Val: ValAssumedPoison);
7722 if (I && !canCreatePoison(Op: cast<Operator>(Val: I))) {
7723 return all_of(Range: I->operands(), P: [=](const Value *Op) {
7724 return impliesPoison(ValAssumedPoison: Op, V, Depth: Depth + 1);
7725 });
7726 }
7727 return false;
7728}
7729
7730bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) {
7731 return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0);
7732}
7733
7734static bool programUndefinedIfUndefOrPoison(const Value *V, bool PoisonOnly);
7735
7736static bool isGuaranteedNotToBeUndefOrPoison(
7737 const Value *V, AssumptionCache *AC, const Instruction *CtxI,
7738 const DominatorTree *DT, unsigned Depth, UndefPoisonKind Kind) {
7739 if (Depth >= MaxAnalysisRecursionDepth)
7740 return false;
7741
7742 if (isa<MetadataAsValue>(Val: V))
7743 return false;
7744
7745 if (const auto *A = dyn_cast<Argument>(Val: V)) {
7746 if (A->hasAttribute(Kind: Attribute::NoUndef) ||
7747 A->hasAttribute(Kind: Attribute::Dereferenceable) ||
7748 A->hasAttribute(Kind: Attribute::DereferenceableOrNull))
7749 return true;
7750 }
7751
7752 if (auto *C = dyn_cast<Constant>(Val: V)) {
7753 if (isa<PoisonValue>(Val: C))
7754 return !includesPoison(Kind);
7755
7756 if (isa<UndefValue>(Val: C))
7757 return !includesUndef(Kind);
7758
7759 if (isa<ConstantInt>(Val: C) || isa<GlobalVariable>(Val: C) || isa<ConstantFP>(Val: C) ||
7760 isa<ConstantPointerNull>(Val: C) || isa<Function>(Val: C))
7761 return true;
7762
7763 if (C->getType()->isVectorTy()) {
7764 if (isa<ConstantExpr>(Val: C)) {
7765 // Scalable vectors can use a ConstantExpr to build a splat.
7766 if (Constant *SplatC = C->getSplatValue())
7767 if (isa<ConstantInt>(Val: SplatC) || isa<ConstantFP>(Val: SplatC))
7768 return true;
7769 } else {
7770 if (includesUndef(Kind) && C->containsUndefElement())
7771 return false;
7772 if (includesPoison(Kind) && C->containsPoisonElement())
7773 return false;
7774 return !C->containsConstantExpression();
7775 }
7776 }
7777 }
7778
7779 // Strip cast operations from a pointer value.
7780 // Note that stripPointerCastsSameRepresentation can strip off getelementptr
7781 // inbounds with zero offset. To guarantee that the result isn't poison, the
7782 // stripped pointer is checked as it has to be pointing into an allocated
7783 // object or be null `null` to ensure `inbounds` getelement pointers with a
7784 // zero offset could not produce poison.
7785 // It can strip off addrspacecast that do not change bit representation as
7786 // well. We believe that such addrspacecast is equivalent to no-op.
7787 auto *StrippedV = V->stripPointerCastsSameRepresentation();
7788 if (isa<AllocaInst>(Val: StrippedV) || isa<GlobalVariable>(Val: StrippedV) ||
7789 isa<Function>(Val: StrippedV) || isa<ConstantPointerNull>(Val: StrippedV))
7790 return true;
7791
7792 auto OpCheck = [&](const Value *V) {
7793 return isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth: Depth + 1, Kind);
7794 };
7795
7796 if (auto *Opr = dyn_cast<Operator>(Val: V)) {
7797 // If the value is a freeze instruction, then it can never
7798 // be undef or poison.
7799 if (isa<FreezeInst>(Val: V))
7800 return true;
7801
7802 if (const auto *CB = dyn_cast<CallBase>(Val: V)) {
7803 if (CB->hasRetAttr(Kind: Attribute::NoUndef) ||
7804 CB->hasRetAttr(Kind: Attribute::Dereferenceable) ||
7805 CB->hasRetAttr(Kind: Attribute::DereferenceableOrNull))
7806 return true;
7807 }
7808
7809 if (!::canCreateUndefOrPoison(Op: Opr, Kind,
7810 /*ConsiderFlagsAndMetadata=*/true)) {
7811 if (const auto *PN = dyn_cast<PHINode>(Val: V)) {
7812 unsigned Num = PN->getNumIncomingValues();
7813 bool IsWellDefined = true;
7814 for (unsigned i = 0; i < Num; ++i) {
7815 if (PN == PN->getIncomingValue(i))
7816 continue;
7817 auto *TI = PN->getIncomingBlock(i)->getTerminator();
7818 if (!isGuaranteedNotToBeUndefOrPoison(V: PN->getIncomingValue(i), AC, CtxI: TI,
7819 DT, Depth: Depth + 1, Kind)) {
7820 IsWellDefined = false;
7821 break;
7822 }
7823 }
7824 if (IsWellDefined)
7825 return true;
7826 } else if (auto *Splat = isa<ShuffleVectorInst>(Val: Opr) ? getSplatValue(V: Opr)
7827 : nullptr) {
7828 // For splats we only need to check the value being splatted.
7829 if (OpCheck(Splat))
7830 return true;
7831 } else if (all_of(Range: Opr->operands(), P: OpCheck))
7832 return true;
7833 }
7834 }
7835
7836 if (auto *I = dyn_cast<LoadInst>(Val: V))
7837 if (I->hasMetadata(KindID: LLVMContext::MD_noundef) ||
7838 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable) ||
7839 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable_or_null))
7840 return true;
7841
7842 if (programUndefinedIfUndefOrPoison(V, PoisonOnly: !includesUndef(Kind)))
7843 return true;
7844
7845 // CxtI may be null or a cloned instruction.
7846 if (!CtxI || !CtxI->getParent() || !DT)
7847 return false;
7848
7849 auto *DNode = DT->getNode(BB: CtxI->getParent());
7850 if (!DNode)
7851 // Unreachable block
7852 return false;
7853
7854 // If V is used as a branch condition before reaching CtxI, V cannot be
7855 // undef or poison.
7856 // br V, BB1, BB2
7857 // BB1:
7858 // CtxI ; V cannot be undef or poison here
7859 auto *Dominator = DNode->getIDom();
7860 // This check is purely for compile time reasons: we can skip the IDom walk
7861 // if what we are checking for includes undef and the value is not an integer.
7862 if (!includesUndef(Kind) || V->getType()->isIntegerTy())
7863 while (Dominator) {
7864 auto *TI = Dominator->getBlock()->getTerminatorOrNull();
7865
7866 Value *Cond = nullptr;
7867 if (auto BI = dyn_cast_or_null<CondBrInst>(Val: TI)) {
7868 Cond = BI->getCondition();
7869 } else if (auto SI = dyn_cast_or_null<SwitchInst>(Val: TI)) {
7870 Cond = SI->getCondition();
7871 }
7872
7873 if (Cond) {
7874 if (Cond == V)
7875 return true;
7876 else if (!includesUndef(Kind) && isa<Operator>(Val: Cond)) {
7877 // For poison, we can analyze further
7878 auto *Opr = cast<Operator>(Val: Cond);
7879 if (any_of(Range: Opr->operands(), P: [V](const Use &U) {
7880 return V == U && propagatesPoison(PoisonOp: U);
7881 }))
7882 return true;
7883 }
7884 }
7885
7886 Dominator = Dominator->getIDom();
7887 }
7888
7889 if (AC && getKnowledgeValidInContext(V, AttrKinds: {Attribute::NoUndef}, AC&: *AC, CtxI, DT))
7890 return true;
7891
7892 return false;
7893}
7894
7895bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC,
7896 const Instruction *CtxI,
7897 const DominatorTree *DT,
7898 unsigned Depth) {
7899 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7900 Kind: UndefPoisonKind::UndefOrPoison);
7901}
7902
7903bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
7904 const Instruction *CtxI,
7905 const DominatorTree *DT, unsigned Depth) {
7906 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7907 Kind: UndefPoisonKind::PoisonOnly);
7908}
7909
7910bool llvm::isGuaranteedNotToBeUndef(const Value *V, AssumptionCache *AC,
7911 const Instruction *CtxI,
7912 const DominatorTree *DT, unsigned Depth) {
7913 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7914 Kind: UndefPoisonKind::UndefOnly);
7915}
7916
7917/// Return true if undefined behavior would provably be executed on the path to
7918/// OnPathTo if Root produced a posion result. Note that this doesn't say
7919/// anything about whether OnPathTo is actually executed or whether Root is
7920/// actually poison. This can be used to assess whether a new use of Root can
7921/// be added at a location which is control equivalent with OnPathTo (such as
7922/// immediately before it) without introducing UB which didn't previously
7923/// exist. Note that a false result conveys no information.
7924bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
7925 Instruction *OnPathTo,
7926 DominatorTree *DT) {
7927 // Basic approach is to assume Root is poison, propagate poison forward
7928 // through all users we can easily track, and then check whether any of those
7929 // users are provable UB and must execute before out exiting block might
7930 // exit.
7931
7932 // The set of all recursive users we've visited (which are assumed to all be
7933 // poison because of said visit)
7934 SmallPtrSet<const Value *, 16> KnownPoison;
7935 SmallVector<const Instruction*, 16> Worklist;
7936 Worklist.push_back(Elt: Root);
7937 while (!Worklist.empty()) {
7938 const Instruction *I = Worklist.pop_back_val();
7939
7940 // If we know this must trigger UB on a path leading our target.
7941 if (mustTriggerUB(I, KnownPoison) && DT->dominates(Def: I, User: OnPathTo))
7942 return true;
7943
7944 // If we can't analyze propagation through this instruction, just skip it
7945 // and transitive users. Safe as false is a conservative result.
7946 if (I != Root && !any_of(Range: I->operands(), P: [&KnownPoison](const Use &U) {
7947 return KnownPoison.contains(Ptr: U) && propagatesPoison(PoisonOp: U);
7948 }))
7949 continue;
7950
7951 if (KnownPoison.insert(Ptr: I).second)
7952 for (const User *User : I->users())
7953 Worklist.push_back(Elt: cast<Instruction>(Val: User));
7954 }
7955
7956 // Might be non-UB, or might have a path we couldn't prove must execute on
7957 // way to exiting bb.
7958 return false;
7959}
7960
7961OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
7962 const SimplifyQuery &SQ) {
7963 return ::computeOverflowForSignedAdd(LHS: Add->getOperand(i_nocapture: 0), RHS: Add->getOperand(i_nocapture: 1),
7964 Add, SQ);
7965}
7966
7967OverflowResult
7968llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7969 const WithCache<const Value *> &RHS,
7970 const SimplifyQuery &SQ) {
7971 return ::computeOverflowForSignedAdd(LHS, RHS, Add: nullptr, SQ);
7972}
7973
7974bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
7975 // Note: An atomic operation isn't guaranteed to return in a reasonable amount
7976 // of time because it's possible for another thread to interfere with it for an
7977 // arbitrary length of time, but programs aren't allowed to rely on that.
7978
7979 // If there is no successor, then execution can't transfer to it.
7980 if (isa<ReturnInst>(Val: I))
7981 return false;
7982 if (isa<UnreachableInst>(Val: I))
7983 return false;
7984
7985 // Note: Do not add new checks here; instead, change Instruction::mayThrow or
7986 // Instruction::willReturn.
7987 //
7988 // FIXME: Move this check into Instruction::willReturn.
7989 if (isa<CatchPadInst>(Val: I)) {
7990 switch (classifyEHPersonality(Pers: I->getFunction()->getPersonalityFn())) {
7991 default:
7992 // A catchpad may invoke exception object constructors and such, which
7993 // in some languages can be arbitrary code, so be conservative by default.
7994 return false;
7995 case EHPersonality::CoreCLR:
7996 // For CoreCLR, it just involves a type test.
7997 return true;
7998 }
7999 }
8000
8001 // An instruction that returns without throwing must transfer control flow
8002 // to a successor.
8003 return !I->mayThrow() && I->willReturn();
8004}
8005
8006bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
8007 // TODO: This is slightly conservative for invoke instruction since exiting
8008 // via an exception *is* normal control for them.
8009 for (const Instruction &I : *BB)
8010 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8011 return false;
8012 return true;
8013}
8014
8015bool llvm::isGuaranteedToTransferExecutionToSuccessor(
8016 BasicBlock::const_iterator Begin, BasicBlock::const_iterator End,
8017 unsigned ScanLimit) {
8018 return isGuaranteedToTransferExecutionToSuccessor(Range: make_range(x: Begin, y: End),
8019 ScanLimit);
8020}
8021
8022bool llvm::isGuaranteedToTransferExecutionToSuccessor(
8023 iterator_range<BasicBlock::const_iterator> Range, unsigned ScanLimit) {
8024 assert(ScanLimit && "scan limit must be non-zero");
8025 for (const Instruction &I : Range) {
8026 if (--ScanLimit == 0)
8027 return false;
8028 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8029 return false;
8030 }
8031 return true;
8032}
8033
8034bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
8035 const Loop *L) {
8036 // The loop header is guaranteed to be executed for every iteration.
8037 //
8038 // FIXME: Relax this constraint to cover all basic blocks that are
8039 // guaranteed to be executed at every iteration.
8040 if (I->getParent() != L->getHeader()) return false;
8041
8042 for (const Instruction &LI : *L->getHeader()) {
8043 if (&LI == I) return true;
8044 if (!isGuaranteedToTransferExecutionToSuccessor(I: &LI)) return false;
8045 }
8046 llvm_unreachable("Instruction not contained in its own parent basic block.");
8047}
8048
8049bool llvm::intrinsicPropagatesPoison(Intrinsic::ID IID) {
8050 switch (IID) {
8051 // TODO: Add more intrinsics.
8052 case Intrinsic::sadd_with_overflow:
8053 case Intrinsic::ssub_with_overflow:
8054 case Intrinsic::smul_with_overflow:
8055 case Intrinsic::uadd_with_overflow:
8056 case Intrinsic::usub_with_overflow:
8057 case Intrinsic::umul_with_overflow:
8058 // If an input is a vector containing a poison element, the
8059 // two output vectors (calculated results, overflow bits)'
8060 // corresponding lanes are poison.
8061 return true;
8062 case Intrinsic::ctpop:
8063 case Intrinsic::ctlz:
8064 case Intrinsic::cttz:
8065 case Intrinsic::abs:
8066 case Intrinsic::smax:
8067 case Intrinsic::smin:
8068 case Intrinsic::umax:
8069 case Intrinsic::umin:
8070 case Intrinsic::scmp:
8071 case Intrinsic::is_fpclass:
8072 case Intrinsic::ptrmask:
8073 case Intrinsic::ucmp:
8074 case Intrinsic::bitreverse:
8075 case Intrinsic::bswap:
8076 case Intrinsic::sadd_sat:
8077 case Intrinsic::ssub_sat:
8078 case Intrinsic::sshl_sat:
8079 case Intrinsic::uadd_sat:
8080 case Intrinsic::usub_sat:
8081 case Intrinsic::ushl_sat:
8082 case Intrinsic::smul_fix:
8083 case Intrinsic::smul_fix_sat:
8084 case Intrinsic::umul_fix:
8085 case Intrinsic::umul_fix_sat:
8086 case Intrinsic::pow:
8087 case Intrinsic::powi:
8088 case Intrinsic::sin:
8089 case Intrinsic::sinh:
8090 case Intrinsic::cos:
8091 case Intrinsic::cosh:
8092 case Intrinsic::sincos:
8093 case Intrinsic::sincospi:
8094 case Intrinsic::tan:
8095 case Intrinsic::tanh:
8096 case Intrinsic::asin:
8097 case Intrinsic::acos:
8098 case Intrinsic::atan:
8099 case Intrinsic::atan2:
8100 case Intrinsic::canonicalize:
8101 case Intrinsic::sqrt:
8102 case Intrinsic::exp:
8103 case Intrinsic::exp2:
8104 case Intrinsic::exp10:
8105 case Intrinsic::log:
8106 case Intrinsic::log2:
8107 case Intrinsic::log10:
8108 case Intrinsic::modf:
8109 case Intrinsic::floor:
8110 case Intrinsic::ceil:
8111 case Intrinsic::trunc:
8112 case Intrinsic::rint:
8113 case Intrinsic::nearbyint:
8114 case Intrinsic::round:
8115 case Intrinsic::roundeven:
8116 case Intrinsic::lrint:
8117 case Intrinsic::llrint:
8118 case Intrinsic::fshl:
8119 case Intrinsic::fshr:
8120 return true;
8121 default:
8122 return false;
8123 }
8124}
8125
8126bool llvm::propagatesPoison(const Use &PoisonOp) {
8127 const Operator *I = cast<Operator>(Val: PoisonOp.getUser());
8128 switch (I->getOpcode()) {
8129 case Instruction::Freeze:
8130 case Instruction::PHI:
8131 case Instruction::Invoke:
8132 return false;
8133 case Instruction::Select:
8134 return PoisonOp.getOperandNo() == 0;
8135 case Instruction::Call:
8136 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
8137 return intrinsicPropagatesPoison(IID: II->getIntrinsicID());
8138 return false;
8139 case Instruction::ICmp:
8140 case Instruction::FCmp:
8141 case Instruction::GetElementPtr:
8142 return true;
8143 default:
8144 if (isa<BinaryOperator>(Val: I) || isa<UnaryOperator>(Val: I) || isa<CastInst>(Val: I))
8145 return true;
8146
8147 // Be conservative and return false.
8148 return false;
8149 }
8150}
8151
8152/// Enumerates all operands of \p I that are guaranteed to not be undef or
8153/// poison. If the callback \p Handle returns true, stop processing and return
8154/// true. Otherwise, return false.
8155template <typename CallableT>
8156static bool handleGuaranteedWellDefinedOps(const Instruction *I,
8157 const CallableT &Handle) {
8158 switch (I->getOpcode()) {
8159 case Instruction::Store:
8160 if (Handle(cast<StoreInst>(Val: I)->getPointerOperand()))
8161 return true;
8162 break;
8163
8164 case Instruction::Load:
8165 if (Handle(cast<LoadInst>(Val: I)->getPointerOperand()))
8166 return true;
8167 break;
8168
8169 // Since dereferenceable attribute imply noundef, atomic operations
8170 // also implicitly have noundef pointers too
8171 case Instruction::AtomicCmpXchg:
8172 if (Handle(cast<AtomicCmpXchgInst>(Val: I)->getPointerOperand()))
8173 return true;
8174 break;
8175
8176 case Instruction::AtomicRMW:
8177 if (Handle(cast<AtomicRMWInst>(Val: I)->getPointerOperand()))
8178 return true;
8179 break;
8180
8181 case Instruction::Call:
8182 case Instruction::Invoke: {
8183 const CallBase *CB = cast<CallBase>(Val: I);
8184 if (CB->isIndirectCall() && Handle(CB->getCalledOperand()))
8185 return true;
8186 for (unsigned i = 0; i < CB->arg_size(); ++i)
8187 if ((CB->paramHasAttr(ArgNo: i, Kind: Attribute::NoUndef) ||
8188 CB->paramHasAttr(ArgNo: i, Kind: Attribute::Dereferenceable) ||
8189 CB->paramHasAttr(ArgNo: i, Kind: Attribute::DereferenceableOrNull)) &&
8190 Handle(CB->getArgOperand(i)))
8191 return true;
8192 break;
8193 }
8194 case Instruction::Ret:
8195 if (I->getFunction()->hasRetAttribute(Kind: Attribute::NoUndef) &&
8196 Handle(I->getOperand(i: 0)))
8197 return true;
8198 break;
8199 case Instruction::Switch:
8200 if (Handle(cast<SwitchInst>(Val: I)->getCondition()))
8201 return true;
8202 break;
8203 case Instruction::CondBr:
8204 if (Handle(cast<CondBrInst>(Val: I)->getCondition()))
8205 return true;
8206 break;
8207 default:
8208 break;
8209 }
8210
8211 return false;
8212}
8213
8214/// Enumerates all operands of \p I that are guaranteed to not be poison.
8215template <typename CallableT>
8216static bool handleGuaranteedNonPoisonOps(const Instruction *I,
8217 const CallableT &Handle) {
8218 if (handleGuaranteedWellDefinedOps(I, Handle))
8219 return true;
8220 switch (I->getOpcode()) {
8221 // Divisors of these operations are allowed to be partially undef.
8222 case Instruction::UDiv:
8223 case Instruction::SDiv:
8224 case Instruction::URem:
8225 case Instruction::SRem:
8226 return Handle(I->getOperand(i: 1));
8227 default:
8228 return false;
8229 }
8230}
8231
8232bool llvm::mustTriggerUB(const Instruction *I,
8233 const SmallPtrSetImpl<const Value *> &KnownPoison) {
8234 return handleGuaranteedNonPoisonOps(
8235 I, Handle: [&](const Value *V) { return KnownPoison.count(Ptr: V); });
8236}
8237
8238static bool programUndefinedIfUndefOrPoison(const Value *V,
8239 bool PoisonOnly) {
8240 // We currently only look for uses of values within the same basic
8241 // block, as that makes it easier to guarantee that the uses will be
8242 // executed given that Inst is executed.
8243 //
8244 // FIXME: Expand this to consider uses beyond the same basic block. To do
8245 // this, look out for the distinction between post-dominance and strong
8246 // post-dominance.
8247 const BasicBlock *BB = nullptr;
8248 BasicBlock::const_iterator Begin;
8249 if (const auto *Inst = dyn_cast<Instruction>(Val: V)) {
8250 BB = Inst->getParent();
8251 Begin = Inst->getIterator();
8252 Begin++;
8253 } else if (const auto *Arg = dyn_cast<Argument>(Val: V)) {
8254 if (Arg->getParent()->isDeclaration())
8255 return false;
8256 BB = &Arg->getParent()->getEntryBlock();
8257 Begin = BB->begin();
8258 } else {
8259 return false;
8260 }
8261
8262 // Limit number of instructions we look at, to avoid scanning through large
8263 // blocks. The current limit is chosen arbitrarily.
8264 unsigned ScanLimit = 32;
8265 BasicBlock::const_iterator End = BB->end();
8266
8267 if (!PoisonOnly) {
8268 // Since undef does not propagate eagerly, be conservative & just check
8269 // whether a value is directly passed to an instruction that must take
8270 // well-defined operands.
8271
8272 for (const auto &I : make_range(x: Begin, y: End)) {
8273 if (--ScanLimit == 0)
8274 break;
8275
8276 if (handleGuaranteedWellDefinedOps(I: &I, Handle: [V](const Value *WellDefinedOp) {
8277 return WellDefinedOp == V;
8278 }))
8279 return true;
8280
8281 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8282 break;
8283 }
8284 return false;
8285 }
8286
8287 // Set of instructions that we have proved will yield poison if Inst
8288 // does.
8289 SmallPtrSet<const Value *, 16> YieldsPoison;
8290 SmallPtrSet<const BasicBlock *, 4> Visited;
8291
8292 YieldsPoison.insert(Ptr: V);
8293 Visited.insert(Ptr: BB);
8294
8295 while (true) {
8296 for (const auto &I : make_range(x: Begin, y: End)) {
8297 if (--ScanLimit == 0)
8298 return false;
8299 if (mustTriggerUB(I: &I, KnownPoison: YieldsPoison))
8300 return true;
8301 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8302 return false;
8303
8304 // If an operand is poison and propagates it, mark I as yielding poison.
8305 for (const Use &Op : I.operands()) {
8306 if (YieldsPoison.count(Ptr: Op) && propagatesPoison(PoisonOp: Op)) {
8307 YieldsPoison.insert(Ptr: &I);
8308 break;
8309 }
8310 }
8311
8312 // Special handling for select, which returns poison if its operand 0 is
8313 // poison (handled in the loop above) *or* if both its true/false operands
8314 // are poison (handled here).
8315 if (I.getOpcode() == Instruction::Select &&
8316 YieldsPoison.count(Ptr: I.getOperand(i: 1)) &&
8317 YieldsPoison.count(Ptr: I.getOperand(i: 2))) {
8318 YieldsPoison.insert(Ptr: &I);
8319 }
8320 }
8321
8322 BB = BB->getSingleSuccessor();
8323 if (!BB || !Visited.insert(Ptr: BB).second)
8324 break;
8325
8326 Begin = BB->getFirstNonPHIIt();
8327 End = BB->end();
8328 }
8329 return false;
8330}
8331
8332bool llvm::programUndefinedIfUndefOrPoison(const Instruction *Inst) {
8333 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: false);
8334}
8335
8336bool llvm::programUndefinedIfPoison(const Instruction *Inst) {
8337 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: true);
8338}
8339
8340static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
8341 if (FMF.noNaNs())
8342 return true;
8343
8344 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8345 return !C->isNaN();
8346
8347 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8348 if (!C->getElementType()->isFloatingPointTy())
8349 return false;
8350 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8351 if (C->getElementAsAPFloat(i: I).isNaN())
8352 return false;
8353 }
8354 return true;
8355 }
8356
8357 if (isa<ConstantAggregateZero>(Val: V))
8358 return true;
8359
8360 return false;
8361}
8362
8363static bool isKnownNonZero(const Value *V) {
8364 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8365 return !C->isZero();
8366
8367 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8368 if (!C->getElementType()->isFloatingPointTy())
8369 return false;
8370 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8371 if (C->getElementAsAPFloat(i: I).isZero())
8372 return false;
8373 }
8374 return true;
8375 }
8376
8377 return false;
8378}
8379
8380/// Match clamp pattern for float types without care about NaNs or signed zeros.
8381/// Given non-min/max outer cmp/select from the clamp pattern this
8382/// function recognizes if it can be substitued by a "canonical" min/max
8383/// pattern.
8384static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
8385 Value *CmpLHS, Value *CmpRHS,
8386 Value *TrueVal, Value *FalseVal,
8387 Value *&LHS, Value *&RHS) {
8388 // Try to match
8389 // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
8390 // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
8391 // and return description of the outer Max/Min.
8392
8393 // First, check if select has inverse order:
8394 if (CmpRHS == FalseVal) {
8395 std::swap(a&: TrueVal, b&: FalseVal);
8396 Pred = CmpInst::getInversePredicate(pred: Pred);
8397 }
8398
8399 // Assume success now. If there's no match, callers should not use these anyway.
8400 LHS = TrueVal;
8401 RHS = FalseVal;
8402
8403 const APFloat *FC1;
8404 if (CmpRHS != TrueVal || !match(V: CmpRHS, P: m_APFloat(Res&: FC1)) || !FC1->isFinite())
8405 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8406
8407 const APFloat *FC2;
8408 switch (Pred) {
8409 case CmpInst::FCMP_OLT:
8410 case CmpInst::FCMP_OLE:
8411 case CmpInst::FCMP_ULT:
8412 case CmpInst::FCMP_ULE:
8413 if (match(V: FalseVal, P: m_OrdOrUnordFMin(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8414 *FC1 < *FC2)
8415 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8416 break;
8417 case CmpInst::FCMP_OGT:
8418 case CmpInst::FCMP_OGE:
8419 case CmpInst::FCMP_UGT:
8420 case CmpInst::FCMP_UGE:
8421 if (match(V: FalseVal, P: m_OrdOrUnordFMax(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8422 *FC1 > *FC2)
8423 return {.Flavor: SPF_FMINNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8424 break;
8425 default:
8426 break;
8427 }
8428
8429 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8430}
8431
8432/// Recognize variations of:
8433/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
8434static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
8435 Value *CmpLHS, Value *CmpRHS,
8436 Value *TrueVal, Value *FalseVal) {
8437 // Swap the select operands and predicate to match the patterns below.
8438 if (CmpRHS != TrueVal) {
8439 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8440 std::swap(a&: TrueVal, b&: FalseVal);
8441 }
8442 const APInt *C1;
8443 if (CmpRHS == TrueVal && match(V: CmpRHS, P: m_APInt(Res&: C1))) {
8444 const APInt *C2;
8445 // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
8446 if (match(V: FalseVal, P: m_SMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8447 C1->slt(RHS: *C2) && Pred == CmpInst::ICMP_SLT)
8448 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8449
8450 // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
8451 if (match(V: FalseVal, P: m_SMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8452 C1->sgt(RHS: *C2) && Pred == CmpInst::ICMP_SGT)
8453 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8454
8455 // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
8456 if (match(V: FalseVal, P: m_UMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8457 C1->ult(RHS: *C2) && Pred == CmpInst::ICMP_ULT)
8458 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8459
8460 // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
8461 if (match(V: FalseVal, P: m_UMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8462 C1->ugt(RHS: *C2) && Pred == CmpInst::ICMP_UGT)
8463 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8464 }
8465 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8466}
8467
8468/// Recognize variations of:
8469/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
8470static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
8471 Value *CmpLHS, Value *CmpRHS,
8472 Value *TVal, Value *FVal,
8473 unsigned Depth) {
8474 // TODO: Allow FP min/max with nnan/nsz.
8475 assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
8476
8477 Value *A = nullptr, *B = nullptr;
8478 SelectPatternResult L = matchSelectPattern(V: TVal, LHS&: A, RHS&: B, CastOp: nullptr, Depth: Depth + 1);
8479 if (!SelectPatternResult::isMinOrMax(SPF: L.Flavor))
8480 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8481
8482 Value *C = nullptr, *D = nullptr;
8483 SelectPatternResult R = matchSelectPattern(V: FVal, LHS&: C, RHS&: D, CastOp: nullptr, Depth: Depth + 1);
8484 if (L.Flavor != R.Flavor)
8485 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8486
8487 // We have something like: x Pred y ? min(a, b) : min(c, d).
8488 // Try to match the compare to the min/max operations of the select operands.
8489 // First, make sure we have the right compare predicate.
8490 switch (L.Flavor) {
8491 case SPF_SMIN:
8492 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
8493 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8494 std::swap(a&: CmpLHS, b&: CmpRHS);
8495 }
8496 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
8497 break;
8498 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8499 case SPF_SMAX:
8500 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
8501 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8502 std::swap(a&: CmpLHS, b&: CmpRHS);
8503 }
8504 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
8505 break;
8506 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8507 case SPF_UMIN:
8508 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
8509 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8510 std::swap(a&: CmpLHS, b&: CmpRHS);
8511 }
8512 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
8513 break;
8514 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8515 case SPF_UMAX:
8516 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
8517 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8518 std::swap(a&: CmpLHS, b&: CmpRHS);
8519 }
8520 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
8521 break;
8522 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8523 default:
8524 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8525 }
8526
8527 // If there is a common operand in the already matched min/max and the other
8528 // min/max operands match the compare operands (either directly or inverted),
8529 // then this is min/max of the same flavor.
8530
8531 // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8532 // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8533 if (D == B) {
8534 if ((CmpLHS == A && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8535 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8536 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8537 }
8538 // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8539 // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8540 if (C == B) {
8541 if ((CmpLHS == A && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8542 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8543 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8544 }
8545 // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8546 // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8547 if (D == A) {
8548 if ((CmpLHS == B && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8549 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8550 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8551 }
8552 // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8553 // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8554 if (C == A) {
8555 if ((CmpLHS == B && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8556 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8557 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8558 }
8559
8560 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8561}
8562
8563/// If the input value is the result of a 'not' op, constant integer, or vector
8564/// splat of a constant integer, return the bitwise-not source value.
8565/// TODO: This could be extended to handle non-splat vector integer constants.
8566static Value *getNotValue(Value *V) {
8567 Value *NotV;
8568 if (match(V, P: m_Not(V: m_Value(V&: NotV))))
8569 return NotV;
8570
8571 const APInt *C;
8572 if (match(V, P: m_APInt(Res&: C)))
8573 return ConstantInt::get(Ty: V->getType(), V: ~(*C));
8574
8575 return nullptr;
8576}
8577
8578/// Match non-obvious integer minimum and maximum sequences.
8579static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
8580 Value *CmpLHS, Value *CmpRHS,
8581 Value *TrueVal, Value *FalseVal,
8582 Value *&LHS, Value *&RHS,
8583 unsigned Depth) {
8584 // Assume success. If there's no match, callers should not use these anyway.
8585 LHS = TrueVal;
8586 RHS = FalseVal;
8587
8588 SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
8589 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8590 return SPR;
8591
8592 SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TVal: TrueVal, FVal: FalseVal, Depth);
8593 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8594 return SPR;
8595
8596 // Look through 'not' ops to find disguised min/max.
8597 // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y)
8598 // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y)
8599 if (CmpLHS == getNotValue(V: TrueVal) && CmpRHS == getNotValue(V: FalseVal)) {
8600 switch (Pred) {
8601 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8602 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8603 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8604 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8605 default: break;
8606 }
8607 }
8608
8609 // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X)
8610 // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X)
8611 if (CmpLHS == getNotValue(V: FalseVal) && CmpRHS == getNotValue(V: TrueVal)) {
8612 switch (Pred) {
8613 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8614 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8615 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8616 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8617 default: break;
8618 }
8619 }
8620
8621 if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
8622 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8623
8624 const APInt *C1;
8625 if (!match(V: CmpRHS, P: m_APInt(Res&: C1)))
8626 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8627
8628 // An unsigned min/max can be written with a signed compare.
8629 const APInt *C2;
8630 if ((CmpLHS == TrueVal && match(V: FalseVal, P: m_APInt(Res&: C2))) ||
8631 (CmpLHS == FalseVal && match(V: TrueVal, P: m_APInt(Res&: C2)))) {
8632 // Is the sign bit set?
8633 // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
8634 // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
8635 if (Pred == CmpInst::ICMP_SLT && C1->isZero() && C2->isMaxSignedValue())
8636 return {.Flavor: CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8637
8638 // Is the sign bit clear?
8639 // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
8640 // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
8641 if (Pred == CmpInst::ICMP_SGT && C1->isAllOnes() && C2->isMinSignedValue())
8642 return {.Flavor: CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8643 }
8644
8645 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8646}
8647
8648bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
8649 bool AllowPoison) {
8650 assert(X && Y && "Invalid operand");
8651
8652 auto IsNegationOf = [&](const Value *X, const Value *Y) {
8653 if (!match(V: X, P: m_Neg(V: m_Specific(V: Y))))
8654 return false;
8655
8656 auto *BO = cast<BinaryOperator>(Val: X);
8657 if (NeedNSW && !BO->hasNoSignedWrap())
8658 return false;
8659
8660 auto *Zero = cast<Constant>(Val: BO->getOperand(i_nocapture: 0));
8661 if (!AllowPoison && !Zero->isNullValue())
8662 return false;
8663
8664 return true;
8665 };
8666
8667 // X = -Y or Y = -X
8668 if (IsNegationOf(X, Y) || IsNegationOf(Y, X))
8669 return true;
8670
8671 // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
8672 Value *A, *B;
8673 return (!NeedNSW && (match(V: X, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8674 match(V: Y, P: m_Sub(L: m_Specific(V: B), R: m_Specific(V: A))))) ||
8675 (NeedNSW && (match(V: X, P: m_NSWSub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8676 match(V: Y, P: m_NSWSub(L: m_Specific(V: B), R: m_Specific(V: A)))));
8677}
8678
8679bool llvm::isKnownInversion(const Value *X, const Value *Y) {
8680 // Handle X = icmp pred A, B, Y = icmp pred A, C.
8681 Value *A, *B, *C;
8682 CmpPredicate Pred1, Pred2;
8683 if (!match(V: X, P: m_ICmp(Pred&: Pred1, L: m_Value(V&: A), R: m_Value(V&: B))) ||
8684 !match(V: Y, P: m_c_ICmp(Pred&: Pred2, L: m_Specific(V: A), R: m_Value(V&: C))))
8685 return false;
8686
8687 // They must both have samesign flag or not.
8688 if (Pred1.hasSameSign() != Pred2.hasSameSign())
8689 return false;
8690
8691 if (B == C)
8692 return Pred1 == ICmpInst::getInversePredicate(pred: Pred2);
8693
8694 // Try to infer the relationship from constant ranges.
8695 const APInt *RHSC1, *RHSC2;
8696 if (!match(V: B, P: m_APInt(Res&: RHSC1)) || !match(V: C, P: m_APInt(Res&: RHSC2)))
8697 return false;
8698
8699 // Sign bits of two RHSCs should match.
8700 if (Pred1.hasSameSign() && RHSC1->isNonNegative() != RHSC2->isNonNegative())
8701 return false;
8702
8703 const auto CR1 = ConstantRange::makeExactICmpRegion(Pred: Pred1, Other: *RHSC1);
8704 const auto CR2 = ConstantRange::makeExactICmpRegion(Pred: Pred2, Other: *RHSC2);
8705
8706 return CR1.inverse() == CR2;
8707}
8708
8709SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
8710 SelectPatternNaNBehavior NaNBehavior,
8711 bool Ordered) {
8712 switch (Pred) {
8713 default:
8714 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false}; // Equality.
8715 case ICmpInst::ICMP_UGT:
8716 case ICmpInst::ICMP_UGE:
8717 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8718 case ICmpInst::ICMP_SGT:
8719 case ICmpInst::ICMP_SGE:
8720 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8721 case ICmpInst::ICMP_ULT:
8722 case ICmpInst::ICMP_ULE:
8723 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8724 case ICmpInst::ICMP_SLT:
8725 case ICmpInst::ICMP_SLE:
8726 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8727 case FCmpInst::FCMP_UGT:
8728 case FCmpInst::FCMP_UGE:
8729 case FCmpInst::FCMP_OGT:
8730 case FCmpInst::FCMP_OGE:
8731 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8732 case FCmpInst::FCMP_ULT:
8733 case FCmpInst::FCMP_ULE:
8734 case FCmpInst::FCMP_OLT:
8735 case FCmpInst::FCMP_OLE:
8736 return {.Flavor: SPF_FMINNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8737 }
8738}
8739
8740std::optional<std::pair<CmpPredicate, Constant *>>
8741llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8742 assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8743 "Only for relational integer predicates.");
8744 if (isa<UndefValue>(Val: C))
8745 return std::nullopt;
8746
8747 Type *Type = C->getType();
8748 bool IsSigned = ICmpInst::isSigned(Pred);
8749
8750 CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8751 bool WillIncrement =
8752 UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8753
8754 // Check if the constant operand can be safely incremented/decremented
8755 // without overflowing/underflowing.
8756 auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8757 return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8758 };
8759
8760 Constant *SafeReplacementConstant = nullptr;
8761 if (auto *CI = dyn_cast<ConstantInt>(Val: C)) {
8762 // Bail out if the constant can't be safely incremented/decremented.
8763 if (!ConstantIsOk(CI))
8764 return std::nullopt;
8765 } else if (auto *FVTy = dyn_cast<FixedVectorType>(Val: Type)) {
8766 unsigned NumElts = FVTy->getNumElements();
8767 for (unsigned i = 0; i != NumElts; ++i) {
8768 Constant *Elt = C->getAggregateElement(Elt: i);
8769 if (!Elt)
8770 return std::nullopt;
8771
8772 if (isa<UndefValue>(Val: Elt))
8773 continue;
8774
8775 // Bail out if we can't determine if this constant is min/max or if we
8776 // know that this constant is min/max.
8777 auto *CI = dyn_cast<ConstantInt>(Val: Elt);
8778 if (!CI || !ConstantIsOk(CI))
8779 return std::nullopt;
8780
8781 if (!SafeReplacementConstant)
8782 SafeReplacementConstant = CI;
8783 }
8784 } else if (isa<VectorType>(Val: C->getType())) {
8785 // Handle scalable splat
8786 Value *SplatC = C->getSplatValue();
8787 auto *CI = dyn_cast_or_null<ConstantInt>(Val: SplatC);
8788 // Bail out if the constant can't be safely incremented/decremented.
8789 if (!CI || !ConstantIsOk(CI))
8790 return std::nullopt;
8791 } else {
8792 // ConstantExpr?
8793 return std::nullopt;
8794 }
8795
8796 // It may not be safe to change a compare predicate in the presence of
8797 // undefined elements, so replace those elements with the first safe constant
8798 // that we found.
8799 // TODO: in case of poison, it is safe; let's replace undefs only.
8800 if (C->containsUndefOrPoisonElement()) {
8801 assert(SafeReplacementConstant && "Replacement constant not set");
8802 C = Constant::replaceUndefsWith(C, Replacement: SafeReplacementConstant);
8803 }
8804
8805 CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(pred: Pred);
8806
8807 // Increment or decrement the constant.
8808 Constant *OneOrNegOne = ConstantInt::get(Ty: Type, V: WillIncrement ? 1 : -1, IsSigned: true);
8809 Constant *NewC = ConstantExpr::getAdd(C1: C, C2: OneOrNegOne);
8810
8811 return std::make_pair(x&: NewPred, y&: NewC);
8812}
8813
8814static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
8815 FastMathFlags FMF,
8816 Value *CmpLHS, Value *CmpRHS,
8817 Value *TrueVal, Value *FalseVal,
8818 Value *&LHS, Value *&RHS,
8819 unsigned Depth) {
8820 bool HasMismatchedZeros = false;
8821 if (CmpInst::isFPPredicate(P: Pred)) {
8822 // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
8823 // 0.0 operand, set the compare's 0.0 operands to that same value for the
8824 // purpose of identifying min/max. Disregard vector constants with undefined
8825 // elements because those can not be back-propagated for analysis.
8826 Value *OutputZeroVal = nullptr;
8827 if (match(V: TrueVal, P: m_AnyZeroFP()) && !match(V: FalseVal, P: m_AnyZeroFP()) &&
8828 !cast<Constant>(Val: TrueVal)->containsUndefOrPoisonElement())
8829 OutputZeroVal = TrueVal;
8830 else if (match(V: FalseVal, P: m_AnyZeroFP()) && !match(V: TrueVal, P: m_AnyZeroFP()) &&
8831 !cast<Constant>(Val: FalseVal)->containsUndefOrPoisonElement())
8832 OutputZeroVal = FalseVal;
8833
8834 if (OutputZeroVal) {
8835 if (match(V: CmpLHS, P: m_AnyZeroFP()) && CmpLHS != OutputZeroVal) {
8836 HasMismatchedZeros = true;
8837 CmpLHS = OutputZeroVal;
8838 }
8839 if (match(V: CmpRHS, P: m_AnyZeroFP()) && CmpRHS != OutputZeroVal) {
8840 HasMismatchedZeros = true;
8841 CmpRHS = OutputZeroVal;
8842 }
8843 }
8844 }
8845
8846 LHS = CmpLHS;
8847 RHS = CmpRHS;
8848
8849 // Signed zero may return inconsistent results between implementations.
8850 // (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
8851 // minNum(0.0, -0.0) // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
8852 // Therefore, we behave conservatively and only proceed if at least one of the
8853 // operands is known to not be zero or if we don't care about signed zero.
8854 switch (Pred) {
8855 default: break;
8856 case CmpInst::FCMP_OGT: case CmpInst::FCMP_OLT:
8857 case CmpInst::FCMP_UGT: case CmpInst::FCMP_ULT:
8858 if (!HasMismatchedZeros)
8859 break;
8860 [[fallthrough]];
8861 case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
8862 case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
8863 if (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8864 !isKnownNonZero(V: CmpRHS))
8865 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8866 }
8867
8868 SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
8869 bool Ordered = false;
8870
8871 // When given one NaN and one non-NaN input:
8872 // - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
8873 // - A simple C99 (a < b ? a : b) construction will return 'b' (as the
8874 // ordered comparison fails), which could be NaN or non-NaN.
8875 // so here we discover exactly what NaN behavior is required/accepted.
8876 if (CmpInst::isFPPredicate(P: Pred)) {
8877 bool LHSSafe = isKnownNonNaN(V: CmpLHS, FMF);
8878 bool RHSSafe = isKnownNonNaN(V: CmpRHS, FMF);
8879
8880 if (LHSSafe && RHSSafe) {
8881 // Both operands are known non-NaN.
8882 NaNBehavior = SPNB_RETURNS_ANY;
8883 Ordered = CmpInst::isOrdered(predicate: Pred);
8884 } else if (CmpInst::isOrdered(predicate: Pred)) {
8885 // An ordered comparison will return false when given a NaN, so it
8886 // returns the RHS.
8887 Ordered = true;
8888 if (LHSSafe)
8889 // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
8890 NaNBehavior = SPNB_RETURNS_NAN;
8891 else if (RHSSafe)
8892 NaNBehavior = SPNB_RETURNS_OTHER;
8893 else
8894 // Completely unsafe.
8895 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8896 } else {
8897 Ordered = false;
8898 // An unordered comparison will return true when given a NaN, so it
8899 // returns the LHS.
8900 if (LHSSafe)
8901 // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
8902 NaNBehavior = SPNB_RETURNS_OTHER;
8903 else if (RHSSafe)
8904 NaNBehavior = SPNB_RETURNS_NAN;
8905 else
8906 // Completely unsafe.
8907 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8908 }
8909 }
8910
8911 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
8912 std::swap(a&: CmpLHS, b&: CmpRHS);
8913 Pred = CmpInst::getSwappedPredicate(pred: Pred);
8914 if (NaNBehavior == SPNB_RETURNS_NAN)
8915 NaNBehavior = SPNB_RETURNS_OTHER;
8916 else if (NaNBehavior == SPNB_RETURNS_OTHER)
8917 NaNBehavior = SPNB_RETURNS_NAN;
8918 Ordered = !Ordered;
8919 }
8920
8921 // ([if]cmp X, Y) ? X : Y
8922 if (TrueVal == CmpLHS && FalseVal == CmpRHS)
8923 return getSelectPattern(Pred, NaNBehavior, Ordered);
8924
8925 if (isKnownNegation(X: TrueVal, Y: FalseVal)) {
8926 // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
8927 // match against either LHS or sext(LHS).
8928 auto MaybeSExtCmpLHS =
8929 m_CombineOr(L: m_Specific(V: CmpLHS), R: m_SExt(Op: m_Specific(V: CmpLHS)));
8930 auto ZeroOrAllOnes = m_CombineOr(L: m_ZeroInt(), R: m_AllOnes());
8931 auto ZeroOrOne = m_CombineOr(L: m_ZeroInt(), R: m_One());
8932 if (match(V: TrueVal, P: MaybeSExtCmpLHS)) {
8933 // Set the return values. If the compare uses the negated value (-X >s 0),
8934 // swap the return values because the negated value is always 'RHS'.
8935 LHS = TrueVal;
8936 RHS = FalseVal;
8937 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: FalseVal))))
8938 std::swap(a&: LHS, b&: RHS);
8939
8940 // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
8941 // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
8942 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8943 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8944
8945 // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
8946 if (Pred == ICmpInst::ICMP_SGE && match(V: CmpRHS, P: ZeroOrOne))
8947 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8948
8949 // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
8950 // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
8951 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8952 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8953 }
8954 else if (match(V: FalseVal, P: MaybeSExtCmpLHS)) {
8955 // Set the return values. If the compare uses the negated value (-X >s 0),
8956 // swap the return values because the negated value is always 'RHS'.
8957 LHS = FalseVal;
8958 RHS = TrueVal;
8959 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: TrueVal))))
8960 std::swap(a&: LHS, b&: RHS);
8961
8962 // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
8963 // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
8964 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8965 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8966
8967 // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
8968 // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
8969 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8970 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8971 }
8972 }
8973
8974 if (CmpInst::isIntPredicate(P: Pred))
8975 return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
8976
8977 // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
8978 // may return either -0.0 or 0.0, so fcmp/select pair has stricter
8979 // semantics than minNum. Be conservative in such case.
8980 if (NaNBehavior != SPNB_RETURNS_ANY ||
8981 (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8982 !isKnownNonZero(V: CmpRHS)))
8983 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8984
8985 return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
8986}
8987
8988static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
8989 Instruction::CastOps *CastOp) {
8990 const DataLayout &DL = CmpI->getDataLayout();
8991
8992 Constant *CastedTo = nullptr;
8993 switch (*CastOp) {
8994 case Instruction::ZExt:
8995 if (CmpI->isUnsigned())
8996 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy);
8997 break;
8998 case Instruction::SExt:
8999 if (CmpI->isSigned())
9000 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy, OnlyIfReduced: true);
9001 break;
9002 case Instruction::Trunc:
9003 Constant *CmpConst;
9004 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_Constant(C&: CmpConst)) &&
9005 CmpConst->getType() == SrcTy) {
9006 // Here we have the following case:
9007 //
9008 // %cond = cmp iN %x, CmpConst
9009 // %tr = trunc iN %x to iK
9010 // %narrowsel = select i1 %cond, iK %t, iK C
9011 //
9012 // We can always move trunc after select operation:
9013 //
9014 // %cond = cmp iN %x, CmpConst
9015 // %widesel = select i1 %cond, iN %x, iN CmpConst
9016 // %tr = trunc iN %widesel to iK
9017 //
9018 // Note that C could be extended in any way because we don't care about
9019 // upper bits after truncation. It can't be abs pattern, because it would
9020 // look like:
9021 //
9022 // select i1 %cond, x, -x.
9023 //
9024 // So only min/max pattern could be matched. Such match requires widened C
9025 // == CmpConst. That is why set widened C = CmpConst, condition trunc
9026 // CmpConst == C is checked below.
9027 CastedTo = CmpConst;
9028 } else {
9029 unsigned ExtOp = CmpI->isSigned() ? Instruction::SExt : Instruction::ZExt;
9030 CastedTo = ConstantFoldCastOperand(Opcode: ExtOp, C, DestTy: SrcTy, DL);
9031 }
9032 break;
9033 case Instruction::FPTrunc:
9034 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPExt, C, DestTy: SrcTy, DL);
9035 break;
9036 case Instruction::FPExt:
9037 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPTrunc, C, DestTy: SrcTy, DL);
9038 break;
9039 case Instruction::FPToUI:
9040 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::UIToFP, C, DestTy: SrcTy, DL);
9041 break;
9042 case Instruction::FPToSI:
9043 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::SIToFP, C, DestTy: SrcTy, DL);
9044 break;
9045 case Instruction::UIToFP:
9046 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToUI, C, DestTy: SrcTy, DL);
9047 break;
9048 case Instruction::SIToFP:
9049 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToSI, C, DestTy: SrcTy, DL);
9050 break;
9051 default:
9052 break;
9053 }
9054
9055 if (!CastedTo)
9056 return nullptr;
9057
9058 // Make sure the cast doesn't lose any information.
9059 Constant *CastedBack =
9060 ConstantFoldCastOperand(Opcode: *CastOp, C: CastedTo, DestTy: C->getType(), DL);
9061 if (CastedBack && CastedBack != C)
9062 return nullptr;
9063
9064 return CastedTo;
9065}
9066
9067/// Helps to match a select pattern in case of a type mismatch.
9068///
9069/// The function processes the case when type of true and false values of a
9070/// select instruction differs from type of the cmp instruction operands because
9071/// of a cast instruction. The function checks if it is legal to move the cast
9072/// operation after "select". If yes, it returns the new second value of
9073/// "select" (with the assumption that cast is moved):
9074/// 1. As operand of cast instruction when both values of "select" are same cast
9075/// instructions.
9076/// 2. As restored constant (by applying reverse cast operation) when the first
9077/// value of the "select" is a cast operation and the second value is a
9078/// constant. It is implemented in lookThroughCastConst().
9079/// 3. As one operand is cast instruction and the other is not. The operands in
9080/// sel(cmp) are in different type integer.
9081/// NOTE: We return only the new second value because the first value could be
9082/// accessed as operand of cast instruction.
9083static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
9084 Instruction::CastOps *CastOp) {
9085 auto *Cast1 = dyn_cast<CastInst>(Val: V1);
9086 if (!Cast1)
9087 return nullptr;
9088
9089 *CastOp = Cast1->getOpcode();
9090 Type *SrcTy = Cast1->getSrcTy();
9091 if (auto *Cast2 = dyn_cast<CastInst>(Val: V2)) {
9092 // If V1 and V2 are both the same cast from the same type, look through V1.
9093 if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
9094 return Cast2->getOperand(i_nocapture: 0);
9095 return nullptr;
9096 }
9097
9098 auto *C = dyn_cast<Constant>(Val: V2);
9099 if (C)
9100 return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
9101
9102 Value *CastedTo = nullptr;
9103 if (*CastOp == Instruction::Trunc) {
9104 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_ZExtOrSExt(Op: m_Specific(V: V2)))) {
9105 // Here we have the following case:
9106 // %y_ext = sext iK %y to iN
9107 // %cond = cmp iN %x, %y_ext
9108 // %tr = trunc iN %x to iK
9109 // %narrowsel = select i1 %cond, iK %tr, iK %y
9110 //
9111 // We can always move trunc after select operation:
9112 // %y_ext = sext iK %y to iN
9113 // %cond = cmp iN %x, %y_ext
9114 // %widesel = select i1 %cond, iN %x, iN %y_ext
9115 // %tr = trunc iN %widesel to iK
9116 assert(V2->getType() == Cast1->getType() &&
9117 "V2 and Cast1 should be the same type.");
9118 CastedTo = CmpI->getOperand(i_nocapture: 1);
9119 }
9120 }
9121
9122 return CastedTo;
9123}
9124SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
9125 Instruction::CastOps *CastOp,
9126 unsigned Depth) {
9127 if (Depth >= MaxAnalysisRecursionDepth)
9128 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9129
9130 SelectInst *SI = dyn_cast<SelectInst>(Val: V);
9131 if (!SI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9132
9133 CmpInst *CmpI = dyn_cast<CmpInst>(Val: SI->getCondition());
9134 if (!CmpI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9135
9136 Value *TrueVal = SI->getTrueValue();
9137 Value *FalseVal = SI->getFalseValue();
9138
9139 return llvm::matchDecomposedSelectPattern(
9140 CmpI, TrueVal, FalseVal, LHS, RHS,
9141 FMF: isa<FPMathOperator>(Val: SI) ? SI->getFastMathFlags() : FastMathFlags(),
9142 CastOp, Depth);
9143}
9144
9145SelectPatternResult llvm::matchDecomposedSelectPattern(
9146 CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
9147 FastMathFlags FMF, Instruction::CastOps *CastOp, unsigned Depth) {
9148 CmpInst::Predicate Pred = CmpI->getPredicate();
9149 Value *CmpLHS = CmpI->getOperand(i_nocapture: 0);
9150 Value *CmpRHS = CmpI->getOperand(i_nocapture: 1);
9151 if (isa<FPMathOperator>(Val: CmpI) && CmpI->hasNoNaNs())
9152 FMF.setNoNaNs();
9153
9154 // Bail out early.
9155 if (CmpI->isEquality())
9156 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9157
9158 // Deal with type mismatches.
9159 if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
9160 if (Value *C = lookThroughCast(CmpI, V1: TrueVal, V2: FalseVal, CastOp)) {
9161 // If this is a potential fmin/fmax with a cast to integer, then ignore
9162 // -0.0 because there is no corresponding integer value.
9163 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9164 FMF.setNoSignedZeros();
9165 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9166 TrueVal: cast<CastInst>(Val: TrueVal)->getOperand(i_nocapture: 0), FalseVal: C,
9167 LHS, RHS, Depth);
9168 }
9169 if (Value *C = lookThroughCast(CmpI, V1: FalseVal, V2: TrueVal, CastOp)) {
9170 // If this is a potential fmin/fmax with a cast to integer, then ignore
9171 // -0.0 because there is no corresponding integer value.
9172 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9173 FMF.setNoSignedZeros();
9174 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9175 TrueVal: C, FalseVal: cast<CastInst>(Val: FalseVal)->getOperand(i_nocapture: 0),
9176 LHS, RHS, Depth);
9177 }
9178 }
9179 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
9180 LHS, RHS, Depth);
9181}
9182
9183CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
9184 if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
9185 if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
9186 if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
9187 if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
9188 if (SPF == SPF_FMINNUM)
9189 return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
9190 if (SPF == SPF_FMAXNUM)
9191 return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
9192 llvm_unreachable("unhandled!");
9193}
9194
9195Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
9196 switch (SPF) {
9197 case SelectPatternFlavor::SPF_UMIN:
9198 return Intrinsic::umin;
9199 case SelectPatternFlavor::SPF_UMAX:
9200 return Intrinsic::umax;
9201 case SelectPatternFlavor::SPF_SMIN:
9202 return Intrinsic::smin;
9203 case SelectPatternFlavor::SPF_SMAX:
9204 return Intrinsic::smax;
9205 default:
9206 llvm_unreachable("Unexpected SPF");
9207 }
9208}
9209
9210SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
9211 if (SPF == SPF_SMIN) return SPF_SMAX;
9212 if (SPF == SPF_UMIN) return SPF_UMAX;
9213 if (SPF == SPF_SMAX) return SPF_SMIN;
9214 if (SPF == SPF_UMAX) return SPF_UMIN;
9215 llvm_unreachable("unhandled!");
9216}
9217
9218Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
9219 switch (MinMaxID) {
9220 case Intrinsic::smax: return Intrinsic::smin;
9221 case Intrinsic::smin: return Intrinsic::smax;
9222 case Intrinsic::umax: return Intrinsic::umin;
9223 case Intrinsic::umin: return Intrinsic::umax;
9224 // Please note that next four intrinsics may produce the same result for
9225 // original and inverted case even if X != Y due to NaN is handled specially.
9226 case Intrinsic::maximum: return Intrinsic::minimum;
9227 case Intrinsic::minimum: return Intrinsic::maximum;
9228 case Intrinsic::maxnum: return Intrinsic::minnum;
9229 case Intrinsic::minnum: return Intrinsic::maxnum;
9230 case Intrinsic::maximumnum:
9231 return Intrinsic::minimumnum;
9232 case Intrinsic::minimumnum:
9233 return Intrinsic::maximumnum;
9234 default: llvm_unreachable("Unexpected intrinsic");
9235 }
9236}
9237
9238APInt llvm::getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth) {
9239 switch (SPF) {
9240 case SPF_SMAX: return APInt::getSignedMaxValue(numBits: BitWidth);
9241 case SPF_SMIN: return APInt::getSignedMinValue(numBits: BitWidth);
9242 case SPF_UMAX: return APInt::getMaxValue(numBits: BitWidth);
9243 case SPF_UMIN: return APInt::getMinValue(numBits: BitWidth);
9244 default: llvm_unreachable("Unexpected flavor");
9245 }
9246}
9247
9248std::pair<Intrinsic::ID, bool>
9249llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
9250 // Check if VL contains select instructions that can be folded into a min/max
9251 // vector intrinsic and return the intrinsic if it is possible.
9252 // TODO: Support floating point min/max.
9253 bool AllCmpSingleUse = true;
9254 SelectPatternResult SelectPattern;
9255 SelectPattern.Flavor = SPF_UNKNOWN;
9256 if (all_of(Range&: VL, P: [&SelectPattern, &AllCmpSingleUse](Value *I) {
9257 Value *LHS, *RHS;
9258 auto CurrentPattern = matchSelectPattern(V: I, LHS, RHS);
9259 if (!SelectPatternResult::isMinOrMax(SPF: CurrentPattern.Flavor))
9260 return false;
9261 if (SelectPattern.Flavor != SPF_UNKNOWN &&
9262 SelectPattern.Flavor != CurrentPattern.Flavor)
9263 return false;
9264 SelectPattern = CurrentPattern;
9265 AllCmpSingleUse &=
9266 match(V: I, P: m_Select(C: m_OneUse(SubPattern: m_Value()), L: m_Value(), R: m_Value()));
9267 return true;
9268 })) {
9269 switch (SelectPattern.Flavor) {
9270 case SPF_SMIN:
9271 return {Intrinsic::smin, AllCmpSingleUse};
9272 case SPF_UMIN:
9273 return {Intrinsic::umin, AllCmpSingleUse};
9274 case SPF_SMAX:
9275 return {Intrinsic::smax, AllCmpSingleUse};
9276 case SPF_UMAX:
9277 return {Intrinsic::umax, AllCmpSingleUse};
9278 case SPF_FMAXNUM:
9279 return {Intrinsic::maxnum, AllCmpSingleUse};
9280 case SPF_FMINNUM:
9281 return {Intrinsic::minnum, AllCmpSingleUse};
9282 default:
9283 llvm_unreachable("unexpected select pattern flavor");
9284 }
9285 }
9286 return {Intrinsic::not_intrinsic, false};
9287}
9288
9289template <typename InstTy>
9290static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
9291 Value *&Init, Value *&OtherOp) {
9292 // Handle the case of a simple two-predecessor recurrence PHI.
9293 // There's a lot more that could theoretically be done here, but
9294 // this is sufficient to catch some interesting cases.
9295 // TODO: Expand list -- gep, uadd.sat etc.
9296 if (PN->getNumIncomingValues() != 2)
9297 return false;
9298
9299 for (unsigned I = 0; I != 2; ++I) {
9300 if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(i: I));
9301 Operation && Operation->getNumOperands() >= 2) {
9302 Value *LHS = Operation->getOperand(0);
9303 Value *RHS = Operation->getOperand(1);
9304 if (LHS != PN && RHS != PN)
9305 continue;
9306
9307 Inst = Operation;
9308 Init = PN->getIncomingValue(i: !I);
9309 OtherOp = (LHS == PN) ? RHS : LHS;
9310 return true;
9311 }
9312 }
9313 return false;
9314}
9315
9316template <typename InstTy>
9317static bool matchThreeInputRecurrence(const PHINode *PN, InstTy *&Inst,
9318 Value *&Init, Value *&OtherOp0,
9319 Value *&OtherOp1) {
9320 if (PN->getNumIncomingValues() != 2)
9321 return false;
9322
9323 for (unsigned I = 0; I != 2; ++I) {
9324 if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(i: I));
9325 Operation && Operation->getNumOperands() >= 3) {
9326 Value *Op0 = Operation->getOperand(0);
9327 Value *Op1 = Operation->getOperand(1);
9328 Value *Op2 = Operation->getOperand(2);
9329
9330 if (Op0 != PN && Op1 != PN && Op2 != PN)
9331 continue;
9332
9333 Inst = Operation;
9334 Init = PN->getIncomingValue(i: !I);
9335 if (Op0 == PN) {
9336 OtherOp0 = Op1;
9337 OtherOp1 = Op2;
9338 } else if (Op1 == PN) {
9339 OtherOp0 = Op0;
9340 OtherOp1 = Op2;
9341 } else {
9342 OtherOp0 = Op0;
9343 OtherOp1 = Op1;
9344 }
9345 return true;
9346 }
9347 }
9348 return false;
9349}
9350bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9351 Value *&Start, Value *&Step) {
9352 // We try to match a recurrence of the form:
9353 // %iv = [Start, %entry], [%iv.next, %backedge]
9354 // %iv.next = binop %iv, Step
9355 // Or:
9356 // %iv = [Start, %entry], [%iv.next, %backedge]
9357 // %iv.next = binop Step, %iv
9358 return matchTwoInputRecurrence(PN: P, Inst&: BO, Init&: Start, OtherOp&: Step);
9359}
9360
9361bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
9362 Value *&Start, Value *&Step) {
9363 BinaryOperator *BO = nullptr;
9364 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 0));
9365 if (!P)
9366 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 1));
9367 return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
9368}
9369
9370bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
9371 PHINode *&P, Value *&Init,
9372 Value *&OtherOp) {
9373 // Binary intrinsics only supported for now.
9374 if (I->arg_size() != 2 || I->getType() != I->getArgOperand(i: 0)->getType() ||
9375 I->getType() != I->getArgOperand(i: 1)->getType())
9376 return false;
9377
9378 IntrinsicInst *II = nullptr;
9379 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 0));
9380 if (!P)
9381 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 1));
9382
9383 return P && matchTwoInputRecurrence(PN: P, Inst&: II, Init, OtherOp) && II == I;
9384}
9385
9386bool llvm::matchSimpleTernaryIntrinsicRecurrence(const IntrinsicInst *I,
9387 PHINode *&P, Value *&Init,
9388 Value *&OtherOp0,
9389 Value *&OtherOp1) {
9390 if (I->arg_size() != 3 || I->getType() != I->getArgOperand(i: 0)->getType() ||
9391 I->getType() != I->getArgOperand(i: 1)->getType() ||
9392 I->getType() != I->getArgOperand(i: 2)->getType())
9393 return false;
9394 IntrinsicInst *II = nullptr;
9395 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 0));
9396 if (!P) {
9397 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 1));
9398 if (!P)
9399 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 2));
9400 }
9401 return P && matchThreeInputRecurrence(PN: P, Inst&: II, Init, OtherOp0, OtherOp1) &&
9402 II == I;
9403}
9404
9405/// Return true if "icmp Pred LHS RHS" is always true.
9406static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
9407 const Value *RHS) {
9408 if (ICmpInst::isTrueWhenEqual(predicate: Pred) && LHS == RHS)
9409 return true;
9410
9411 switch (Pred) {
9412 default:
9413 return false;
9414
9415 case CmpInst::ICMP_SLE: {
9416 const APInt *C;
9417
9418 // LHS s<= LHS +_{nsw} C if C >= 0
9419 // LHS s<= LHS | C if C >= 0
9420 if (match(V: RHS, P: m_NSWAdd(L: m_Specific(V: LHS), R: m_APInt(Res&: C))) ||
9421 match(V: RHS, P: m_Or(L: m_Specific(V: LHS), R: m_APInt(Res&: C))))
9422 return !C->isNegative();
9423
9424 // LHS s<= smax(LHS, V) for any V
9425 if (match(V: RHS, P: m_c_SMax(L: m_Specific(V: LHS), R: m_Value())))
9426 return true;
9427
9428 // smin(RHS, V) s<= RHS for any V
9429 if (match(V: LHS, P: m_c_SMin(L: m_Specific(V: RHS), R: m_Value())))
9430 return true;
9431
9432 // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
9433 const Value *X;
9434 const APInt *CLHS, *CRHS;
9435 if (match(V: LHS, P: m_NSWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9436 match(V: RHS, P: m_NSWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9437 return CLHS->sle(RHS: *CRHS);
9438
9439 return false;
9440 }
9441
9442 case CmpInst::ICMP_ULE: {
9443 // LHS u<= LHS +_{nuw} V for any V
9444 if (match(V: RHS, P: m_c_Add(L: m_Specific(V: LHS), R: m_Value())) &&
9445 cast<OverflowingBinaryOperator>(Val: RHS)->hasNoUnsignedWrap())
9446 return true;
9447
9448 // LHS u<= LHS | V for any V
9449 if (match(V: RHS, P: m_c_Or(L: m_Specific(V: LHS), R: m_Value())))
9450 return true;
9451
9452 // LHS u<= umax(LHS, V) for any V
9453 if (match(V: RHS, P: m_c_UMax(L: m_Specific(V: LHS), R: m_Value())))
9454 return true;
9455
9456 // RHS >> V u<= RHS for any V
9457 if (match(V: LHS, P: m_LShr(L: m_Specific(V: RHS), R: m_Value())))
9458 return true;
9459
9460 // RHS u/ C_ugt_1 u<= RHS
9461 const APInt *C;
9462 if (match(V: LHS, P: m_UDiv(L: m_Specific(V: RHS), R: m_APInt(Res&: C))) && C->ugt(RHS: 1))
9463 return true;
9464
9465 // RHS & V u<= RHS for any V
9466 if (match(V: LHS, P: m_c_And(L: m_Specific(V: RHS), R: m_Value())))
9467 return true;
9468
9469 // umin(RHS, V) u<= RHS for any V
9470 if (match(V: LHS, P: m_c_UMin(L: m_Specific(V: RHS), R: m_Value())))
9471 return true;
9472
9473 // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
9474 const Value *X;
9475 const APInt *CLHS, *CRHS;
9476 if (match(V: LHS, P: m_NUWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9477 match(V: RHS, P: m_NUWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9478 return CLHS->ule(RHS: *CRHS);
9479
9480 return false;
9481 }
9482 }
9483}
9484
9485/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
9486/// ALHS ARHS" is true. Otherwise, return std::nullopt.
9487static std::optional<bool>
9488isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
9489 const Value *ARHS, const Value *BLHS, const Value *BRHS) {
9490 switch (Pred) {
9491 default:
9492 return std::nullopt;
9493
9494 case CmpInst::ICMP_SLT:
9495 case CmpInst::ICMP_SLE:
9496 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BLHS, RHS: ALHS) &&
9497 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ARHS, RHS: BRHS))
9498 return true;
9499 return std::nullopt;
9500
9501 case CmpInst::ICMP_SGT:
9502 case CmpInst::ICMP_SGE:
9503 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ALHS, RHS: BLHS) &&
9504 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BRHS, RHS: ARHS))
9505 return true;
9506 return std::nullopt;
9507
9508 case CmpInst::ICMP_ULT:
9509 case CmpInst::ICMP_ULE:
9510 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BLHS, RHS: ALHS) &&
9511 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ARHS, RHS: BRHS))
9512 return true;
9513 return std::nullopt;
9514
9515 case CmpInst::ICMP_UGT:
9516 case CmpInst::ICMP_UGE:
9517 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ALHS, RHS: BLHS) &&
9518 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BRHS, RHS: ARHS))
9519 return true;
9520 return std::nullopt;
9521 }
9522}
9523
9524/// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
9525/// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
9526/// Otherwise, return std::nullopt if we can't infer anything.
9527static std::optional<bool>
9528isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
9529 CmpPredicate RPred, const ConstantRange &RCR) {
9530 auto CRImpliesPred = [&](ConstantRange CR,
9531 CmpInst::Predicate Pred) -> std::optional<bool> {
9532 // If all true values for lhs and true for rhs, lhs implies rhs
9533 if (CR.icmp(Pred, Other: RCR))
9534 return true;
9535
9536 // If there is no overlap, lhs implies not rhs
9537 if (CR.icmp(Pred: CmpInst::getInversePredicate(pred: Pred), Other: RCR))
9538 return false;
9539
9540 return std::nullopt;
9541 };
9542 if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9543 RPred))
9544 return Res;
9545 if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
9546 LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: LPred)
9547 : LPred.dropSameSign();
9548 RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: RPred)
9549 : RPred.dropSameSign();
9550 return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9551 RPred);
9552 }
9553 return std::nullopt;
9554}
9555
9556/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9557/// is true. Return false if LHS implies RHS is false. Otherwise, return
9558/// std::nullopt if we can't infer anything.
9559static std::optional<bool>
9560isImpliedCondICmps(CmpPredicate LPred, const Value *L0, const Value *L1,
9561 CmpPredicate RPred, const Value *R0, const Value *R1,
9562 const DataLayout &DL, bool LHSIsTrue) {
9563 // The rest of the logic assumes the LHS condition is true. If that's not the
9564 // case, invert the predicate to make it so.
9565 if (!LHSIsTrue)
9566 LPred = ICmpInst::getInverseCmpPredicate(Pred: LPred);
9567
9568 // We can have non-canonical operands, so try to normalize any common operand
9569 // to L0/R0.
9570 if (L0 == R1) {
9571 std::swap(a&: R0, b&: R1);
9572 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9573 }
9574 if (R0 == L1) {
9575 std::swap(a&: L0, b&: L1);
9576 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9577 }
9578 if (L1 == R1) {
9579 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9580 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9581 std::swap(a&: L0, b&: L1);
9582 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9583 std::swap(a&: R0, b&: R1);
9584 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9585 }
9586 }
9587
9588 // See if we can infer anything if operand-0 matches and we have at least one
9589 // constant.
9590 const APInt *Unused;
9591 if (L0 == R0 && (match(V: L1, P: m_APInt(Res&: Unused)) || match(V: R1, P: m_APInt(Res&: Unused)))) {
9592 // Potential TODO: We could also further use the constant range of L0/R0 to
9593 // further constraint the constant ranges. At the moment this leads to
9594 // several regressions related to not transforming `multi_use(A + C0) eq/ne
9595 // C1` (see discussion: D58633).
9596 ConstantRange LCR = computeConstantRange(
9597 V: L1, ForSigned: ICmpInst::isSigned(Pred: LPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9598 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9599 ConstantRange RCR = computeConstantRange(
9600 V: R1, ForSigned: ICmpInst::isSigned(Pred: RPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9601 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9602 // Even if L1/R1 are not both constant, we can still sometimes deduce
9603 // relationship from a single constant. For example X u> Y implies X != 0.
9604 if (auto R = isImpliedCondCommonOperandWithCR(LPred, LCR, RPred, RCR))
9605 return R;
9606 // If both L1/R1 were exact constant ranges and we didn't get anything
9607 // here, we won't be able to deduce this.
9608 if (match(V: L1, P: m_APInt(Res&: Unused)) && match(V: R1, P: m_APInt(Res&: Unused)))
9609 return std::nullopt;
9610 }
9611
9612 // Can we infer anything when the two compares have matching operands?
9613 if (L0 == R0 && L1 == R1)
9614 return ICmpInst::isImpliedByMatchingCmp(Pred1: LPred, Pred2: RPred);
9615
9616 // It only really makes sense in the context of signed comparison for "X - Y
9617 // must be positive if X >= Y and no overflow".
9618 // Take SGT as an example: L0:x > L1:y and C >= 0
9619 // ==> R0:(x -nsw y) < R1:(-C) is false
9620 CmpInst::Predicate SignedLPred = LPred.getPreferredSignedPredicate();
9621 if ((SignedLPred == ICmpInst::ICMP_SGT ||
9622 SignedLPred == ICmpInst::ICMP_SGE) &&
9623 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9624 if (match(V: R1, P: m_NonPositive()) &&
9625 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == false)
9626 return false;
9627 }
9628
9629 // Take SLT as an example: L0:x < L1:y and C <= 0
9630 // ==> R0:(x -nsw y) < R1:(-C) is true
9631 if ((SignedLPred == ICmpInst::ICMP_SLT ||
9632 SignedLPred == ICmpInst::ICMP_SLE) &&
9633 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9634 if (match(V: R1, P: m_NonNegative()) &&
9635 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == true)
9636 return true;
9637 }
9638
9639 // a - b == NonZero -> a != b
9640 // ptrtoint(a) - ptrtoint(b) == NonZero -> a != b
9641 const APInt *L1C;
9642 Value *A, *B;
9643 if (LPred == ICmpInst::ICMP_EQ && ICmpInst::isEquality(P: RPred) &&
9644 match(V: L1, P: m_APInt(Res&: L1C)) && !L1C->isZero() &&
9645 match(V: L0, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
9646 ((A == R0 && B == R1) || (A == R1 && B == R0) ||
9647 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0))) &&
9648 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1)))) ||
9649 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1))) &&
9650 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0)))))) {
9651 return RPred.dropSameSign() == ICmpInst::ICMP_NE;
9652 }
9653
9654 // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
9655 if (L0 == R0 &&
9656 (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
9657 (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
9658 match(V: L0, P: m_c_Add(L: m_Specific(V: L1), R: m_Specific(V: R1))))
9659 return CmpPredicate::getMatching(A: LPred, B: RPred).has_value();
9660
9661 if (auto P = CmpPredicate::getMatching(A: LPred, B: RPred))
9662 return isImpliedCondOperands(Pred: *P, ALHS: L0, ARHS: L1, BLHS: R0, BRHS: R1);
9663
9664 return std::nullopt;
9665}
9666
9667/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9668/// is true. Return false if LHS implies RHS is false. Otherwise, return
9669/// std::nullopt if we can't infer anything.
9670static std::optional<bool>
9671isImpliedCondFCmps(FCmpInst::Predicate LPred, const Value *L0, const Value *L1,
9672 FCmpInst::Predicate RPred, const Value *R0, const Value *R1,
9673 const DataLayout &DL, bool LHSIsTrue) {
9674 // The rest of the logic assumes the LHS condition is true. If that's not the
9675 // case, invert the predicate to make it so.
9676 if (!LHSIsTrue)
9677 LPred = FCmpInst::getInversePredicate(pred: LPred);
9678
9679 // We can have non-canonical operands, so try to normalize any common operand
9680 // to L0/R0.
9681 if (L0 == R1) {
9682 std::swap(a&: R0, b&: R1);
9683 RPred = FCmpInst::getSwappedPredicate(pred: RPred);
9684 }
9685 if (R0 == L1) {
9686 std::swap(a&: L0, b&: L1);
9687 LPred = FCmpInst::getSwappedPredicate(pred: LPred);
9688 }
9689 if (L1 == R1) {
9690 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9691 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9692 std::swap(a&: L0, b&: L1);
9693 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9694 std::swap(a&: R0, b&: R1);
9695 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9696 }
9697 }
9698
9699 // Can we infer anything when the two compares have matching operands?
9700 if (L0 == R0 && L1 == R1) {
9701 if ((LPred & RPred) == LPred)
9702 return true;
9703 if ((LPred & ~RPred) == LPred)
9704 return false;
9705 }
9706
9707 // See if we can infer anything if operand-0 matches and we have at least one
9708 // constant.
9709 const APFloat *L1C, *R1C;
9710 if (L0 == R0 && match(V: L1, P: m_APFloat(Res&: L1C)) && match(V: R1, P: m_APFloat(Res&: R1C))) {
9711 if (std::optional<ConstantFPRange> DomCR =
9712 ConstantFPRange::makeExactFCmpRegion(Pred: LPred, Other: *L1C)) {
9713 if (std::optional<ConstantFPRange> ImpliedCR =
9714 ConstantFPRange::makeExactFCmpRegion(Pred: RPred, Other: *R1C)) {
9715 if (ImpliedCR->contains(CR: *DomCR))
9716 return true;
9717 }
9718 if (std::optional<ConstantFPRange> ImpliedCR =
9719 ConstantFPRange::makeExactFCmpRegion(
9720 Pred: FCmpInst::getInversePredicate(pred: RPred), Other: *R1C)) {
9721 if (ImpliedCR->contains(CR: *DomCR))
9722 return false;
9723 }
9724 }
9725 }
9726
9727 return std::nullopt;
9728}
9729
9730/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is
9731/// false. Otherwise, return std::nullopt if we can't infer anything. We
9732/// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
9733/// instruction.
9734static std::optional<bool>
9735isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
9736 const Value *RHSOp0, const Value *RHSOp1,
9737 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9738 // The LHS must be an 'or', 'and', or a 'select' instruction.
9739 assert((LHS->getOpcode() == Instruction::And ||
9740 LHS->getOpcode() == Instruction::Or ||
9741 LHS->getOpcode() == Instruction::Select) &&
9742 "Expected LHS to be 'and', 'or', or 'select'.");
9743
9744 assert(Depth <= MaxAnalysisRecursionDepth && "Hit recursion limit");
9745
9746 // If the result of an 'or' is false, then we know both legs of the 'or' are
9747 // false. Similarly, if the result of an 'and' is true, then we know both
9748 // legs of the 'and' are true.
9749 const Value *ALHS, *ARHS;
9750 if ((!LHSIsTrue && match(V: LHS, P: m_LogicalOr(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS)))) ||
9751 (LHSIsTrue && match(V: LHS, P: m_LogicalAnd(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS))))) {
9752 // FIXME: Make this non-recursion.
9753 if (std::optional<bool> Implication = isImpliedCondition(
9754 LHS: ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9755 return Implication;
9756 if (std::optional<bool> Implication = isImpliedCondition(
9757 LHS: ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9758 return Implication;
9759 return std::nullopt;
9760 }
9761 return std::nullopt;
9762}
9763
9764std::optional<bool>
9765llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
9766 const Value *RHSOp0, const Value *RHSOp1,
9767 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9768 // Bail out when we hit the limit.
9769 if (Depth == MaxAnalysisRecursionDepth)
9770 return std::nullopt;
9771
9772 // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
9773 // example.
9774 if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy())
9775 return std::nullopt;
9776
9777 assert(LHS->getType()->isIntOrIntVectorTy(1) &&
9778 "Expected integer type only!");
9779
9780 // Match not
9781 if (match(V: LHS, P: m_Not(V: m_Value(V&: LHS))))
9782 LHSIsTrue = !LHSIsTrue;
9783
9784 // Both LHS and RHS are icmps.
9785 if (RHSOp0->getType()->getScalarType()->isIntOrPtrTy()) {
9786 if (const auto *LHSCmp = dyn_cast<ICmpInst>(Val: LHS))
9787 return isImpliedCondICmps(LPred: LHSCmp->getCmpPredicate(),
9788 L0: LHSCmp->getOperand(i_nocapture: 0), L1: LHSCmp->getOperand(i_nocapture: 1),
9789 RPred: RHSPred, R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9790 const Value *V;
9791 if (match(V: LHS, P: m_NUWTrunc(Op: m_Value(V))))
9792 return isImpliedCondICmps(LPred: CmpInst::ICMP_NE, L0: V,
9793 L1: ConstantInt::get(Ty: V->getType(), V: 0), RPred: RHSPred,
9794 R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9795 } else {
9796 assert(RHSOp0->getType()->isFPOrFPVectorTy() &&
9797 "Expected floating point type only!");
9798 if (const auto *LHSCmp = dyn_cast<FCmpInst>(Val: LHS))
9799 return isImpliedCondFCmps(LPred: LHSCmp->getPredicate(), L0: LHSCmp->getOperand(i_nocapture: 0),
9800 L1: LHSCmp->getOperand(i_nocapture: 1), RPred: RHSPred, R0: RHSOp0, R1: RHSOp1,
9801 DL, LHSIsTrue);
9802 }
9803
9804 /// The LHS should be an 'or', 'and', or a 'select' instruction. We expect
9805 /// the RHS to be an icmp.
9806 /// FIXME: Add support for and/or/select on the RHS.
9807 if (const Instruction *LHSI = dyn_cast<Instruction>(Val: LHS)) {
9808 if ((LHSI->getOpcode() == Instruction::And ||
9809 LHSI->getOpcode() == Instruction::Or ||
9810 LHSI->getOpcode() == Instruction::Select))
9811 return isImpliedCondAndOr(LHS: LHSI, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
9812 Depth);
9813 }
9814 return std::nullopt;
9815}
9816
9817std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
9818 const DataLayout &DL,
9819 bool LHSIsTrue, unsigned Depth) {
9820 // LHS ==> RHS by definition
9821 if (LHS == RHS)
9822 return LHSIsTrue;
9823
9824 // Match not
9825 bool InvertRHS = false;
9826 if (match(V: RHS, P: m_Not(V: m_Value(V&: RHS)))) {
9827 if (LHS == RHS)
9828 return !LHSIsTrue;
9829 InvertRHS = true;
9830 }
9831
9832 if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(Val: RHS)) {
9833 if (auto Implied = isImpliedCondition(
9834 LHS, RHSPred: RHSCmp->getCmpPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9835 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9836 return InvertRHS ? !*Implied : *Implied;
9837 return std::nullopt;
9838 }
9839 if (const FCmpInst *RHSCmp = dyn_cast<FCmpInst>(Val: RHS)) {
9840 if (auto Implied = isImpliedCondition(
9841 LHS, RHSPred: RHSCmp->getPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9842 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9843 return InvertRHS ? !*Implied : *Implied;
9844 return std::nullopt;
9845 }
9846
9847 const Value *V;
9848 if (match(V: RHS, P: m_NUWTrunc(Op: m_Value(V)))) {
9849 if (auto Implied = isImpliedCondition(LHS, RHSPred: CmpInst::ICMP_NE, RHSOp0: V,
9850 RHSOp1: ConstantInt::get(Ty: V->getType(), V: 0), DL,
9851 LHSIsTrue, Depth))
9852 return InvertRHS ? !*Implied : *Implied;
9853 return std::nullopt;
9854 }
9855
9856 if (Depth == MaxAnalysisRecursionDepth)
9857 return std::nullopt;
9858
9859 // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2
9860 // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2
9861 const Value *RHS1, *RHS2;
9862 if (match(V: RHS, P: m_LogicalOr(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9863 if (std::optional<bool> Imp =
9864 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9865 if (*Imp == true)
9866 return !InvertRHS;
9867 if (std::optional<bool> Imp =
9868 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9869 if (*Imp == true)
9870 return !InvertRHS;
9871 }
9872 if (match(V: RHS, P: m_LogicalAnd(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9873 if (std::optional<bool> Imp =
9874 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9875 if (*Imp == false)
9876 return InvertRHS;
9877 if (std::optional<bool> Imp =
9878 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9879 if (*Imp == false)
9880 return InvertRHS;
9881 }
9882
9883 return std::nullopt;
9884}
9885
9886// Returns a pair (Condition, ConditionIsTrue), where Condition is a branch
9887// condition dominating ContextI or nullptr, if no condition is found.
9888static std::pair<Value *, bool>
9889getDomPredecessorCondition(const Instruction *ContextI) {
9890 if (!ContextI || !ContextI->getParent())
9891 return {nullptr, false};
9892
9893 // TODO: This is a poor/cheap way to determine dominance. Should we use a
9894 // dominator tree (eg, from a SimplifyQuery) instead?
9895 const BasicBlock *ContextBB = ContextI->getParent();
9896 const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
9897 if (!PredBB)
9898 return {nullptr, false};
9899
9900 // We need a conditional branch in the predecessor.
9901 Value *PredCond;
9902 BasicBlock *TrueBB, *FalseBB;
9903 if (!match(V: PredBB->getTerminator(), P: m_Br(C: m_Value(V&: PredCond), T&: TrueBB, F&: FalseBB)))
9904 return {nullptr, false};
9905
9906 // The branch should get simplified. Don't bother simplifying this condition.
9907 if (TrueBB == FalseBB)
9908 return {nullptr, false};
9909
9910 assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
9911 "Predecessor block does not point to successor?");
9912
9913 // Is this condition implied by the predecessor condition?
9914 return {PredCond, TrueBB == ContextBB};
9915}
9916
9917std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
9918 const Instruction *ContextI,
9919 const DataLayout &DL) {
9920 assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
9921 auto PredCond = getDomPredecessorCondition(ContextI);
9922 if (PredCond.first)
9923 return isImpliedCondition(LHS: PredCond.first, RHS: Cond, DL, LHSIsTrue: PredCond.second);
9924 return std::nullopt;
9925}
9926
9927std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
9928 const Value *LHS,
9929 const Value *RHS,
9930 const Instruction *ContextI,
9931 const DataLayout &DL) {
9932 auto PredCond = getDomPredecessorCondition(ContextI);
9933 if (PredCond.first)
9934 return isImpliedCondition(LHS: PredCond.first, RHSPred: Pred, RHSOp0: LHS, RHSOp1: RHS, DL,
9935 LHSIsTrue: PredCond.second);
9936 return std::nullopt;
9937}
9938
9939static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
9940 APInt &Upper, const InstrInfoQuery &IIQ,
9941 bool PreferSignedRange) {
9942 unsigned Width = Lower.getBitWidth();
9943 const APInt *C;
9944 switch (BO.getOpcode()) {
9945 case Instruction::Sub:
9946 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9947 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9948 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9949
9950 // If the caller expects a signed compare, then try to use a signed range.
9951 // Otherwise if both no-wraps are set, use the unsigned range because it
9952 // is never larger than the signed range. Example:
9953 // "sub nuw nsw i8 -2, x" is unsigned [0, 254] vs. signed [-128, 126].
9954 // "sub nuw nsw i8 2, x" is unsigned [0, 2] vs. signed [-125, 127].
9955 if (PreferSignedRange && HasNSW && HasNUW)
9956 HasNUW = false;
9957
9958 if (HasNUW) {
9959 // 'sub nuw c, x' produces [0, C].
9960 Upper = *C + 1;
9961 } else if (HasNSW) {
9962 if (C->isNegative()) {
9963 // 'sub nsw -C, x' produces [SINT_MIN, -C - SINT_MIN].
9964 Lower = APInt::getSignedMinValue(numBits: Width);
9965 Upper = *C - APInt::getSignedMaxValue(numBits: Width);
9966 } else {
9967 // Note that sub 0, INT_MIN is not NSW. It techically is a signed wrap
9968 // 'sub nsw C, x' produces [C - SINT_MAX, SINT_MAX].
9969 Lower = *C - APInt::getSignedMaxValue(numBits: Width);
9970 Upper = APInt::getSignedMinValue(numBits: Width);
9971 }
9972 }
9973 }
9974 break;
9975 case Instruction::Add:
9976 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
9977 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9978 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9979
9980 // If the caller expects a signed compare, then try to use a signed
9981 // range. Otherwise if both no-wraps are set, use the unsigned range
9982 // because it is never larger than the signed range. Example: "add nuw
9983 // nsw i8 X, -2" is unsigned [254,255] vs. signed [-128, 125].
9984 if (PreferSignedRange && HasNSW && HasNUW)
9985 HasNUW = false;
9986
9987 if (HasNUW) {
9988 // 'add nuw x, C' produces [C, UINT_MAX].
9989 Lower = *C;
9990 } else if (HasNSW) {
9991 if (C->isNegative()) {
9992 // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
9993 Lower = APInt::getSignedMinValue(numBits: Width);
9994 Upper = APInt::getSignedMaxValue(numBits: Width) + *C + 1;
9995 } else {
9996 // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
9997 Lower = APInt::getSignedMinValue(numBits: Width) + *C;
9998 Upper = APInt::getSignedMaxValue(numBits: Width) + 1;
9999 }
10000 }
10001 }
10002 break;
10003
10004 case Instruction::And:
10005 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10006 // 'and x, C' produces [0, C].
10007 Upper = *C + 1;
10008 // X & -X is a power of two or zero. So we can cap the value at max power of
10009 // two.
10010 if (match(V: BO.getOperand(i_nocapture: 0), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 1)))) ||
10011 match(V: BO.getOperand(i_nocapture: 1), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 0)))))
10012 Upper = APInt::getSignedMinValue(numBits: Width) + 1;
10013 break;
10014
10015 case Instruction::Or:
10016 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10017 // 'or x, C' produces [C, UINT_MAX].
10018 Lower = *C;
10019 break;
10020
10021 case Instruction::AShr:
10022 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
10023 // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
10024 Lower = APInt::getSignedMinValue(numBits: Width).ashr(ShiftAmt: *C);
10025 Upper = APInt::getSignedMaxValue(numBits: Width).ashr(ShiftAmt: *C) + 1;
10026 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10027 unsigned ShiftAmount = Width - 1;
10028 if (!C->isZero() && IIQ.isExact(Op: &BO))
10029 ShiftAmount = C->countr_zero();
10030 if (C->isNegative()) {
10031 // 'ashr C, x' produces [C, C >> (Width-1)]
10032 Lower = *C;
10033 Upper = C->ashr(ShiftAmt: ShiftAmount) + 1;
10034 } else {
10035 // 'ashr C, x' produces [C >> (Width-1), C]
10036 Lower = C->ashr(ShiftAmt: ShiftAmount);
10037 Upper = *C + 1;
10038 }
10039 }
10040 break;
10041
10042 case Instruction::LShr:
10043 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
10044 // 'lshr x, C' produces [0, UINT_MAX >> C].
10045 Upper = APInt::getAllOnes(numBits: Width).lshr(ShiftAmt: *C) + 1;
10046 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10047 // 'lshr C, x' produces [C >> (Width-1), C].
10048 unsigned ShiftAmount = Width - 1;
10049 if (!C->isZero() && IIQ.isExact(Op: &BO))
10050 ShiftAmount = C->countr_zero();
10051 Lower = C->lshr(shiftAmt: ShiftAmount);
10052 Upper = *C + 1;
10053 }
10054 break;
10055
10056 case Instruction::Shl:
10057 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10058 if (IIQ.hasNoUnsignedWrap(Op: &BO)) {
10059 // 'shl nuw C, x' produces [C, C << CLZ(C)]
10060 Lower = *C;
10061 Upper = Lower.shl(shiftAmt: Lower.countl_zero()) + 1;
10062 } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
10063 if (C->isNegative()) {
10064 // 'shl nsw C, x' produces [C << CLO(C)-1, C]
10065 unsigned ShiftAmount = C->countl_one() - 1;
10066 Lower = C->shl(shiftAmt: ShiftAmount);
10067 Upper = *C + 1;
10068 } else {
10069 // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
10070 unsigned ShiftAmount = C->countl_zero() - 1;
10071 Lower = *C;
10072 Upper = C->shl(shiftAmt: ShiftAmount) + 1;
10073 }
10074 } else {
10075 // If lowbit is set, value can never be zero.
10076 if ((*C)[0])
10077 Lower = APInt::getOneBitSet(numBits: Width, BitNo: 0);
10078 // If we are shifting a constant the largest it can be is if the longest
10079 // sequence of consecutive ones is shifted to the highbits (breaking
10080 // ties for which sequence is higher). At the moment we take a liberal
10081 // upper bound on this by just popcounting the constant.
10082 // TODO: There may be a bitwise trick for it longest/highest
10083 // consecutative sequence of ones (naive method is O(Width) loop).
10084 Upper = APInt::getHighBitsSet(numBits: Width, hiBitsSet: C->popcount()) + 1;
10085 }
10086 } else if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
10087 Upper = APInt::getBitsSetFrom(numBits: Width, loBit: C->getZExtValue()) + 1;
10088 }
10089 break;
10090
10091 case Instruction::SDiv:
10092 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10093 APInt IntMin = APInt::getSignedMinValue(numBits: Width);
10094 APInt IntMax = APInt::getSignedMaxValue(numBits: Width);
10095 if (C->isAllOnes()) {
10096 // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
10097 // where C != -1 and C != 0 and C != 1
10098 Lower = IntMin + 1;
10099 Upper = IntMax + 1;
10100 } else if (C->countl_zero() < Width - 1) {
10101 // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
10102 // where C != -1 and C != 0 and C != 1
10103 Lower = IntMin.sdiv(RHS: *C);
10104 Upper = IntMax.sdiv(RHS: *C);
10105 if (Lower.sgt(RHS: Upper))
10106 std::swap(a&: Lower, b&: Upper);
10107 Upper = Upper + 1;
10108 assert(Upper != Lower && "Upper part of range has wrapped!");
10109 }
10110 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10111 if (C->isMinSignedValue()) {
10112 // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
10113 Lower = *C;
10114 Upper = Lower.lshr(shiftAmt: 1) + 1;
10115 } else {
10116 // 'sdiv C, x' produces [-|C|, |C|].
10117 Upper = C->abs() + 1;
10118 Lower = (-Upper) + 1;
10119 }
10120 }
10121 break;
10122
10123 case Instruction::UDiv:
10124 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
10125 // 'udiv x, C' produces [0, UINT_MAX / C].
10126 Upper = APInt::getMaxValue(numBits: Width).udiv(RHS: *C) + 1;
10127 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10128 // 'udiv C, x' produces [0, C].
10129 Upper = *C + 1;
10130 }
10131 break;
10132
10133 case Instruction::SRem:
10134 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10135 // 'srem x, C' produces (-|C|, |C|).
10136 Upper = C->abs();
10137 Lower = (-Upper) + 1;
10138 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10139 if (C->isNegative()) {
10140 // 'srem -|C|, x' produces [-|C|, 0].
10141 Upper = 1;
10142 Lower = *C;
10143 } else {
10144 // 'srem |C|, x' produces [0, |C|].
10145 Upper = *C + 1;
10146 }
10147 }
10148 break;
10149
10150 case Instruction::URem:
10151 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10152 // 'urem x, C' produces [0, C).
10153 Upper = *C;
10154 else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10155 // 'urem C, x' produces [0, C].
10156 Upper = *C + 1;
10157 break;
10158
10159 default:
10160 break;
10161 }
10162}
10163
10164static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II,
10165 bool UseInstrInfo) {
10166 unsigned Width = II.getType()->getScalarSizeInBits();
10167 const APInt *C;
10168 switch (II.getIntrinsicID()) {
10169 case Intrinsic::ctlz:
10170 case Intrinsic::cttz: {
10171 APInt Upper(Width, Width);
10172 if (!UseInstrInfo || !match(V: II.getArgOperand(i: 1), P: m_One()))
10173 Upper += 1;
10174 // Maximum of set/clear bits is the bit width.
10175 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper);
10176 }
10177 case Intrinsic::ctpop:
10178 // Maximum of set/clear bits is the bit width.
10179 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10180 Upper: APInt(Width, Width) + 1);
10181 case Intrinsic::uadd_sat:
10182 // uadd.sat(x, C) produces [C, UINT_MAX].
10183 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10184 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10185 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10186 break;
10187 case Intrinsic::sadd_sat:
10188 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10189 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10190 if (C->isNegative())
10191 // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
10192 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10193 Upper: APInt::getSignedMaxValue(numBits: Width) + *C +
10194 1);
10195
10196 // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
10197 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) + *C,
10198 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10199 }
10200 break;
10201 case Intrinsic::usub_sat:
10202 // usub.sat(C, x) produces [0, C].
10203 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10204 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10205
10206 // usub.sat(x, C) produces [0, UINT_MAX - C].
10207 if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10208 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10209 Upper: APInt::getMaxValue(numBits: Width) - *C + 1);
10210 break;
10211 case Intrinsic::ssub_sat:
10212 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10213 if (C->isNegative())
10214 // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
10215 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10216 Upper: *C - APInt::getSignedMinValue(numBits: Width) +
10217 1);
10218
10219 // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
10220 return ConstantRange::getNonEmpty(Lower: *C - APInt::getSignedMaxValue(numBits: Width),
10221 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10222 } else if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10223 if (C->isNegative())
10224 // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
10225 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) - *C,
10226 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10227
10228 // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
10229 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10230 Upper: APInt::getSignedMaxValue(numBits: Width) - *C +
10231 1);
10232 }
10233 break;
10234 case Intrinsic::umin:
10235 case Intrinsic::umax:
10236 case Intrinsic::smin:
10237 case Intrinsic::smax:
10238 if (!match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) &&
10239 !match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10240 break;
10241
10242 switch (II.getIntrinsicID()) {
10243 case Intrinsic::umin:
10244 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10245 case Intrinsic::umax:
10246 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10247 case Intrinsic::smin:
10248 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10249 Upper: *C + 1);
10250 case Intrinsic::smax:
10251 return ConstantRange::getNonEmpty(Lower: *C,
10252 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10253 default:
10254 llvm_unreachable("Must be min/max intrinsic");
10255 }
10256 break;
10257 case Intrinsic::abs:
10258 // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX],
10259 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10260 if (match(V: II.getOperand(i_nocapture: 1), P: m_One()))
10261 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10262 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10263
10264 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10265 Upper: APInt::getSignedMinValue(numBits: Width) + 1);
10266 case Intrinsic::vscale:
10267 if (!II.getParent() || !II.getFunction())
10268 break;
10269 return getVScaleRange(F: II.getFunction(), BitWidth: Width);
10270 default:
10271 break;
10272 }
10273
10274 return ConstantRange::getFull(BitWidth: Width);
10275}
10276
10277static ConstantRange getRangeForSelectPattern(const SelectInst &SI,
10278 const InstrInfoQuery &IIQ) {
10279 unsigned BitWidth = SI.getType()->getScalarSizeInBits();
10280 const Value *LHS = nullptr, *RHS = nullptr;
10281 SelectPatternResult R = matchSelectPattern(V: &SI, LHS, RHS);
10282 if (R.Flavor == SPF_UNKNOWN)
10283 return ConstantRange::getFull(BitWidth);
10284
10285 if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
10286 // If the negation part of the abs (in RHS) has the NSW flag,
10287 // then the result of abs(X) is [0..SIGNED_MAX],
10288 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10289 if (match(V: RHS, P: m_Neg(V: m_Specific(V: LHS))) &&
10290 IIQ.hasNoSignedWrap(Op: cast<Instruction>(Val: RHS)))
10291 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10292 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10293
10294 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10295 Upper: APInt::getSignedMinValue(numBits: BitWidth) + 1);
10296 }
10297
10298 if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
10299 // The result of -abs(X) is <= 0.
10300 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10301 Upper: APInt(BitWidth, 1));
10302 }
10303
10304 const APInt *C;
10305 if (!match(V: LHS, P: m_APInt(Res&: C)) && !match(V: RHS, P: m_APInt(Res&: C)))
10306 return ConstantRange::getFull(BitWidth);
10307
10308 switch (R.Flavor) {
10309 case SPF_UMIN:
10310 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth), Upper: *C + 1);
10311 case SPF_UMAX:
10312 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: BitWidth));
10313 case SPF_SMIN:
10314 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10315 Upper: *C + 1);
10316 case SPF_SMAX:
10317 return ConstantRange::getNonEmpty(Lower: *C,
10318 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10319 default:
10320 return ConstantRange::getFull(BitWidth);
10321 }
10322}
10323
10324static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
10325 // The maximum representable value of a half is 65504. For floats the maximum
10326 // value is 3.4e38 which requires roughly 129 bits.
10327 unsigned BitWidth = I->getType()->getScalarSizeInBits();
10328 if (!I->getOperand(i: 0)->getType()->getScalarType()->isHalfTy())
10329 return;
10330 if (isa<FPToSIInst>(Val: I) && BitWidth >= 17) {
10331 Lower = APInt(BitWidth, -65504, true);
10332 Upper = APInt(BitWidth, 65505);
10333 }
10334
10335 if (isa<FPToUIInst>(Val: I) && BitWidth >= 16) {
10336 // For a fptoui the lower limit is left as 0.
10337 Upper = APInt(BitWidth, 65505);
10338 }
10339}
10340
10341ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
10342 bool UseInstrInfo, AssumptionCache *AC,
10343 const Instruction *CtxI,
10344 const DominatorTree *DT,
10345 unsigned Depth) {
10346 assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
10347
10348 if (Depth == MaxAnalysisRecursionDepth)
10349 return ConstantRange::getFull(BitWidth: V->getType()->getScalarSizeInBits());
10350
10351 if (auto *C = dyn_cast<Constant>(Val: V))
10352 return C->toConstantRange();
10353
10354 unsigned BitWidth = V->getType()->getScalarSizeInBits();
10355 InstrInfoQuery IIQ(UseInstrInfo);
10356 ConstantRange CR = ConstantRange::getFull(BitWidth);
10357 if (auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10358 APInt Lower = APInt(BitWidth, 0);
10359 APInt Upper = APInt(BitWidth, 0);
10360 // TODO: Return ConstantRange.
10361 setLimitsForBinOp(BO: *BO, Lower, Upper, IIQ, PreferSignedRange: ForSigned);
10362 CR = ConstantRange::getNonEmpty(Lower, Upper);
10363 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: V))
10364 CR = getRangeForIntrinsic(II: *II, UseInstrInfo);
10365 else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
10366 ConstantRange CRTrue = computeConstantRange(
10367 V: SI->getTrueValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10368 ConstantRange CRFalse = computeConstantRange(
10369 V: SI->getFalseValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10370 CR = CRTrue.unionWith(CR: CRFalse);
10371 CR = CR.intersectWith(CR: getRangeForSelectPattern(SI: *SI, IIQ));
10372 } else if (isa<FPToUIInst>(Val: V) || isa<FPToSIInst>(Val: V)) {
10373 APInt Lower = APInt(BitWidth, 0);
10374 APInt Upper = APInt(BitWidth, 0);
10375 // TODO: Return ConstantRange.
10376 setLimitForFPToI(I: cast<Instruction>(Val: V), Lower, Upper);
10377 CR = ConstantRange::getNonEmpty(Lower, Upper);
10378 } else if (const auto *A = dyn_cast<Argument>(Val: V))
10379 if (std::optional<ConstantRange> Range = A->getRange())
10380 CR = *Range;
10381
10382 if (auto *I = dyn_cast<Instruction>(Val: V)) {
10383 if (auto *Range = IIQ.getMetadata(I, KindID: LLVMContext::MD_range))
10384 CR = CR.intersectWith(CR: getConstantRangeFromMetadata(RangeMD: *Range));
10385
10386 if (const auto *CB = dyn_cast<CallBase>(Val: V))
10387 if (std::optional<ConstantRange> Range = CB->getRange())
10388 CR = CR.intersectWith(CR: *Range);
10389 }
10390
10391 if (CtxI && AC) {
10392 // Try to restrict the range based on information from assumptions.
10393 for (auto &AssumeVH : AC->assumptionsFor(V)) {
10394 if (!AssumeVH)
10395 continue;
10396 CallInst *I = cast<CallInst>(Val&: AssumeVH);
10397 assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
10398 "Got assumption for the wrong function!");
10399 assert(I->getIntrinsicID() == Intrinsic::assume &&
10400 "must be an assume intrinsic");
10401
10402 if (!isValidAssumeForContext(Inv: I, CxtI: CtxI, DT))
10403 continue;
10404 Value *Arg = I->getArgOperand(i: 0);
10405 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
10406 // Currently we just use information from comparisons.
10407 if (!Cmp || Cmp->getOperand(i_nocapture: 0) != V)
10408 continue;
10409 // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
10410 ConstantRange RHS =
10411 computeConstantRange(V: Cmp->getOperand(i_nocapture: 1), /* ForSigned */ false,
10412 UseInstrInfo, AC, CtxI: I, DT, Depth: Depth + 1);
10413 CR = CR.intersectWith(
10414 CR: ConstantRange::makeAllowedICmpRegion(Pred: Cmp->getPredicate(), Other: RHS));
10415 }
10416 }
10417
10418 return CR;
10419}
10420
10421static void
10422addValueAffectedByCondition(Value *V,
10423 function_ref<void(Value *)> InsertAffected) {
10424 assert(V != nullptr);
10425 if (isa<Argument>(Val: V) || isa<GlobalValue>(Val: V)) {
10426 InsertAffected(V);
10427 } else if (auto *I = dyn_cast<Instruction>(Val: V)) {
10428 InsertAffected(V);
10429
10430 // Peek through unary operators to find the source of the condition.
10431 Value *Op;
10432 if (match(V: I, P: m_CombineOr(L: m_PtrToIntOrAddr(Op: m_Value(V&: Op)),
10433 R: m_Trunc(Op: m_Value(V&: Op))))) {
10434 if (isa<Instruction>(Val: Op) || isa<Argument>(Val: Op))
10435 InsertAffected(Op);
10436 }
10437 }
10438}
10439
10440void llvm::findValuesAffectedByCondition(
10441 Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10442 auto AddAffected = [&InsertAffected](Value *V) {
10443 addValueAffectedByCondition(V, InsertAffected);
10444 };
10445
10446 auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10447 if (IsAssume) {
10448 AddAffected(LHS);
10449 AddAffected(RHS);
10450 } else if (match(V: RHS, P: m_Constant()))
10451 AddAffected(LHS);
10452 };
10453
10454 SmallVector<Value *, 8> Worklist;
10455 SmallPtrSet<Value *, 8> Visited;
10456 Worklist.push_back(Elt: Cond);
10457 while (!Worklist.empty()) {
10458 Value *V = Worklist.pop_back_val();
10459 if (!Visited.insert(Ptr: V).second)
10460 continue;
10461
10462 CmpPredicate Pred;
10463 Value *A, *B, *X;
10464
10465 if (IsAssume) {
10466 AddAffected(V);
10467 if (match(V, P: m_Not(V: m_Value(V&: X))))
10468 AddAffected(X);
10469 }
10470
10471 if (match(V, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
10472 // assume(A && B) is split to -> assume(A); assume(B);
10473 // assume(!(A || B)) is split to -> assume(!A); assume(!B);
10474 // Finally, assume(A || B) / assume(!(A && B)) generally don't provide
10475 // enough information to be worth handling (intersection of information as
10476 // opposed to union).
10477 if (!IsAssume) {
10478 Worklist.push_back(Elt: A);
10479 Worklist.push_back(Elt: B);
10480 }
10481 } else if (match(V, P: m_ICmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10482 bool HasRHSC = match(V: B, P: m_ConstantInt());
10483 if (ICmpInst::isEquality(P: Pred)) {
10484 AddAffected(A);
10485 if (IsAssume)
10486 AddAffected(B);
10487 if (HasRHSC) {
10488 Value *Y;
10489 // (X << C) or (X >>_s C) or (X >>_u C).
10490 if (match(V: A, P: m_Shift(L: m_Value(V&: X), R: m_ConstantInt())))
10491 AddAffected(X);
10492 // (X & C) or (X | C).
10493 else if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10494 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10495 AddAffected(X);
10496 AddAffected(Y);
10497 }
10498 // X - Y
10499 else if (match(V: A, P: m_Sub(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10500 AddAffected(X);
10501 AddAffected(Y);
10502 }
10503 }
10504 } else {
10505 AddCmpOperands(A, B);
10506 if (HasRHSC) {
10507 // Handle (A + C1) u< C2, which is the canonical form of
10508 // A > C3 && A < C4.
10509 if (match(V: A, P: m_AddLike(L: m_Value(V&: X), R: m_ConstantInt())))
10510 AddAffected(X);
10511
10512 if (ICmpInst::isUnsigned(Pred)) {
10513 Value *Y;
10514 // X & Y u> C -> X >u C && Y >u C
10515 // X | Y u< C -> X u< C && Y u< C
10516 // X nuw+ Y u< C -> X u< C && Y u< C
10517 if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10518 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10519 match(V: A, P: m_NUWAdd(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10520 AddAffected(X);
10521 AddAffected(Y);
10522 }
10523 // X nuw- Y u> C -> X u> C
10524 if (match(V: A, P: m_NUWSub(L: m_Value(V&: X), R: m_Value())))
10525 AddAffected(X);
10526 }
10527 }
10528
10529 // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
10530 // by computeKnownFPClass().
10531 if (match(V: A, P: m_ElementWiseBitCast(Op: m_Value(V&: X)))) {
10532 if (Pred == ICmpInst::ICMP_SLT && match(V: B, P: m_Zero()))
10533 InsertAffected(X);
10534 else if (Pred == ICmpInst::ICMP_SGT && match(V: B, P: m_AllOnes()))
10535 InsertAffected(X);
10536 }
10537 }
10538
10539 if (HasRHSC && match(V: A, P: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Value(V&: X))))
10540 AddAffected(X);
10541 } else if (match(V, P: m_FCmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10542 AddCmpOperands(A, B);
10543
10544 // fcmp fneg(x), y
10545 // fcmp fabs(x), y
10546 // fcmp fneg(fabs(x)), y
10547 if (match(V: A, P: m_FNeg(X: m_Value(V&: A))))
10548 AddAffected(A);
10549 if (match(V: A, P: m_FAbs(Op0: m_Value(V&: A))))
10550 AddAffected(A);
10551
10552 } else if (match(V, P: m_Intrinsic<Intrinsic::is_fpclass>(Op0: m_Value(V&: A),
10553 Op1: m_Value()))) {
10554 // Handle patterns that computeKnownFPClass() support.
10555 AddAffected(A);
10556 } else if (!IsAssume && match(V, P: m_Trunc(Op: m_Value(V&: X)))) {
10557 // Assume is checked here as X is already added above for assumes in
10558 // addValueAffectedByCondition
10559 AddAffected(X);
10560 } else if (!IsAssume && match(V, P: m_Not(V: m_Value(V&: X)))) {
10561 // Assume is checked here to avoid issues with ephemeral values
10562 Worklist.push_back(Elt: X);
10563 }
10564 }
10565}
10566
10567const Value *llvm::stripNullTest(const Value *V) {
10568 // (X >> C) or/add (X & mask(C) != 0)
10569 if (const auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10570 if (BO->getOpcode() == Instruction::Add ||
10571 BO->getOpcode() == Instruction::Or) {
10572 const Value *X;
10573 const APInt *C1, *C2;
10574 if (match(V: BO, P: m_c_BinOp(L: m_LShr(L: m_Value(V&: X), R: m_APInt(Res&: C1)),
10575 R: m_ZExt(Op: m_SpecificICmp(
10576 MatchPred: ICmpInst::ICMP_NE,
10577 L: m_And(L: m_Deferred(V: X), R: m_LowBitMask(V&: C2)),
10578 R: m_Zero())))) &&
10579 C2->popcount() == C1->getZExtValue())
10580 return X;
10581 }
10582 }
10583 return nullptr;
10584}
10585
10586Value *llvm::stripNullTest(Value *V) {
10587 return const_cast<Value *>(stripNullTest(V: const_cast<const Value *>(V)));
10588}
10589
10590bool llvm::collectPossibleValues(const Value *V,
10591 SmallPtrSetImpl<const Constant *> &Constants,
10592 unsigned MaxCount, bool AllowUndefOrPoison) {
10593 SmallPtrSet<const Instruction *, 8> Visited;
10594 SmallVector<const Instruction *, 8> Worklist;
10595 auto Push = [&](const Value *V) -> bool {
10596 Constant *C;
10597 if (match(V: const_cast<Value *>(V), P: m_ImmConstant(C))) {
10598 if (!AllowUndefOrPoison && !isGuaranteedNotToBeUndefOrPoison(V: C))
10599 return false;
10600 // Check existence first to avoid unnecessary allocations.
10601 if (Constants.contains(Ptr: C))
10602 return true;
10603 if (Constants.size() == MaxCount)
10604 return false;
10605 Constants.insert(Ptr: C);
10606 return true;
10607 }
10608
10609 if (auto *Inst = dyn_cast<Instruction>(Val: V)) {
10610 if (Visited.insert(Ptr: Inst).second)
10611 Worklist.push_back(Elt: Inst);
10612 return true;
10613 }
10614 return false;
10615 };
10616 if (!Push(V))
10617 return false;
10618 while (!Worklist.empty()) {
10619 const Instruction *CurInst = Worklist.pop_back_val();
10620 switch (CurInst->getOpcode()) {
10621 case Instruction::Select:
10622 if (!Push(CurInst->getOperand(i: 1)))
10623 return false;
10624 if (!Push(CurInst->getOperand(i: 2)))
10625 return false;
10626 break;
10627 case Instruction::PHI:
10628 for (Value *IncomingValue : cast<PHINode>(Val: CurInst)->incoming_values()) {
10629 // Fast path for recurrence PHI.
10630 if (IncomingValue == CurInst)
10631 continue;
10632 if (!Push(IncomingValue))
10633 return false;
10634 }
10635 break;
10636 default:
10637 return false;
10638 }
10639 }
10640 return true;
10641}
10642