1//===- ValueTracking.cpp - Walk computations to compute properties --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains routines that help analyze properties that chains of
10// computations have.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/ValueTracking.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/FloatingPointMode.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/ADT/iterator_range.h"
25#include "llvm/Analysis/AliasAnalysis.h"
26#include "llvm/Analysis/AssumeBundleQueries.h"
27#include "llvm/Analysis/AssumptionCache.h"
28#include "llvm/Analysis/ConstantFolding.h"
29#include "llvm/Analysis/DomConditionCache.h"
30#include "llvm/Analysis/FloatingPointPredicateUtils.h"
31#include "llvm/Analysis/GuardUtils.h"
32#include "llvm/Analysis/InstructionSimplify.h"
33#include "llvm/Analysis/Loads.h"
34#include "llvm/Analysis/LoopInfo.h"
35#include "llvm/Analysis/TargetLibraryInfo.h"
36#include "llvm/Analysis/VectorUtils.h"
37#include "llvm/Analysis/WithCache.h"
38#include "llvm/IR/Argument.h"
39#include "llvm/IR/Attributes.h"
40#include "llvm/IR/BasicBlock.h"
41#include "llvm/IR/Constant.h"
42#include "llvm/IR/ConstantFPRange.h"
43#include "llvm/IR/ConstantRange.h"
44#include "llvm/IR/Constants.h"
45#include "llvm/IR/DerivedTypes.h"
46#include "llvm/IR/DiagnosticInfo.h"
47#include "llvm/IR/Dominators.h"
48#include "llvm/IR/EHPersonalities.h"
49#include "llvm/IR/Function.h"
50#include "llvm/IR/GetElementPtrTypeIterator.h"
51#include "llvm/IR/GlobalAlias.h"
52#include "llvm/IR/GlobalValue.h"
53#include "llvm/IR/GlobalVariable.h"
54#include "llvm/IR/InstrTypes.h"
55#include "llvm/IR/Instruction.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/IntrinsicInst.h"
58#include "llvm/IR/Intrinsics.h"
59#include "llvm/IR/IntrinsicsAArch64.h"
60#include "llvm/IR/IntrinsicsAMDGPU.h"
61#include "llvm/IR/IntrinsicsRISCV.h"
62#include "llvm/IR/IntrinsicsX86.h"
63#include "llvm/IR/LLVMContext.h"
64#include "llvm/IR/Metadata.h"
65#include "llvm/IR/Module.h"
66#include "llvm/IR/Operator.h"
67#include "llvm/IR/PatternMatch.h"
68#include "llvm/IR/Type.h"
69#include "llvm/IR/User.h"
70#include "llvm/IR/Value.h"
71#include "llvm/Support/Casting.h"
72#include "llvm/Support/CommandLine.h"
73#include "llvm/Support/Compiler.h"
74#include "llvm/Support/ErrorHandling.h"
75#include "llvm/Support/KnownBits.h"
76#include "llvm/Support/KnownFPClass.h"
77#include "llvm/Support/MathExtras.h"
78#include "llvm/TargetParser/RISCVTargetParser.h"
79#include <algorithm>
80#include <cassert>
81#include <cstdint>
82#include <optional>
83#include <utility>
84
85using namespace llvm;
86using namespace llvm::PatternMatch;
87
88// Controls the number of uses of the value searched for possible
89// dominating comparisons.
90static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
91 cl::Hidden, cl::init(Val: 20));
92
93/// Maximum number of instructions to check between assume and context
94/// instruction.
95static constexpr unsigned MaxInstrsToCheckForFree = 16;
96
97/// Returns the bitwidth of the given scalar or pointer type. For vector types,
98/// returns the element type's bitwidth.
99static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
100 if (unsigned BitWidth = Ty->getScalarSizeInBits())
101 return BitWidth;
102
103 return DL.getPointerTypeSizeInBits(Ty);
104}
105
106// Given the provided Value and, potentially, a context instruction, return
107// the preferred context instruction (if any).
108static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
109 // If we've been provided with a context instruction, then use that (provided
110 // it has been inserted).
111 if (CxtI && CxtI->getParent())
112 return CxtI;
113
114 // If the value is really an already-inserted instruction, then use that.
115 CxtI = dyn_cast<Instruction>(Val: V);
116 if (CxtI && CxtI->getParent())
117 return CxtI;
118
119 return nullptr;
120}
121
122static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
123 const APInt &DemandedElts,
124 APInt &DemandedLHS, APInt &DemandedRHS) {
125 if (isa<ScalableVectorType>(Val: Shuf->getType())) {
126 assert(DemandedElts == APInt(1,1));
127 DemandedLHS = DemandedRHS = DemandedElts;
128 return true;
129 }
130
131 int NumElts =
132 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
133 return llvm::getShuffleDemandedElts(SrcWidth: NumElts, Mask: Shuf->getShuffleMask(),
134 DemandedElts, DemandedLHS, DemandedRHS);
135}
136
137static void computeKnownBits(const Value *V, const APInt &DemandedElts,
138 KnownBits &Known, const SimplifyQuery &Q,
139 unsigned Depth);
140
141void llvm::computeKnownBits(const Value *V, KnownBits &Known,
142 const SimplifyQuery &Q, unsigned Depth) {
143 // Since the number of lanes in a scalable vector is unknown at compile time,
144 // we track one bit which is implicitly broadcast to all lanes. This means
145 // that all lanes in a scalable vector are considered demanded.
146 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
147 APInt DemandedElts =
148 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
149 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
150}
151
152void llvm::computeKnownBits(const Value *V, KnownBits &Known,
153 const DataLayout &DL, AssumptionCache *AC,
154 const Instruction *CxtI, const DominatorTree *DT,
155 bool UseInstrInfo, unsigned Depth) {
156 computeKnownBits(V, Known,
157 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
158 Depth);
159}
160
161KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
162 AssumptionCache *AC, const Instruction *CxtI,
163 const DominatorTree *DT, bool UseInstrInfo,
164 unsigned Depth) {
165 return computeKnownBits(
166 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
167}
168
169KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
170 const DataLayout &DL, AssumptionCache *AC,
171 const Instruction *CxtI,
172 const DominatorTree *DT, bool UseInstrInfo,
173 unsigned Depth) {
174 return computeKnownBits(
175 V, DemandedElts,
176 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
177}
178
179static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
180 const SimplifyQuery &SQ) {
181 // Look for an inverted mask: (X & ~M) op (Y & M).
182 {
183 Value *M;
184 if (match(V: LHS, P: m_c_And(L: m_Not(V: m_Value(V&: M)), R: m_Value())) &&
185 match(V: RHS, P: m_c_And(L: m_Specific(V: M), R: m_Value())) &&
186 isGuaranteedNotToBeUndef(V: M, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
187 return true;
188 }
189
190 // X op (Y & ~X)
191 if (match(V: RHS, P: m_c_And(L: m_Not(V: m_Specific(V: LHS)), R: m_Value())) &&
192 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
193 return true;
194
195 // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern
196 // for constant Y.
197 Value *Y;
198 if (match(V: RHS,
199 P: m_c_Xor(L: m_c_And(L: m_Specific(V: LHS), R: m_Value(V&: Y)), R: m_Deferred(V: Y))) &&
200 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
201 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
202 return true;
203
204 // Peek through extends to find a 'not' of the other side:
205 // (ext Y) op ext(~Y)
206 if (match(V: LHS, P: m_ZExtOrSExt(Op: m_Value(V&: Y))) &&
207 match(V: RHS, P: m_ZExtOrSExt(Op: m_Not(V: m_Specific(V: Y)))) &&
208 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
209 return true;
210
211 // Look for: (A & B) op ~(A | B)
212 {
213 Value *A, *B;
214 if (match(V: LHS, P: m_And(L: m_Value(V&: A), R: m_Value(V&: B))) &&
215 match(V: RHS, P: m_Not(V: m_c_Or(L: m_Specific(V: A), R: m_Specific(V: B)))) &&
216 isGuaranteedNotToBeUndef(V: A, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
217 isGuaranteedNotToBeUndef(V: B, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
218 return true;
219 }
220
221 // Look for: (X << V) op (Y >> (BitWidth - V))
222 // or (X >> V) op (Y << (BitWidth - V))
223 {
224 const Value *V;
225 const APInt *R;
226 if (((match(V: RHS, P: m_Shl(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
227 match(V: LHS, P: m_LShr(L: m_Value(), R: m_Specific(V)))) ||
228 (match(V: RHS, P: m_LShr(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
229 match(V: LHS, P: m_Shl(L: m_Value(), R: m_Specific(V))))) &&
230 R->uge(RHS: LHS->getType()->getScalarSizeInBits()))
231 return true;
232 }
233
234 return false;
235}
236
237bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
238 const WithCache<const Value *> &RHSCache,
239 const SimplifyQuery &SQ) {
240 const Value *LHS = LHSCache.getValue();
241 const Value *RHS = RHSCache.getValue();
242
243 assert(LHS->getType() == RHS->getType() &&
244 "LHS and RHS should have the same type");
245 assert(LHS->getType()->isIntOrIntVectorTy() &&
246 "LHS and RHS should be integers");
247
248 if (haveNoCommonBitsSetSpecialCases(LHS, RHS, SQ) ||
249 haveNoCommonBitsSetSpecialCases(LHS: RHS, RHS: LHS, SQ))
250 return true;
251
252 return KnownBits::haveNoCommonBitsSet(LHS: LHSCache.getKnownBits(Q: SQ),
253 RHS: RHSCache.getKnownBits(Q: SQ));
254}
255
256bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
257 return !I->user_empty() &&
258 all_of(Range: I->users(), P: match_fn(P: m_ICmp(L: m_Value(), R: m_Zero())));
259}
260
261bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
262 return !I->user_empty() && all_of(Range: I->users(), P: [](const User *U) {
263 CmpPredicate P;
264 return match(V: U, P: m_ICmp(Pred&: P, L: m_Value(), R: m_Zero())) && ICmpInst::isEquality(P);
265 });
266}
267
268bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
269 bool OrZero, AssumptionCache *AC,
270 const Instruction *CxtI,
271 const DominatorTree *DT, bool UseInstrInfo,
272 unsigned Depth) {
273 return ::isKnownToBeAPowerOfTwo(
274 V, OrZero, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
275 Depth);
276}
277
278static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
279 const SimplifyQuery &Q, unsigned Depth);
280
281bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
282 unsigned Depth) {
283 return computeKnownBits(V, Q: SQ, Depth).isNonNegative();
284}
285
286bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
287 unsigned Depth) {
288 if (auto *CI = dyn_cast<ConstantInt>(Val: V))
289 return CI->getValue().isStrictlyPositive();
290
291 // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
292 // this updated.
293 KnownBits Known = computeKnownBits(V, Q: SQ, Depth);
294 return Known.isNonNegative() &&
295 (Known.isNonZero() || isKnownNonZero(V, Q: SQ, Depth));
296}
297
298bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
299 unsigned Depth) {
300 return computeKnownBits(V, Q: SQ, Depth).isNegative();
301}
302
303static bool isKnownNonEqual(const Value *V1, const Value *V2,
304 const APInt &DemandedElts, const SimplifyQuery &Q,
305 unsigned Depth);
306
307bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
308 const SimplifyQuery &Q, unsigned Depth) {
309 // We don't support looking through casts.
310 if (V1 == V2 || V1->getType() != V2->getType())
311 return false;
312 auto *FVTy = dyn_cast<FixedVectorType>(Val: V1->getType());
313 APInt DemandedElts =
314 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
315 return ::isKnownNonEqual(V1, V2, DemandedElts, Q, Depth);
316}
317
318bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
319 const SimplifyQuery &SQ, unsigned Depth) {
320 KnownBits Known(Mask.getBitWidth());
321 computeKnownBits(V, Known, Q: SQ, Depth);
322 return Mask.isSubsetOf(RHS: Known.Zero);
323}
324
325static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
326 const SimplifyQuery &Q, unsigned Depth);
327
328static unsigned ComputeNumSignBits(const Value *V, const SimplifyQuery &Q,
329 unsigned Depth = 0) {
330 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
331 APInt DemandedElts =
332 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
333 return ComputeNumSignBits(V, DemandedElts, Q, Depth);
334}
335
336unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
337 AssumptionCache *AC, const Instruction *CxtI,
338 const DominatorTree *DT, bool UseInstrInfo,
339 unsigned Depth) {
340 return ::ComputeNumSignBits(
341 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
342}
343
344unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
345 AssumptionCache *AC,
346 const Instruction *CxtI,
347 const DominatorTree *DT,
348 unsigned Depth) {
349 unsigned SignBits = ComputeNumSignBits(V, DL, AC, CxtI, DT, UseInstrInfo: Depth);
350 return V->getType()->getScalarSizeInBits() - SignBits + 1;
351}
352
353/// Try to detect the lerp pattern: a * (b - c) + c * d
354/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355///
356/// In that particular case, we can use the following chain of reasoning:
357///
358/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359///
360/// Since that is true for arbitrary a, b, c and d within our constraints, we
361/// can conclude that:
362///
363/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364///
365/// Considering that any result of the lerp would be less or equal to U, it
366/// would have at least the number of leading 0s as in U.
367///
368/// While being quite a specific situation, it is fairly common in computer
369/// graphics in the shape of alpha blending.
370///
371/// Modifies given KnownOut in-place with the inferred information.
372static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
373 const APInt &DemandedElts,
374 KnownBits &KnownOut,
375 const SimplifyQuery &Q,
376 unsigned Depth) {
377
378 Type *Ty = Op0->getType();
379 const unsigned BitWidth = Ty->getScalarSizeInBits();
380
381 // Only handle scalar types for now
382 if (Ty->isVectorTy())
383 return;
384
385 // Try to match: a * (b - c) + c * d.
386 // When a == 1 => A == nullptr, the same applies to d/D as well.
387 const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
388 const Instruction *SubBC = nullptr;
389
390 const auto MatchSubBC = [&]() {
391 // (b - c) can have two forms that interest us:
392 //
393 // 1. sub nuw %b, %c
394 // 2. xor %c, %b
395 //
396 // For the first case, nuw flag guarantees our requirement b >= c.
397 //
398 // The second case might happen when the analysis can infer that b is a mask
399 // for c and we can transform sub operation into xor (that is usually true
400 // for constant b's). Even though xor is symmetrical, canonicalization
401 // ensures that the constant will be the RHS. We have additional checks
402 // later on to ensure that this xor operation is equivalent to subtraction.
403 return m_Instruction(I&: SubBC, Match: m_CombineOr(L: m_NUWSub(L: m_Value(V&: B), R: m_Value(V&: C)),
404 R: m_Xor(L: m_Value(V&: C), R: m_Value(V&: B))));
405 };
406
407 const auto MatchASubBC = [&]() {
408 // Cases:
409 // - a * (b - c)
410 // - (b - c) * a
411 // - (b - c) <- a implicitly equals 1
412 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: A), R: MatchSubBC()), R: MatchSubBC());
413 };
414
415 const auto MatchCD = [&]() {
416 // Cases:
417 // - d * c
418 // - c * d
419 // - c <- d implicitly equals 1
420 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: D), R: m_Specific(V: C)), R: m_Specific(V: C));
421 };
422
423 const auto Match = [&](const Value *LHS, const Value *RHS) {
424 // We do use m_Specific(C) in MatchCD, so we have to make sure that
425 // it's bound to anything and match(LHS, MatchASubBC()) absolutely
426 // has to evaluate first and return true.
427 //
428 // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
429 return match(V: LHS, P: MatchASubBC()) && match(V: RHS, P: MatchCD());
430 };
431
432 if (!Match(Op0, Op1) && !Match(Op1, Op0))
433 return;
434
435 const auto ComputeKnownBitsOrOne = [&](const Value *V) {
436 // For some of the values we use the convention of leaving
437 // it nullptr to signify an implicit constant 1.
438 return V ? computeKnownBits(V, DemandedElts, Q, Depth: Depth + 1)
439 : KnownBits::makeConstant(C: APInt(BitWidth, 1));
440 };
441
442 // Check that all operands are non-negative
443 const KnownBits KnownA = ComputeKnownBitsOrOne(A);
444 if (!KnownA.isNonNegative())
445 return;
446
447 const KnownBits KnownD = ComputeKnownBitsOrOne(D);
448 if (!KnownD.isNonNegative())
449 return;
450
451 const KnownBits KnownB = computeKnownBits(V: B, DemandedElts, Q, Depth: Depth + 1);
452 if (!KnownB.isNonNegative())
453 return;
454
455 const KnownBits KnownC = computeKnownBits(V: C, DemandedElts, Q, Depth: Depth + 1);
456 if (!KnownC.isNonNegative())
457 return;
458
459 // If we matched subtraction as xor, we need to actually check that xor
460 // is semantically equivalent to subtraction.
461 //
462 // For that to be true, b has to be a mask for c or that b's known
463 // ones cover all known and possible ones of c.
464 if (SubBC->getOpcode() == Instruction::Xor &&
465 !KnownC.getMaxValue().isSubsetOf(RHS: KnownB.getMinValue()))
466 return;
467
468 const APInt MaxA = KnownA.getMaxValue();
469 const APInt MaxD = KnownD.getMaxValue();
470 const APInt MaxAD = APIntOps::umax(A: MaxA, B: MaxD);
471 const APInt MaxB = KnownB.getMaxValue();
472
473 // We can't infer leading zeros info if the upper-bound estimate wraps.
474 bool Overflow;
475 const APInt UpperBound = MaxAD.umul_ov(RHS: MaxB, Overflow);
476
477 if (Overflow)
478 return;
479
480 // If we know that x <= y and both are positive than x has at least the same
481 // number of leading zeros as y.
482 const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
483 KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
484}
485
486static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
487 bool NSW, bool NUW,
488 const APInt &DemandedElts,
489 KnownBits &KnownOut, KnownBits &Known2,
490 const SimplifyQuery &Q, unsigned Depth) {
491 computeKnownBits(V: Op1, DemandedElts, Known&: KnownOut, Q, Depth: Depth + 1);
492
493 // If one operand is unknown and we have no nowrap information,
494 // the result will be unknown independently of the second operand.
495 if (KnownOut.isUnknown() && !NSW && !NUW)
496 return;
497
498 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
499 KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, LHS: Known2, RHS: KnownOut);
500
501 if (!Add && NSW && !KnownOut.isNonNegative() &&
502 isImpliedByDomCondition(Pred: ICmpInst::ICMP_SLE, LHS: Op1, RHS: Op0, ContextI: Q.CxtI, DL: Q.DL)
503 .value_or(u: false))
504 KnownOut.makeNonNegative();
505
506 if (Add)
507 // Try to match lerp pattern and combine results
508 computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
509}
510
511static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
512 bool NUW, const APInt &DemandedElts,
513 KnownBits &Known, KnownBits &Known2,
514 const SimplifyQuery &Q, unsigned Depth) {
515 computeKnownBits(V: Op1, DemandedElts, Known, Q, Depth: Depth + 1);
516 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
517
518 bool isKnownNegative = false;
519 bool isKnownNonNegative = false;
520 // If the multiplication is known not to overflow, compute the sign bit.
521 if (NSW) {
522 if (Op0 == Op1) {
523 // The product of a number with itself is non-negative.
524 isKnownNonNegative = true;
525 } else {
526 bool isKnownNonNegativeOp1 = Known.isNonNegative();
527 bool isKnownNonNegativeOp0 = Known2.isNonNegative();
528 bool isKnownNegativeOp1 = Known.isNegative();
529 bool isKnownNegativeOp0 = Known2.isNegative();
530 // The product of two numbers with the same sign is non-negative.
531 isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
532 (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
533 if (!isKnownNonNegative && NUW) {
534 // mul nuw nsw with a factor > 1 is non-negative.
535 KnownBits One = KnownBits::makeConstant(C: APInt(Known.getBitWidth(), 1));
536 isKnownNonNegative = KnownBits::sgt(LHS: Known, RHS: One).value_or(u: false) ||
537 KnownBits::sgt(LHS: Known2, RHS: One).value_or(u: false);
538 }
539
540 // The product of a negative number and a non-negative number is either
541 // negative or zero.
542 if (!isKnownNonNegative)
543 isKnownNegative =
544 (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
545 Known2.isNonZero()) ||
546 (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
547 }
548 }
549
550 bool SelfMultiply = Op0 == Op1;
551 if (SelfMultiply)
552 SelfMultiply &=
553 isGuaranteedNotToBeUndef(V: Op0, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1);
554 Known = KnownBits::mul(LHS: Known, RHS: Known2, NoUndefSelfMultiply: SelfMultiply);
555
556 if (SelfMultiply) {
557 unsigned SignBits = ComputeNumSignBits(V: Op0, DemandedElts, Q, Depth: Depth + 1);
558 unsigned TyBits = Op0->getType()->getScalarSizeInBits();
559 unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
560
561 if (OutValidBits < TyBits) {
562 APInt KnownZeroMask =
563 APInt::getHighBitsSet(numBits: TyBits, hiBitsSet: TyBits - OutValidBits + 1);
564 Known.Zero |= KnownZeroMask;
565 }
566 }
567
568 // Only make use of no-wrap flags if we failed to compute the sign bit
569 // directly. This matters if the multiplication always overflows, in
570 // which case we prefer to follow the result of the direct computation,
571 // though as the program is invoking undefined behaviour we can choose
572 // whatever we like here.
573 if (isKnownNonNegative && !Known.isNegative())
574 Known.makeNonNegative();
575 else if (isKnownNegative && !Known.isNonNegative())
576 Known.makeNegative();
577}
578
579void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
580 KnownBits &Known) {
581 unsigned BitWidth = Known.getBitWidth();
582 unsigned NumRanges = Ranges.getNumOperands() / 2;
583 assert(NumRanges >= 1);
584
585 Known.setAllConflict();
586
587 for (unsigned i = 0; i < NumRanges; ++i) {
588 ConstantInt *Lower =
589 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 0));
590 ConstantInt *Upper =
591 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 1));
592 ConstantRange Range(Lower->getValue(), Upper->getValue());
593 // BitWidth must equal the Ranges BitWidth for the correct number of high
594 // bits to be set.
595 assert(BitWidth == Range.getBitWidth() &&
596 "Known bit width must match range bit width!");
597
598 // The first CommonPrefixBits of all values in Range are equal.
599 unsigned CommonPrefixBits =
600 (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
601 APInt Mask = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: CommonPrefixBits);
602 APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(width: BitWidth);
603 Known.One &= UnsignedMax & Mask;
604 Known.Zero &= ~UnsignedMax & Mask;
605 }
606}
607
608static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
609 SmallVector<const Instruction *, 16> WorkSet(1, I);
610 SmallPtrSet<const Instruction *, 32> Visited;
611 SmallPtrSet<const Instruction *, 16> EphValues;
612
613 // The instruction defining an assumption's condition itself is always
614 // considered ephemeral to that assumption (even if it has other
615 // non-ephemeral users). See r246696's test case for an example.
616 if (is_contained(Range: I->operands(), Element: E))
617 return true;
618
619 while (!WorkSet.empty()) {
620 const Instruction *V = WorkSet.pop_back_val();
621 if (!Visited.insert(Ptr: V).second)
622 continue;
623
624 // If all uses of this value are ephemeral, then so is this value.
625 if (all_of(Range: V->users(), P: [&](const User *U) {
626 return EphValues.count(Ptr: cast<Instruction>(Val: U));
627 })) {
628 if (V == E)
629 return true;
630
631 if (V == I || (!V->mayHaveSideEffects() && !V->isTerminator())) {
632 EphValues.insert(Ptr: V);
633
634 if (const User *U = dyn_cast<User>(Val: V)) {
635 for (const Use &U : U->operands()) {
636 if (const auto *I = dyn_cast<Instruction>(Val: U.get()))
637 WorkSet.push_back(Elt: I);
638 }
639 }
640 }
641 }
642 }
643
644 return false;
645}
646
647// Is this an intrinsic that cannot be speculated but also cannot trap?
648bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
649 if (const IntrinsicInst *CI = dyn_cast<IntrinsicInst>(Val: I))
650 return CI->isAssumeLikeIntrinsic();
651
652 return false;
653}
654
655bool llvm::isValidAssumeForContext(const Instruction *Inv,
656 const Instruction *CxtI,
657 const DominatorTree *DT,
658 bool AllowEphemerals) {
659 // There are two restrictions on the use of an assume:
660 // 1. The assume must dominate the context (or the control flow must
661 // reach the assume whenever it reaches the context).
662 // 2. The context must not be in the assume's set of ephemeral values
663 // (otherwise we will use the assume to prove that the condition
664 // feeding the assume is trivially true, thus causing the removal of
665 // the assume).
666
667 if (Inv->getParent() == CxtI->getParent()) {
668 // If Inv and CtxI are in the same block, check if the assume (Inv) is first
669 // in the BB.
670 if (Inv->comesBefore(Other: CxtI))
671 return true;
672
673 // Don't let an assume affect itself - this would cause the problems
674 // `isEphemeralValueOf` is trying to prevent, and it would also make
675 // the loop below go out of bounds.
676 if (!AllowEphemerals && Inv == CxtI)
677 return false;
678
679 // The context comes first, but they're both in the same block.
680 // Make sure there is nothing in between that might interrupt
681 // the control flow, not even CxtI itself.
682 // We limit the scan distance between the assume and its context instruction
683 // to avoid a compile-time explosion. This limit is chosen arbitrarily, so
684 // it can be adjusted if needed (could be turned into a cl::opt).
685 auto Range = make_range(x: CxtI->getIterator(), y: Inv->getIterator());
686 if (!isGuaranteedToTransferExecutionToSuccessor(Range, ScanLimit: 15))
687 return false;
688
689 return AllowEphemerals || !isEphemeralValueOf(I: Inv, E: CxtI);
690 }
691
692 // Inv and CxtI are in different blocks.
693 if (DT) {
694 if (DT->dominates(Def: Inv, User: CxtI))
695 return true;
696 } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor() ||
697 Inv->getParent()->isEntryBlock()) {
698 // We don't have a DT, but this trivially dominates.
699 return true;
700 }
701
702 return false;
703}
704
705bool llvm::willNotFreeBetween(const Instruction *Assume,
706 const Instruction *CtxI) {
707 // Helper to check if there are any calls in the range that may free memory.
708 auto hasNoFreeCalls = [](auto Range) {
709 for (const auto &[Idx, I] : enumerate(Range)) {
710 if (Idx > MaxInstrsToCheckForFree)
711 return false;
712 if (const auto *CB = dyn_cast<CallBase>(&I))
713 if (!CB->hasFnAttr(Attribute::NoFree))
714 return false;
715 }
716 return true;
717 };
718
719 // Make sure the current function cannot arrange for another thread to free on
720 // its behalf.
721 if (!CtxI->getFunction()->hasNoSync())
722 return false;
723
724 // Handle cross-block case: CtxI in a successor of Assume's block.
725 const BasicBlock *CtxBB = CtxI->getParent();
726 const BasicBlock *AssumeBB = Assume->getParent();
727 BasicBlock::const_iterator CtxIter = CtxI->getIterator();
728 if (CtxBB != AssumeBB) {
729 if (CtxBB->getSinglePredecessor() != AssumeBB)
730 return false;
731
732 if (!hasNoFreeCalls(make_range(x: CtxBB->begin(), y: CtxIter)))
733 return false;
734
735 CtxIter = AssumeBB->end();
736 } else {
737 // Same block case: check that Assume comes before CtxI.
738 if (!Assume->comesBefore(Other: CtxI))
739 return false;
740 }
741
742 // Check if there are any calls between Assume and CtxIter that may free
743 // memory.
744 return hasNoFreeCalls(make_range(x: Assume->getIterator(), y: CtxIter));
745}
746
747// TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but
748// we still have enough information about `RHS` to conclude non-zero. For
749// example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops
750// so the extra compile time may not be worth it, but possibly a second API
751// should be created for use outside of loops.
752static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
753 // v u> y implies v != 0.
754 if (Pred == ICmpInst::ICMP_UGT)
755 return true;
756
757 // Special-case v != 0 to also handle v != null.
758 if (Pred == ICmpInst::ICMP_NE)
759 return match(V: RHS, P: m_Zero());
760
761 // All other predicates - rely on generic ConstantRange handling.
762 const APInt *C;
763 auto Zero = APInt::getZero(numBits: RHS->getType()->getScalarSizeInBits());
764 if (match(V: RHS, P: m_APInt(Res&: C))) {
765 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(Pred, Other: *C);
766 return !TrueValues.contains(Val: Zero);
767 }
768
769 auto *VC = dyn_cast<ConstantDataVector>(Val: RHS);
770 if (VC == nullptr)
771 return false;
772
773 for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
774 ++ElemIdx) {
775 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
776 Pred, Other: VC->getElementAsAPInt(i: ElemIdx));
777 if (TrueValues.contains(Val: Zero))
778 return false;
779 }
780 return true;
781}
782
783static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
784 Value *&ValOut, Instruction *&CtxIOut,
785 const PHINode **PhiOut = nullptr) {
786 ValOut = U->get();
787 if (ValOut == PHI)
788 return;
789 CtxIOut = PHI->getIncomingBlock(U: *U)->getTerminator();
790 if (PhiOut)
791 *PhiOut = PHI;
792 Value *V;
793 // If the Use is a select of this phi, compute analysis on other arm to break
794 // recursion.
795 // TODO: Min/Max
796 if (match(V: ValOut, P: m_Select(C: m_Value(), L: m_Specific(V: PHI), R: m_Value(V))) ||
797 match(V: ValOut, P: m_Select(C: m_Value(), L: m_Value(V), R: m_Specific(V: PHI))))
798 ValOut = V;
799
800 // Same for select, if this phi is 2-operand phi, compute analysis on other
801 // incoming value to break recursion.
802 // TODO: We could handle any number of incoming edges as long as we only have
803 // two unique values.
804 if (auto *IncPhi = dyn_cast<PHINode>(Val: ValOut);
805 IncPhi && IncPhi->getNumIncomingValues() == 2) {
806 for (int Idx = 0; Idx < 2; ++Idx) {
807 if (IncPhi->getIncomingValue(i: Idx) == PHI) {
808 ValOut = IncPhi->getIncomingValue(i: 1 - Idx);
809 if (PhiOut)
810 *PhiOut = IncPhi;
811 CtxIOut = IncPhi->getIncomingBlock(i: 1 - Idx)->getTerminator();
812 break;
813 }
814 }
815 }
816}
817
818static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
819 // Use of assumptions is context-sensitive. If we don't have a context, we
820 // cannot use them!
821 if (!Q.AC || !Q.CxtI)
822 return false;
823
824 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
825 if (!Elem.Assume)
826 continue;
827
828 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
829 assert(I->getFunction() == Q.CxtI->getFunction() &&
830 "Got assumption for the wrong function!");
831
832 if (Elem.Index != AssumptionCache::ExprResultIdx) {
833 if (!V->getType()->isPointerTy())
834 continue;
835 if (RetainedKnowledge RK = getKnowledgeFromBundle(
836 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
837 if (RK.WasOn != V)
838 continue;
839 bool AssumeImpliesNonNull = [&]() {
840 if (RK.AttrKind == Attribute::NonNull)
841 return true;
842
843 if (RK.AttrKind == Attribute::Dereferenceable) {
844 if (NullPointerIsDefined(F: Q.CxtI->getFunction(),
845 AS: V->getType()->getPointerAddressSpace()))
846 return false;
847 assert(RK.IRArgValue &&
848 "Dereferenceable attribute without IR argument?");
849
850 auto *CI = dyn_cast<ConstantInt>(Val: RK.IRArgValue);
851 return CI && !CI->isZero();
852 }
853
854 return false;
855 }();
856 if (AssumeImpliesNonNull && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
857 return true;
858 }
859 continue;
860 }
861
862 // Warning: This loop can end up being somewhat performance sensitive.
863 // We're running this loop for once for each value queried resulting in a
864 // runtime of ~O(#assumes * #values).
865
866 Value *RHS;
867 CmpPredicate Pred;
868 auto m_V = m_CombineOr(L: m_Specific(V), R: m_PtrToInt(Op: m_Specific(V)));
869 if (!match(V: I->getArgOperand(i: 0), P: m_c_ICmp(Pred, L: m_V, R: m_Value(V&: RHS))))
870 continue;
871
872 if (cmpExcludesZero(Pred, RHS) && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
873 return true;
874 }
875
876 return false;
877}
878
879static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
880 Value *LHS, Value *RHS, KnownBits &Known,
881 const SimplifyQuery &Q) {
882 if (RHS->getType()->isPointerTy()) {
883 // Handle comparison of pointer to null explicitly, as it will not be
884 // covered by the m_APInt() logic below.
885 if (LHS == V && match(V: RHS, P: m_Zero())) {
886 switch (Pred) {
887 case ICmpInst::ICMP_EQ:
888 Known.setAllZero();
889 break;
890 case ICmpInst::ICMP_SGE:
891 case ICmpInst::ICMP_SGT:
892 Known.makeNonNegative();
893 break;
894 case ICmpInst::ICMP_SLT:
895 Known.makeNegative();
896 break;
897 default:
898 break;
899 }
900 }
901 return;
902 }
903
904 unsigned BitWidth = Known.getBitWidth();
905 auto m_V =
906 m_CombineOr(L: m_Specific(V), R: m_PtrToIntSameSize(DL: Q.DL, Op: m_Specific(V)));
907
908 Value *Y;
909 const APInt *Mask, *C;
910 if (!match(V: RHS, P: m_APInt(Res&: C)))
911 return;
912
913 uint64_t ShAmt;
914 switch (Pred) {
915 case ICmpInst::ICMP_EQ:
916 // assume(V = C)
917 if (match(V: LHS, P: m_V)) {
918 Known = Known.unionWith(RHS: KnownBits::makeConstant(C: *C));
919 // assume(V & Mask = C)
920 } else if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value(V&: Y)))) {
921 // For one bits in Mask, we can propagate bits from C to V.
922 Known.One |= *C;
923 if (match(V: Y, P: m_APInt(Res&: Mask)))
924 Known.Zero |= ~*C & *Mask;
925 // assume(V | Mask = C)
926 } else if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value(V&: Y)))) {
927 // For zero bits in Mask, we can propagate bits from C to V.
928 Known.Zero |= ~*C;
929 if (match(V: Y, P: m_APInt(Res&: Mask)))
930 Known.One |= *C & ~*Mask;
931 // assume(V << ShAmt = C)
932 } else if (match(V: LHS, P: m_Shl(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
933 ShAmt < BitWidth) {
934 // For those bits in C that are known, we can propagate them to known
935 // bits in V shifted to the right by ShAmt.
936 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
937 RHSKnown >>= ShAmt;
938 Known = Known.unionWith(RHS: RHSKnown);
939 // assume(V >> ShAmt = C)
940 } else if (match(V: LHS, P: m_Shr(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
941 ShAmt < BitWidth) {
942 // For those bits in RHS that are known, we can propagate them to known
943 // bits in V shifted to the right by C.
944 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
945 RHSKnown <<= ShAmt;
946 Known = Known.unionWith(RHS: RHSKnown);
947 }
948 break;
949 case ICmpInst::ICMP_NE: {
950 // assume (V & B != 0) where B is a power of 2
951 const APInt *BPow2;
952 if (C->isZero() && match(V: LHS, P: m_And(L: m_V, R: m_Power2(V&: BPow2))))
953 Known.One |= *BPow2;
954 break;
955 }
956 default: {
957 const APInt *Offset = nullptr;
958 if (match(V: LHS, P: m_CombineOr(L: m_V, R: m_AddLike(L: m_V, R: m_APInt(Res&: Offset))))) {
959 ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, Other: *C);
960 if (Offset)
961 LHSRange = LHSRange.sub(Other: *Offset);
962 Known = Known.unionWith(RHS: LHSRange.toKnownBits());
963 }
964 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
965 // X & Y u> C -> X u> C && Y u> C
966 // X nuw- Y u> C -> X u> C
967 if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value())) ||
968 match(V: LHS, P: m_NUWSub(L: m_V, R: m_Value())))
969 Known.One.setHighBits(
970 (*C + (Pred == ICmpInst::ICMP_UGT)).countLeadingOnes());
971 }
972 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
973 // X | Y u< C -> X u< C && Y u< C
974 // X nuw+ Y u< C -> X u< C && Y u< C
975 if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value())) ||
976 match(V: LHS, P: m_c_NUWAdd(L: m_V, R: m_Value()))) {
977 Known.Zero.setHighBits(
978 (*C - (Pred == ICmpInst::ICMP_ULT)).countLeadingZeros());
979 }
980 }
981 } break;
982 }
983}
984
985static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
986 KnownBits &Known,
987 const SimplifyQuery &SQ, bool Invert) {
988 ICmpInst::Predicate Pred =
989 Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
990 Value *LHS = Cmp->getOperand(i_nocapture: 0);
991 Value *RHS = Cmp->getOperand(i_nocapture: 1);
992
993 // Handle icmp pred (trunc V), C
994 if (match(V: LHS, P: m_Trunc(Op: m_Specific(V)))) {
995 KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
996 computeKnownBitsFromCmp(V: LHS, Pred, LHS, RHS, Known&: DstKnown, Q: SQ);
997 if (cast<TruncInst>(Val: LHS)->hasNoUnsignedWrap())
998 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
999 else
1000 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1001 return;
1002 }
1003
1004 computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, Q: SQ);
1005}
1006
1007static void computeKnownBitsFromCond(const Value *V, Value *Cond,
1008 KnownBits &Known, const SimplifyQuery &SQ,
1009 bool Invert, unsigned Depth) {
1010 Value *A, *B;
1011 if (Depth < MaxAnalysisRecursionDepth &&
1012 match(V: Cond, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
1013 KnownBits Known2(Known.getBitWidth());
1014 KnownBits Known3(Known.getBitWidth());
1015 computeKnownBitsFromCond(V, Cond: A, Known&: Known2, SQ, Invert, Depth: Depth + 1);
1016 computeKnownBitsFromCond(V, Cond: B, Known&: Known3, SQ, Invert, Depth: Depth + 1);
1017 if (Invert ? match(V: Cond, P: m_LogicalOr(L: m_Value(), R: m_Value()))
1018 : match(V: Cond, P: m_LogicalAnd(L: m_Value(), R: m_Value())))
1019 Known2 = Known2.unionWith(RHS: Known3);
1020 else
1021 Known2 = Known2.intersectWith(RHS: Known3);
1022 Known = Known.unionWith(RHS: Known2);
1023 return;
1024 }
1025
1026 if (auto *Cmp = dyn_cast<ICmpInst>(Val: Cond)) {
1027 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ, Invert);
1028 return;
1029 }
1030
1031 if (match(V: Cond, P: m_Trunc(Op: m_Specific(V)))) {
1032 KnownBits DstKnown(1);
1033 if (Invert) {
1034 DstKnown.setAllZero();
1035 } else {
1036 DstKnown.setAllOnes();
1037 }
1038 if (cast<TruncInst>(Val: Cond)->hasNoUnsignedWrap()) {
1039 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
1040 return;
1041 }
1042 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1043 return;
1044 }
1045
1046 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A))))
1047 computeKnownBitsFromCond(V, Cond: A, Known, SQ, Invert: !Invert, Depth: Depth + 1);
1048}
1049
1050void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
1051 const SimplifyQuery &Q, unsigned Depth) {
1052 // Handle injected condition.
1053 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
1054 computeKnownBitsFromCond(V, Cond: Q.CC->Cond, Known, SQ: Q, Invert: Q.CC->Invert, Depth);
1055
1056 if (!Q.CxtI)
1057 return;
1058
1059 if (Q.DC && Q.DT) {
1060 // Handle dominating conditions.
1061 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
1062 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
1063 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
1064 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1065 /*Invert*/ false, Depth);
1066
1067 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
1068 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
1069 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1070 /*Invert*/ true, Depth);
1071 }
1072
1073 if (Known.hasConflict())
1074 Known.resetAll();
1075 }
1076
1077 if (!Q.AC)
1078 return;
1079
1080 unsigned BitWidth = Known.getBitWidth();
1081
1082 // Note that the patterns below need to be kept in sync with the code
1083 // in AssumptionCache::updateAffectedValues.
1084
1085 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
1086 if (!Elem.Assume)
1087 continue;
1088
1089 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
1090 assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
1091 "Got assumption for the wrong function!");
1092
1093 if (Elem.Index != AssumptionCache::ExprResultIdx) {
1094 if (!V->getType()->isPointerTy())
1095 continue;
1096 if (RetainedKnowledge RK = getKnowledgeFromBundle(
1097 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
1098 // Allow AllowEphemerals in isValidAssumeForContext, as the CxtI might
1099 // be the producer of the pointer in the bundle. At the moment, align
1100 // assumptions aren't optimized away.
1101 if (RK.WasOn == V && RK.AttrKind == Attribute::Alignment &&
1102 isPowerOf2_64(Value: RK.ArgValue) &&
1103 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT, /*AllowEphemerals*/ true))
1104 Known.Zero.setLowBits(Log2_64(Value: RK.ArgValue));
1105 }
1106 continue;
1107 }
1108
1109 // Warning: This loop can end up being somewhat performance sensitive.
1110 // We're running this loop for once for each value queried resulting in a
1111 // runtime of ~O(#assumes * #values).
1112
1113 Value *Arg = I->getArgOperand(i: 0);
1114
1115 if (Arg == V && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1116 assert(BitWidth == 1 && "assume operand is not i1?");
1117 (void)BitWidth;
1118 Known.setAllOnes();
1119 return;
1120 }
1121 if (match(V: Arg, P: m_Not(V: m_Specific(V))) &&
1122 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1123 assert(BitWidth == 1 && "assume operand is not i1?");
1124 (void)BitWidth;
1125 Known.setAllZero();
1126 return;
1127 }
1128 auto *Trunc = dyn_cast<TruncInst>(Val: Arg);
1129 if (Trunc && Trunc->getOperand(i_nocapture: 0) == V &&
1130 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1131 if (Trunc->hasNoUnsignedWrap()) {
1132 Known = KnownBits::makeConstant(C: APInt(BitWidth, 1));
1133 return;
1134 }
1135 Known.One.setBit(0);
1136 return;
1137 }
1138
1139 // The remaining tests are all recursive, so bail out if we hit the limit.
1140 if (Depth == MaxAnalysisRecursionDepth)
1141 continue;
1142
1143 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
1144 if (!Cmp)
1145 continue;
1146
1147 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
1148 continue;
1149
1150 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ: Q, /*Invert=*/false);
1151 }
1152
1153 // Conflicting assumption: Undefined behavior will occur on this execution
1154 // path.
1155 if (Known.hasConflict())
1156 Known.resetAll();
1157}
1158
1159/// Compute known bits from a shift operator, including those with a
1160/// non-constant shift amount. Known is the output of this function. Known2 is a
1161/// pre-allocated temporary with the same bit width as Known and on return
1162/// contains the known bit of the shift value source. KF is an
1163/// operator-specific function that, given the known-bits and a shift amount,
1164/// compute the implied known-bits of the shift operator's result respectively
1165/// for that shift amount. The results from calling KF are conservatively
1166/// combined for all permitted shift amounts.
1167static void computeKnownBitsFromShiftOperator(
1168 const Operator *I, const APInt &DemandedElts, KnownBits &Known,
1169 KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
1170 function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
1171 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1172 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1173 // To limit compile-time impact, only query isKnownNonZero() if we know at
1174 // least something about the shift amount.
1175 bool ShAmtNonZero =
1176 Known.isNonZero() ||
1177 (Known.getMaxValue().ult(RHS: Known.getBitWidth()) &&
1178 isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1));
1179 Known = KF(Known2, Known, ShAmtNonZero);
1180}
1181
1182static KnownBits
1183getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
1184 const KnownBits &KnownLHS, const KnownBits &KnownRHS,
1185 const SimplifyQuery &Q, unsigned Depth) {
1186 unsigned BitWidth = KnownLHS.getBitWidth();
1187 KnownBits KnownOut(BitWidth);
1188 bool IsAnd = false;
1189 bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
1190 Value *X = nullptr, *Y = nullptr;
1191
1192 switch (I->getOpcode()) {
1193 case Instruction::And:
1194 KnownOut = KnownLHS & KnownRHS;
1195 IsAnd = true;
1196 // and(x, -x) is common idioms that will clear all but lowest set
1197 // bit. If we have a single known bit in x, we can clear all bits
1198 // above it.
1199 // TODO: instcombine often reassociates independent `and` which can hide
1200 // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
1201 if (HasKnownOne && match(V: I, P: m_c_And(L: m_Value(V&: X), R: m_Neg(V: m_Deferred(V: X))))) {
1202 // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
1203 if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
1204 KnownOut = KnownLHS.blsi();
1205 else
1206 KnownOut = KnownRHS.blsi();
1207 }
1208 break;
1209 case Instruction::Or:
1210 KnownOut = KnownLHS | KnownRHS;
1211 break;
1212 case Instruction::Xor:
1213 KnownOut = KnownLHS ^ KnownRHS;
1214 // xor(x, x-1) is common idioms that will clear all but lowest set
1215 // bit. If we have a single known bit in x, we can clear all bits
1216 // above it.
1217 // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
1218 // -1 but for the purpose of demanded bits (xor(x, x-C) &
1219 // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
1220 // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
1221 if (HasKnownOne &&
1222 match(V: I, P: m_c_Xor(L: m_Value(V&: X), R: m_Add(L: m_Deferred(V: X), R: m_AllOnes())))) {
1223 const KnownBits &XBits = I->getOperand(i: 0) == X ? KnownLHS : KnownRHS;
1224 KnownOut = XBits.blsmsk();
1225 }
1226 break;
1227 default:
1228 llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
1229 }
1230
1231 // and(x, add (x, -1)) is a common idiom that always clears the low bit;
1232 // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
1233 // here we handle the more general case of adding any odd number by
1234 // matching the form and/xor/or(x, add(x, y)) where y is odd.
1235 // TODO: This could be generalized to clearing any bit set in y where the
1236 // following bit is known to be unset in y.
1237 if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
1238 (match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_c_Add(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1239 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1240 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Value(V&: Y), R: m_Deferred(V: X)))))) {
1241 KnownBits KnownY(BitWidth);
1242 computeKnownBits(V: Y, DemandedElts, Known&: KnownY, Q, Depth: Depth + 1);
1243 if (KnownY.countMinTrailingOnes() > 0) {
1244 if (IsAnd)
1245 KnownOut.Zero.setBit(0);
1246 else
1247 KnownOut.One.setBit(0);
1248 }
1249 }
1250 return KnownOut;
1251}
1252
1253static KnownBits computeKnownBitsForHorizontalOperation(
1254 const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
1255 unsigned Depth,
1256 const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
1257 KnownBitsFunc) {
1258 APInt DemandedEltsLHS, DemandedEltsRHS;
1259 getHorizDemandedEltsForFirstOperand(VectorBitWidth: Q.DL.getTypeSizeInBits(Ty: I->getType()),
1260 DemandedElts, DemandedLHS&: DemandedEltsLHS,
1261 DemandedRHS&: DemandedEltsRHS);
1262
1263 const auto ComputeForSingleOpFunc =
1264 [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
1265 return KnownBitsFunc(
1266 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp, Q, Depth: Depth + 1),
1267 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp << 1, Q, Depth: Depth + 1));
1268 };
1269
1270 if (DemandedEltsRHS.isZero())
1271 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS);
1272 if (DemandedEltsLHS.isZero())
1273 return ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS);
1274
1275 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS)
1276 .intersectWith(RHS: ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS));
1277}
1278
1279// Public so this can be used in `SimplifyDemandedUseBits`.
1280KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
1281 const KnownBits &KnownLHS,
1282 const KnownBits &KnownRHS,
1283 const SimplifyQuery &SQ,
1284 unsigned Depth) {
1285 auto *FVTy = dyn_cast<FixedVectorType>(Val: I->getType());
1286 APInt DemandedElts =
1287 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
1288
1289 return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, Q: SQ,
1290 Depth);
1291}
1292
1293ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
1294 Attribute Attr = F->getFnAttribute(Kind: Attribute::VScaleRange);
1295 // Without vscale_range, we only know that vscale is non-zero.
1296 if (!Attr.isValid())
1297 return ConstantRange(APInt(BitWidth, 1), APInt::getZero(numBits: BitWidth));
1298
1299 unsigned AttrMin = Attr.getVScaleRangeMin();
1300 // Minimum is larger than vscale width, result is always poison.
1301 if ((unsigned)llvm::bit_width(Value: AttrMin) > BitWidth)
1302 return ConstantRange::getEmpty(BitWidth);
1303
1304 APInt Min(BitWidth, AttrMin);
1305 std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
1306 if (!AttrMax || (unsigned)llvm::bit_width(Value: *AttrMax) > BitWidth)
1307 return ConstantRange(Min, APInt::getZero(numBits: BitWidth));
1308
1309 return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
1310}
1311
1312void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
1313 Value *Arm, bool Invert,
1314 const SimplifyQuery &Q, unsigned Depth) {
1315 // If we have a constant arm, we are done.
1316 if (Known.isConstant())
1317 return;
1318
1319 // See what condition implies about the bits of the select arm.
1320 KnownBits CondRes(Known.getBitWidth());
1321 computeKnownBitsFromCond(V: Arm, Cond, Known&: CondRes, SQ: Q, Invert, Depth: Depth + 1);
1322 // If we don't get any information from the condition, no reason to
1323 // proceed.
1324 if (CondRes.isUnknown())
1325 return;
1326
1327 // We can have conflict if the condition is dead. I.e if we have
1328 // (x | 64) < 32 ? (x | 64) : y
1329 // we will have conflict at bit 6 from the condition/the `or`.
1330 // In that case just return. Its not particularly important
1331 // what we do, as this select is going to be simplified soon.
1332 CondRes = CondRes.unionWith(RHS: Known);
1333 if (CondRes.hasConflict())
1334 return;
1335
1336 // Finally make sure the information we found is valid. This is relatively
1337 // expensive so it's left for the very end.
1338 if (!isGuaranteedNotToBeUndef(V: Arm, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1))
1339 return;
1340
1341 // Finally, we know we get information from the condition and its valid,
1342 // so return it.
1343 Known = CondRes;
1344}
1345
1346// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1347// Returns the input and lower/upper bounds.
1348static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1349 const APInt *&CLow, const APInt *&CHigh) {
1350 assert(isa<Operator>(Select) &&
1351 cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1352 "Input should be a Select!");
1353
1354 const Value *LHS = nullptr, *RHS = nullptr;
1355 SelectPatternFlavor SPF = matchSelectPattern(V: Select, LHS, RHS).Flavor;
1356 if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1357 return false;
1358
1359 if (!match(V: RHS, P: m_APInt(Res&: CLow)))
1360 return false;
1361
1362 const Value *LHS2 = nullptr, *RHS2 = nullptr;
1363 SelectPatternFlavor SPF2 = matchSelectPattern(V: LHS, LHS&: LHS2, RHS&: RHS2).Flavor;
1364 if (getInverseMinMaxFlavor(SPF) != SPF2)
1365 return false;
1366
1367 if (!match(V: RHS2, P: m_APInt(Res&: CHigh)))
1368 return false;
1369
1370 if (SPF == SPF_SMIN)
1371 std::swap(a&: CLow, b&: CHigh);
1372
1373 In = LHS2;
1374 return CLow->sle(RHS: *CHigh);
1375}
1376
1377static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1378 const APInt *&CLow,
1379 const APInt *&CHigh) {
1380 assert((II->getIntrinsicID() == Intrinsic::smin ||
1381 II->getIntrinsicID() == Intrinsic::smax) &&
1382 "Must be smin/smax");
1383
1384 Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(MinMaxID: II->getIntrinsicID());
1385 auto *InnerII = dyn_cast<IntrinsicInst>(Val: II->getArgOperand(i: 0));
1386 if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1387 !match(V: II->getArgOperand(i: 1), P: m_APInt(Res&: CLow)) ||
1388 !match(V: InnerII->getArgOperand(i: 1), P: m_APInt(Res&: CHigh)))
1389 return false;
1390
1391 if (II->getIntrinsicID() == Intrinsic::smin)
1392 std::swap(a&: CLow, b&: CHigh);
1393 return CLow->sle(RHS: *CHigh);
1394}
1395
1396static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1397 KnownBits &Known) {
1398 const APInt *CLow, *CHigh;
1399 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1400 Known = Known.unionWith(
1401 RHS: ConstantRange::getNonEmpty(Lower: *CLow, Upper: *CHigh + 1).toKnownBits());
1402}
1403
1404static void computeKnownBitsFromOperator(const Operator *I,
1405 const APInt &DemandedElts,
1406 KnownBits &Known,
1407 const SimplifyQuery &Q,
1408 unsigned Depth) {
1409 unsigned BitWidth = Known.getBitWidth();
1410
1411 KnownBits Known2(BitWidth);
1412 switch (I->getOpcode()) {
1413 default: break;
1414 case Instruction::Load:
1415 if (MDNode *MD =
1416 Q.IIQ.getMetadata(I: cast<LoadInst>(Val: I), KindID: LLVMContext::MD_range))
1417 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
1418 break;
1419 case Instruction::And:
1420 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1421 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1422
1423 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1424 break;
1425 case Instruction::Or:
1426 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1427 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1428
1429 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1430 break;
1431 case Instruction::Xor:
1432 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1433 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1434
1435 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1436 break;
1437 case Instruction::Mul: {
1438 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1439 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1440 computeKnownBitsMul(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1441 DemandedElts, Known, Known2, Q, Depth);
1442 break;
1443 }
1444 case Instruction::UDiv: {
1445 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1446 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1447 Known =
1448 KnownBits::udiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1449 break;
1450 }
1451 case Instruction::SDiv: {
1452 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1453 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1454 Known =
1455 KnownBits::sdiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1456 break;
1457 }
1458 case Instruction::Select: {
1459 auto ComputeForArm = [&](Value *Arm, bool Invert) {
1460 KnownBits Res(Known.getBitWidth());
1461 computeKnownBits(V: Arm, DemandedElts, Known&: Res, Q, Depth: Depth + 1);
1462 adjustKnownBitsForSelectArm(Known&: Res, Cond: I->getOperand(i: 0), Arm, Invert, Q, Depth);
1463 return Res;
1464 };
1465 // Only known if known in both the LHS and RHS.
1466 Known =
1467 ComputeForArm(I->getOperand(i: 1), /*Invert=*/false)
1468 .intersectWith(RHS: ComputeForArm(I->getOperand(i: 2), /*Invert=*/true));
1469 break;
1470 }
1471 case Instruction::FPTrunc:
1472 case Instruction::FPExt:
1473 case Instruction::FPToUI:
1474 case Instruction::FPToSI:
1475 case Instruction::SIToFP:
1476 case Instruction::UIToFP:
1477 break; // Can't work with floating point.
1478 case Instruction::PtrToInt:
1479 case Instruction::PtrToAddr:
1480 case Instruction::IntToPtr:
1481 // Fall through and handle them the same as zext/trunc.
1482 [[fallthrough]];
1483 case Instruction::ZExt:
1484 case Instruction::Trunc: {
1485 Type *SrcTy = I->getOperand(i: 0)->getType();
1486
1487 unsigned SrcBitWidth;
1488 // Note that we handle pointer operands here because of inttoptr/ptrtoint
1489 // which fall through here.
1490 Type *ScalarTy = SrcTy->getScalarType();
1491 SrcBitWidth = ScalarTy->isPointerTy() ?
1492 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
1493 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
1494
1495 assert(SrcBitWidth && "SrcBitWidth can't be zero");
1496 Known = Known.anyextOrTrunc(BitWidth: SrcBitWidth);
1497 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1498 if (auto *Inst = dyn_cast<PossiblyNonNegInst>(Val: I);
1499 Inst && Inst->hasNonNeg() && !Known.isNegative())
1500 Known.makeNonNegative();
1501 Known = Known.zextOrTrunc(BitWidth);
1502 break;
1503 }
1504 case Instruction::BitCast: {
1505 Type *SrcTy = I->getOperand(i: 0)->getType();
1506 if (SrcTy->isIntOrPtrTy() &&
1507 // TODO: For now, not handling conversions like:
1508 // (bitcast i64 %x to <2 x i32>)
1509 !I->getType()->isVectorTy()) {
1510 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1511 break;
1512 }
1513
1514 const Value *V;
1515 // Handle bitcast from floating point to integer.
1516 if (match(V: I, P: m_ElementWiseBitCast(Op: m_Value(V))) &&
1517 V->getType()->isFPOrFPVectorTy()) {
1518 Type *FPType = V->getType()->getScalarType();
1519 KnownFPClass Result =
1520 computeKnownFPClass(V, DemandedElts, InterestedClasses: fcAllFlags, SQ: Q, Depth: Depth + 1);
1521 FPClassTest FPClasses = Result.KnownFPClasses;
1522
1523 // TODO: Treat it as zero/poison if the use of I is unreachable.
1524 if (FPClasses == fcNone)
1525 break;
1526
1527 if (Result.isKnownNever(Mask: fcNormal | fcSubnormal | fcNan)) {
1528 Known.setAllConflict();
1529
1530 if (FPClasses & fcInf)
1531 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1532 C: APFloat::getInf(Sem: FPType->getFltSemantics()).bitcastToAPInt()));
1533
1534 if (FPClasses & fcZero)
1535 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1536 C: APInt::getZero(numBits: FPType->getScalarSizeInBits())));
1537
1538 Known.Zero.clearSignBit();
1539 Known.One.clearSignBit();
1540 }
1541
1542 if (Result.SignBit) {
1543 if (*Result.SignBit)
1544 Known.makeNegative();
1545 else
1546 Known.makeNonNegative();
1547 }
1548
1549 break;
1550 }
1551
1552 // Handle cast from vector integer type to scalar or vector integer.
1553 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcTy);
1554 if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
1555 !I->getType()->isIntOrIntVectorTy() ||
1556 isa<ScalableVectorType>(Val: I->getType()))
1557 break;
1558
1559 unsigned NumElts = DemandedElts.getBitWidth();
1560 bool IsLE = Q.DL.isLittleEndian();
1561 // Look through a cast from narrow vector elements to wider type.
1562 // Examples: v4i32 -> v2i64, v3i8 -> v24
1563 unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
1564 if (BitWidth % SubBitWidth == 0) {
1565 // Known bits are automatically intersected across demanded elements of a
1566 // vector. So for example, if a bit is computed as known zero, it must be
1567 // zero across all demanded elements of the vector.
1568 //
1569 // For this bitcast, each demanded element of the output is sub-divided
1570 // across a set of smaller vector elements in the source vector. To get
1571 // the known bits for an entire element of the output, compute the known
1572 // bits for each sub-element sequentially. This is done by shifting the
1573 // one-set-bit demanded elements parameter across the sub-elements for
1574 // consecutive calls to computeKnownBits. We are using the demanded
1575 // elements parameter as a mask operator.
1576 //
1577 // The known bits of each sub-element are then inserted into place
1578 // (dependent on endian) to form the full result of known bits.
1579 unsigned SubScale = BitWidth / SubBitWidth;
1580 APInt SubDemandedElts = APInt::getZero(numBits: NumElts * SubScale);
1581 for (unsigned i = 0; i != NumElts; ++i) {
1582 if (DemandedElts[i])
1583 SubDemandedElts.setBit(i * SubScale);
1584 }
1585
1586 KnownBits KnownSrc(SubBitWidth);
1587 for (unsigned i = 0; i != SubScale; ++i) {
1588 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts.shl(shiftAmt: i), Known&: KnownSrc, Q,
1589 Depth: Depth + 1);
1590 unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
1591 Known.insertBits(SubBits: KnownSrc, BitPosition: ShiftElt * SubBitWidth);
1592 }
1593 }
1594 // Look through a cast from wider vector elements to narrow type.
1595 // Examples: v2i64 -> v4i32
1596 if (SubBitWidth % BitWidth == 0) {
1597 unsigned SubScale = SubBitWidth / BitWidth;
1598 KnownBits KnownSrc(SubBitWidth);
1599 APInt SubDemandedElts =
1600 APIntOps::ScaleBitMask(A: DemandedElts, NewBitWidth: NumElts / SubScale);
1601 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts, Known&: KnownSrc, Q,
1602 Depth: Depth + 1);
1603
1604 Known.setAllConflict();
1605 for (unsigned i = 0; i != NumElts; ++i) {
1606 if (DemandedElts[i]) {
1607 unsigned Shifts = IsLE ? i : NumElts - 1 - i;
1608 unsigned Offset = (Shifts % SubScale) * BitWidth;
1609 Known = Known.intersectWith(RHS: KnownSrc.extractBits(NumBits: BitWidth, BitPosition: Offset));
1610 if (Known.isUnknown())
1611 break;
1612 }
1613 }
1614 }
1615 break;
1616 }
1617 case Instruction::SExt: {
1618 // Compute the bits in the result that are not present in the input.
1619 unsigned SrcBitWidth = I->getOperand(i: 0)->getType()->getScalarSizeInBits();
1620
1621 Known = Known.trunc(BitWidth: SrcBitWidth);
1622 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1623 // If the sign bit of the input is known set or clear, then we know the
1624 // top bits of the result.
1625 Known = Known.sext(BitWidth);
1626 break;
1627 }
1628 case Instruction::Shl: {
1629 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1630 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1631 auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1632 bool ShAmtNonZero) {
1633 return KnownBits::shl(LHS: KnownVal, RHS: KnownAmt, NUW, NSW, ShAmtNonZero);
1634 };
1635 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1636 KF);
1637 // Trailing zeros of a right-shifted constant never decrease.
1638 const APInt *C;
1639 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1640 Known.Zero.setLowBits(C->countr_zero());
1641 break;
1642 }
1643 case Instruction::LShr: {
1644 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1645 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1646 bool ShAmtNonZero) {
1647 return KnownBits::lshr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1648 };
1649 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1650 KF);
1651 // Leading zeros of a left-shifted constant never decrease.
1652 const APInt *C;
1653 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1654 Known.Zero.setHighBits(C->countl_zero());
1655 break;
1656 }
1657 case Instruction::AShr: {
1658 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1659 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1660 bool ShAmtNonZero) {
1661 return KnownBits::ashr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1662 };
1663 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1664 KF);
1665 break;
1666 }
1667 case Instruction::Sub: {
1668 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1669 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1670 computeKnownBitsAddSub(Add: false, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1671 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1672 break;
1673 }
1674 case Instruction::Add: {
1675 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1676 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1677 computeKnownBitsAddSub(Add: true, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1678 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1679 break;
1680 }
1681 case Instruction::SRem:
1682 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1683 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1684 Known = KnownBits::srem(LHS: Known, RHS: Known2);
1685 break;
1686
1687 case Instruction::URem:
1688 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1689 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1690 Known = KnownBits::urem(LHS: Known, RHS: Known2);
1691 break;
1692 case Instruction::Alloca:
1693 Known.Zero.setLowBits(Log2(A: cast<AllocaInst>(Val: I)->getAlign()));
1694 break;
1695 case Instruction::GetElementPtr: {
1696 // Analyze all of the subscripts of this getelementptr instruction
1697 // to determine if we can prove known low zero bits.
1698 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1699 // Accumulate the constant indices in a separate variable
1700 // to minimize the number of calls to computeForAddSub.
1701 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: I->getType());
1702 APInt AccConstIndices(IndexWidth, 0);
1703
1704 auto AddIndexToKnown = [&](KnownBits IndexBits) {
1705 if (IndexWidth == BitWidth) {
1706 // Note that inbounds does *not* guarantee nsw for the addition, as only
1707 // the offset is signed, while the base address is unsigned.
1708 Known = KnownBits::add(LHS: Known, RHS: IndexBits);
1709 } else {
1710 // If the index width is smaller than the pointer width, only add the
1711 // value to the low bits.
1712 assert(IndexWidth < BitWidth &&
1713 "Index width can't be larger than pointer width");
1714 Known.insertBits(SubBits: KnownBits::add(LHS: Known.trunc(BitWidth: IndexWidth), RHS: IndexBits), BitPosition: 0);
1715 }
1716 };
1717
1718 gep_type_iterator GTI = gep_type_begin(GEP: I);
1719 for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
1720 // TrailZ can only become smaller, short-circuit if we hit zero.
1721 if (Known.isUnknown())
1722 break;
1723
1724 Value *Index = I->getOperand(i);
1725
1726 // Handle case when index is zero.
1727 Constant *CIndex = dyn_cast<Constant>(Val: Index);
1728 if (CIndex && CIndex->isZeroValue())
1729 continue;
1730
1731 if (StructType *STy = GTI.getStructTypeOrNull()) {
1732 // Handle struct member offset arithmetic.
1733
1734 assert(CIndex &&
1735 "Access to structure field must be known at compile time");
1736
1737 if (CIndex->getType()->isVectorTy())
1738 Index = CIndex->getSplatValue();
1739
1740 unsigned Idx = cast<ConstantInt>(Val: Index)->getZExtValue();
1741 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
1742 uint64_t Offset = SL->getElementOffset(Idx);
1743 AccConstIndices += Offset;
1744 continue;
1745 }
1746
1747 // Handle array index arithmetic.
1748 Type *IndexedTy = GTI.getIndexedType();
1749 if (!IndexedTy->isSized()) {
1750 Known.resetAll();
1751 break;
1752 }
1753
1754 TypeSize Stride = GTI.getSequentialElementStride(DL: Q.DL);
1755 uint64_t StrideInBytes = Stride.getKnownMinValue();
1756 if (!Stride.isScalable()) {
1757 // Fast path for constant offset.
1758 if (auto *CI = dyn_cast<ConstantInt>(Val: Index)) {
1759 AccConstIndices +=
1760 CI->getValue().sextOrTrunc(width: IndexWidth) * StrideInBytes;
1761 continue;
1762 }
1763 }
1764
1765 KnownBits IndexBits =
1766 computeKnownBits(V: Index, Q, Depth: Depth + 1).sextOrTrunc(BitWidth: IndexWidth);
1767 KnownBits ScalingFactor(IndexWidth);
1768 // Multiply by current sizeof type.
1769 // &A[i] == A + i * sizeof(*A[i]).
1770 if (Stride.isScalable()) {
1771 // For scalable types the only thing we know about sizeof is
1772 // that this is a multiple of the minimum size.
1773 ScalingFactor.Zero.setLowBits(llvm::countr_zero(Val: StrideInBytes));
1774 } else {
1775 ScalingFactor =
1776 KnownBits::makeConstant(C: APInt(IndexWidth, StrideInBytes));
1777 }
1778 AddIndexToKnown(KnownBits::mul(LHS: IndexBits, RHS: ScalingFactor));
1779 }
1780 if (!Known.isUnknown() && !AccConstIndices.isZero())
1781 AddIndexToKnown(KnownBits::makeConstant(C: AccConstIndices));
1782 break;
1783 }
1784 case Instruction::PHI: {
1785 const PHINode *P = cast<PHINode>(Val: I);
1786 BinaryOperator *BO = nullptr;
1787 Value *R = nullptr, *L = nullptr;
1788 if (matchSimpleRecurrence(P, BO, Start&: R, Step&: L)) {
1789 // Handle the case of a simple two-predecessor recurrence PHI.
1790 // There's a lot more that could theoretically be done here, but
1791 // this is sufficient to catch some interesting cases.
1792 unsigned Opcode = BO->getOpcode();
1793
1794 switch (Opcode) {
1795 // If this is a shift recurrence, we know the bits being shifted in. We
1796 // can combine that with information about the start value of the
1797 // recurrence to conclude facts about the result. If this is a udiv
1798 // recurrence, we know that the result can never exceed either the
1799 // numerator or the start value, whichever is greater.
1800 case Instruction::LShr:
1801 case Instruction::AShr:
1802 case Instruction::Shl:
1803 case Instruction::UDiv:
1804 if (BO->getOperand(i_nocapture: 0) != I)
1805 break;
1806 [[fallthrough]];
1807
1808 // For a urem recurrence, the result can never exceed the start value. The
1809 // phi could either be the numerator or the denominator.
1810 case Instruction::URem: {
1811 // We have matched a recurrence of the form:
1812 // %iv = [R, %entry], [%iv.next, %backedge]
1813 // %iv.next = shift_op %iv, L
1814
1815 // Recurse with the phi context to avoid concern about whether facts
1816 // inferred hold at original context instruction. TODO: It may be
1817 // correct to use the original context. IF warranted, explore and
1818 // add sufficient tests to cover.
1819 SimplifyQuery RecQ = Q.getWithoutCondContext();
1820 RecQ.CxtI = P;
1821 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1822 switch (Opcode) {
1823 case Instruction::Shl:
1824 // A shl recurrence will only increase the tailing zeros
1825 Known.Zero.setLowBits(Known2.countMinTrailingZeros());
1826 break;
1827 case Instruction::LShr:
1828 case Instruction::UDiv:
1829 case Instruction::URem:
1830 // lshr, udiv, and urem recurrences will preserve the leading zeros of
1831 // the start value.
1832 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1833 break;
1834 case Instruction::AShr:
1835 // An ashr recurrence will extend the initial sign bit
1836 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1837 Known.One.setHighBits(Known2.countMinLeadingOnes());
1838 break;
1839 }
1840 break;
1841 }
1842
1843 // Check for operations that have the property that if
1844 // both their operands have low zero bits, the result
1845 // will have low zero bits.
1846 case Instruction::Add:
1847 case Instruction::Sub:
1848 case Instruction::And:
1849 case Instruction::Or:
1850 case Instruction::Mul: {
1851 // Change the context instruction to the "edge" that flows into the
1852 // phi. This is important because that is where the value is actually
1853 // "evaluated" even though it is used later somewhere else. (see also
1854 // D69571).
1855 SimplifyQuery RecQ = Q.getWithoutCondContext();
1856
1857 unsigned OpNum = P->getOperand(i_nocapture: 0) == R ? 0 : 1;
1858 Instruction *RInst = P->getIncomingBlock(i: OpNum)->getTerminator();
1859 Instruction *LInst = P->getIncomingBlock(i: 1 - OpNum)->getTerminator();
1860
1861 // Ok, we have a PHI of the form L op= R. Check for low
1862 // zero bits.
1863 RecQ.CxtI = RInst;
1864 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1865
1866 // We need to take the minimum number of known bits
1867 KnownBits Known3(BitWidth);
1868 RecQ.CxtI = LInst;
1869 computeKnownBits(V: L, DemandedElts, Known&: Known3, Q: RecQ, Depth: Depth + 1);
1870
1871 Known.Zero.setLowBits(std::min(a: Known2.countMinTrailingZeros(),
1872 b: Known3.countMinTrailingZeros()));
1873
1874 auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(Val: BO);
1875 if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(Op: OverflowOp))
1876 break;
1877
1878 switch (Opcode) {
1879 // If initial value of recurrence is nonnegative, and we are adding
1880 // a nonnegative number with nsw, the result can only be nonnegative
1881 // or poison value regardless of the number of times we execute the
1882 // add in phi recurrence. If initial value is negative and we are
1883 // adding a negative number with nsw, the result can only be
1884 // negative or poison value. Similar arguments apply to sub and mul.
1885 //
1886 // (add non-negative, non-negative) --> non-negative
1887 // (add negative, negative) --> negative
1888 case Instruction::Add: {
1889 if (Known2.isNonNegative() && Known3.isNonNegative())
1890 Known.makeNonNegative();
1891 else if (Known2.isNegative() && Known3.isNegative())
1892 Known.makeNegative();
1893 break;
1894 }
1895
1896 // (sub nsw non-negative, negative) --> non-negative
1897 // (sub nsw negative, non-negative) --> negative
1898 case Instruction::Sub: {
1899 if (BO->getOperand(i_nocapture: 0) != I)
1900 break;
1901 if (Known2.isNonNegative() && Known3.isNegative())
1902 Known.makeNonNegative();
1903 else if (Known2.isNegative() && Known3.isNonNegative())
1904 Known.makeNegative();
1905 break;
1906 }
1907
1908 // (mul nsw non-negative, non-negative) --> non-negative
1909 case Instruction::Mul:
1910 if (Known2.isNonNegative() && Known3.isNonNegative())
1911 Known.makeNonNegative();
1912 break;
1913
1914 default:
1915 break;
1916 }
1917 break;
1918 }
1919
1920 default:
1921 break;
1922 }
1923 }
1924
1925 // Unreachable blocks may have zero-operand PHI nodes.
1926 if (P->getNumIncomingValues() == 0)
1927 break;
1928
1929 // Otherwise take the unions of the known bit sets of the operands,
1930 // taking conservative care to avoid excessive recursion.
1931 if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
1932 // Skip if every incoming value references to ourself.
1933 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
1934 break;
1935
1936 Known.setAllConflict();
1937 for (const Use &U : P->operands()) {
1938 Value *IncValue;
1939 const PHINode *CxtPhi;
1940 Instruction *CxtI;
1941 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI, PhiOut: &CxtPhi);
1942 // Skip direct self references.
1943 if (IncValue == P)
1944 continue;
1945
1946 // Change the context instruction to the "edge" that flows into the
1947 // phi. This is important because that is where the value is actually
1948 // "evaluated" even though it is used later somewhere else. (see also
1949 // D69571).
1950 SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(I: CxtI);
1951
1952 Known2 = KnownBits(BitWidth);
1953
1954 // Recurse, but cap the recursion to one level, because we don't
1955 // want to waste time spinning around in loops.
1956 // TODO: See if we can base recursion limiter on number of incoming phi
1957 // edges so we don't overly clamp analysis.
1958 computeKnownBits(V: IncValue, DemandedElts, Known&: Known2, Q: RecQ,
1959 Depth: MaxAnalysisRecursionDepth - 1);
1960
1961 // See if we can further use a conditional branch into the phi
1962 // to help us determine the range of the value.
1963 if (!Known2.isConstant()) {
1964 CmpPredicate Pred;
1965 const APInt *RHSC;
1966 BasicBlock *TrueSucc, *FalseSucc;
1967 // TODO: Use RHS Value and compute range from its known bits.
1968 if (match(V: RecQ.CxtI,
1969 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: IncValue), R: m_APInt(Res&: RHSC)),
1970 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
1971 // Check for cases of duplicate successors.
1972 if ((TrueSucc == CxtPhi->getParent()) !=
1973 (FalseSucc == CxtPhi->getParent())) {
1974 // If we're using the false successor, invert the predicate.
1975 if (FalseSucc == CxtPhi->getParent())
1976 Pred = CmpInst::getInversePredicate(pred: Pred);
1977 // Get the knownbits implied by the incoming phi condition.
1978 auto CR = ConstantRange::makeExactICmpRegion(Pred, Other: *RHSC);
1979 KnownBits KnownUnion = Known2.unionWith(RHS: CR.toKnownBits());
1980 // We can have conflicts here if we are analyzing deadcode (its
1981 // impossible for us reach this BB based the icmp).
1982 if (KnownUnion.hasConflict()) {
1983 // No reason to continue analyzing in a known dead region, so
1984 // just resetAll and break. This will cause us to also exit the
1985 // outer loop.
1986 Known.resetAll();
1987 break;
1988 }
1989 Known2 = KnownUnion;
1990 }
1991 }
1992 }
1993
1994 Known = Known.intersectWith(RHS: Known2);
1995 // If all bits have been ruled out, there's no need to check
1996 // more operands.
1997 if (Known.isUnknown())
1998 break;
1999 }
2000 }
2001 break;
2002 }
2003 case Instruction::Call:
2004 case Instruction::Invoke: {
2005 // If range metadata is attached to this call, set known bits from that,
2006 // and then intersect with known bits based on other properties of the
2007 // function.
2008 if (MDNode *MD =
2009 Q.IIQ.getMetadata(I: cast<Instruction>(Val: I), KindID: LLVMContext::MD_range))
2010 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
2011
2012 const auto *CB = cast<CallBase>(Val: I);
2013
2014 if (std::optional<ConstantRange> Range = CB->getRange())
2015 Known = Known.unionWith(RHS: Range->toKnownBits());
2016
2017 if (const Value *RV = CB->getReturnedArgOperand()) {
2018 if (RV->getType() == I->getType()) {
2019 computeKnownBits(V: RV, Known&: Known2, Q, Depth: Depth + 1);
2020 Known = Known.unionWith(RHS: Known2);
2021 // If the function doesn't return properly for all input values
2022 // (e.g. unreachable exits) then there might be conflicts between the
2023 // argument value and the range metadata. Simply discard the known bits
2024 // in case of conflicts.
2025 if (Known.hasConflict())
2026 Known.resetAll();
2027 }
2028 }
2029 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) {
2030 switch (II->getIntrinsicID()) {
2031 default:
2032 break;
2033 case Intrinsic::abs: {
2034 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2035 bool IntMinIsPoison = match(V: II->getArgOperand(i: 1), P: m_One());
2036 Known = Known.unionWith(RHS: Known2.abs(IntMinIsPoison));
2037 break;
2038 }
2039 case Intrinsic::bitreverse:
2040 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2041 Known = Known.unionWith(RHS: Known2.reverseBits());
2042 break;
2043 case Intrinsic::bswap:
2044 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2045 Known = Known.unionWith(RHS: Known2.byteSwap());
2046 break;
2047 case Intrinsic::ctlz: {
2048 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2049 // If we have a known 1, its position is our upper bound.
2050 unsigned PossibleLZ = Known2.countMaxLeadingZeros();
2051 // If this call is poison for 0 input, the result will be less than 2^n.
2052 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2053 PossibleLZ = std::min(a: PossibleLZ, b: BitWidth - 1);
2054 unsigned LowBits = llvm::bit_width(Value: PossibleLZ);
2055 Known.Zero.setBitsFrom(LowBits);
2056 break;
2057 }
2058 case Intrinsic::cttz: {
2059 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2060 // If we have a known 1, its position is our upper bound.
2061 unsigned PossibleTZ = Known2.countMaxTrailingZeros();
2062 // If this call is poison for 0 input, the result will be less than 2^n.
2063 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2064 PossibleTZ = std::min(a: PossibleTZ, b: BitWidth - 1);
2065 unsigned LowBits = llvm::bit_width(Value: PossibleTZ);
2066 Known.Zero.setBitsFrom(LowBits);
2067 break;
2068 }
2069 case Intrinsic::ctpop: {
2070 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2071 // We can bound the space the count needs. Also, bits known to be zero
2072 // can't contribute to the population.
2073 unsigned BitsPossiblySet = Known2.countMaxPopulation();
2074 unsigned LowBits = llvm::bit_width(Value: BitsPossiblySet);
2075 Known.Zero.setBitsFrom(LowBits);
2076 // TODO: we could bound KnownOne using the lower bound on the number
2077 // of bits which might be set provided by popcnt KnownOne2.
2078 break;
2079 }
2080 case Intrinsic::fshr:
2081 case Intrinsic::fshl: {
2082 const APInt *SA;
2083 if (!match(V: I->getOperand(i: 2), P: m_APInt(Res&: SA)))
2084 break;
2085
2086 // Normalize to funnel shift left.
2087 uint64_t ShiftAmt = SA->urem(RHS: BitWidth);
2088 if (II->getIntrinsicID() == Intrinsic::fshr)
2089 ShiftAmt = BitWidth - ShiftAmt;
2090
2091 KnownBits Known3(BitWidth);
2092 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2093 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known3, Q, Depth: Depth + 1);
2094
2095 Known2 <<= ShiftAmt;
2096 Known3 >>= BitWidth - ShiftAmt;
2097 Known = Known2.unionWith(RHS: Known3);
2098 break;
2099 }
2100 case Intrinsic::clmul:
2101 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2102 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2103 Known = KnownBits::clmul(LHS: Known, RHS: Known2);
2104 break;
2105 case Intrinsic::uadd_sat:
2106 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2107 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2108 Known = KnownBits::uadd_sat(LHS: Known, RHS: Known2);
2109 break;
2110 case Intrinsic::usub_sat:
2111 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2112 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2113 Known = KnownBits::usub_sat(LHS: Known, RHS: Known2);
2114 break;
2115 case Intrinsic::sadd_sat:
2116 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2117 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2118 Known = KnownBits::sadd_sat(LHS: Known, RHS: Known2);
2119 break;
2120 case Intrinsic::ssub_sat:
2121 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2122 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2123 Known = KnownBits::ssub_sat(LHS: Known, RHS: Known2);
2124 break;
2125 // Vec reverse preserves bits from input vec.
2126 case Intrinsic::vector_reverse:
2127 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: DemandedElts.reverseBits(), Known, Q,
2128 Depth: Depth + 1);
2129 break;
2130 // for min/max/and/or reduce, any bit common to each element in the
2131 // input vec is set in the output.
2132 case Intrinsic::vector_reduce_and:
2133 case Intrinsic::vector_reduce_or:
2134 case Intrinsic::vector_reduce_umax:
2135 case Intrinsic::vector_reduce_umin:
2136 case Intrinsic::vector_reduce_smax:
2137 case Intrinsic::vector_reduce_smin:
2138 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2139 break;
2140 case Intrinsic::vector_reduce_xor: {
2141 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2142 // The zeros common to all vecs are zero in the output.
2143 // If the number of elements is odd, then the common ones remain. If the
2144 // number of elements is even, then the common ones becomes zeros.
2145 auto *VecTy = cast<VectorType>(Val: I->getOperand(i: 0)->getType());
2146 // Even, so the ones become zeros.
2147 bool EvenCnt = VecTy->getElementCount().isKnownEven();
2148 if (EvenCnt)
2149 Known.Zero |= Known.One;
2150 // Maybe even element count so need to clear ones.
2151 if (VecTy->isScalableTy() || EvenCnt)
2152 Known.One.clearAllBits();
2153 break;
2154 }
2155 case Intrinsic::vector_reduce_add: {
2156 auto *VecTy = dyn_cast<FixedVectorType>(Val: I->getOperand(i: 0)->getType());
2157 if (!VecTy)
2158 break;
2159 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2160 Known = Known.reduceAdd(NumElts: VecTy->getNumElements());
2161 break;
2162 }
2163 case Intrinsic::umin:
2164 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2165 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2166 Known = KnownBits::umin(LHS: Known, RHS: Known2);
2167 break;
2168 case Intrinsic::umax:
2169 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2170 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2171 Known = KnownBits::umax(LHS: Known, RHS: Known2);
2172 break;
2173 case Intrinsic::smin:
2174 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2175 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2176 Known = KnownBits::smin(LHS: Known, RHS: Known2);
2177 unionWithMinMaxIntrinsicClamp(II, Known);
2178 break;
2179 case Intrinsic::smax:
2180 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2181 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2182 Known = KnownBits::smax(LHS: Known, RHS: Known2);
2183 unionWithMinMaxIntrinsicClamp(II, Known);
2184 break;
2185 case Intrinsic::ptrmask: {
2186 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2187
2188 const Value *Mask = I->getOperand(i: 1);
2189 Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
2190 computeKnownBits(V: Mask, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2191 // TODO: 1-extend would be more precise.
2192 Known &= Known2.anyextOrTrunc(BitWidth);
2193 break;
2194 }
2195 case Intrinsic::x86_sse2_pmulh_w:
2196 case Intrinsic::x86_avx2_pmulh_w:
2197 case Intrinsic::x86_avx512_pmulh_w_512:
2198 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2199 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2200 Known = KnownBits::mulhs(LHS: Known, RHS: Known2);
2201 break;
2202 case Intrinsic::x86_sse2_pmulhu_w:
2203 case Intrinsic::x86_avx2_pmulhu_w:
2204 case Intrinsic::x86_avx512_pmulhu_w_512:
2205 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2206 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2207 Known = KnownBits::mulhu(LHS: Known, RHS: Known2);
2208 break;
2209 case Intrinsic::x86_sse42_crc32_64_64:
2210 Known.Zero.setBitsFrom(32);
2211 break;
2212 case Intrinsic::x86_ssse3_phadd_d_128:
2213 case Intrinsic::x86_ssse3_phadd_w_128:
2214 case Intrinsic::x86_avx2_phadd_d:
2215 case Intrinsic::x86_avx2_phadd_w: {
2216 Known = computeKnownBitsForHorizontalOperation(
2217 I, DemandedElts, Q, Depth,
2218 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2219 return KnownBits::add(LHS: KnownLHS, RHS: KnownRHS);
2220 });
2221 break;
2222 }
2223 case Intrinsic::x86_ssse3_phadd_sw_128:
2224 case Intrinsic::x86_avx2_phadd_sw: {
2225 Known = computeKnownBitsForHorizontalOperation(
2226 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::sadd_sat);
2227 break;
2228 }
2229 case Intrinsic::x86_ssse3_phsub_d_128:
2230 case Intrinsic::x86_ssse3_phsub_w_128:
2231 case Intrinsic::x86_avx2_phsub_d:
2232 case Intrinsic::x86_avx2_phsub_w: {
2233 Known = computeKnownBitsForHorizontalOperation(
2234 I, DemandedElts, Q, Depth,
2235 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2236 return KnownBits::sub(LHS: KnownLHS, RHS: KnownRHS);
2237 });
2238 break;
2239 }
2240 case Intrinsic::x86_ssse3_phsub_sw_128:
2241 case Intrinsic::x86_avx2_phsub_sw: {
2242 Known = computeKnownBitsForHorizontalOperation(
2243 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::ssub_sat);
2244 break;
2245 }
2246 case Intrinsic::riscv_vsetvli:
2247 case Intrinsic::riscv_vsetvlimax: {
2248 bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
2249 const ConstantRange Range = getVScaleRange(F: II->getFunction(), BitWidth);
2250 uint64_t SEW = RISCVVType::decodeVSEW(
2251 VSEW: cast<ConstantInt>(Val: II->getArgOperand(i: HasAVL))->getZExtValue());
2252 RISCVVType::VLMUL VLMUL = static_cast<RISCVVType::VLMUL>(
2253 cast<ConstantInt>(Val: II->getArgOperand(i: 1 + HasAVL))->getZExtValue());
2254 uint64_t MaxVLEN =
2255 Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
2256 uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMul: VLMUL);
2257
2258 // Result of vsetvli must be not larger than AVL.
2259 if (HasAVL)
2260 if (auto *CI = dyn_cast<ConstantInt>(Val: II->getArgOperand(i: 0)))
2261 MaxVL = std::min(a: MaxVL, b: CI->getZExtValue());
2262
2263 unsigned KnownZeroFirstBit = Log2_32(Value: MaxVL) + 1;
2264 if (BitWidth > KnownZeroFirstBit)
2265 Known.Zero.setBitsFrom(KnownZeroFirstBit);
2266 break;
2267 }
2268 case Intrinsic::vscale: {
2269 if (!II->getParent() || !II->getFunction())
2270 break;
2271
2272 Known = getVScaleRange(F: II->getFunction(), BitWidth).toKnownBits();
2273 break;
2274 }
2275 }
2276 }
2277 break;
2278 }
2279 case Instruction::ShuffleVector: {
2280 if (auto *Splat = getSplatValue(V: I)) {
2281 computeKnownBits(V: Splat, Known, Q, Depth: Depth + 1);
2282 break;
2283 }
2284
2285 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
2286 // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
2287 if (!Shuf) {
2288 Known.resetAll();
2289 return;
2290 }
2291 // For undef elements, we don't know anything about the common state of
2292 // the shuffle result.
2293 APInt DemandedLHS, DemandedRHS;
2294 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
2295 Known.resetAll();
2296 return;
2297 }
2298 Known.setAllConflict();
2299 if (!!DemandedLHS) {
2300 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
2301 computeKnownBits(V: LHS, DemandedElts: DemandedLHS, Known, Q, Depth: Depth + 1);
2302 // If we don't know any bits, early out.
2303 if (Known.isUnknown())
2304 break;
2305 }
2306 if (!!DemandedRHS) {
2307 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
2308 computeKnownBits(V: RHS, DemandedElts: DemandedRHS, Known&: Known2, Q, Depth: Depth + 1);
2309 Known = Known.intersectWith(RHS: Known2);
2310 }
2311 break;
2312 }
2313 case Instruction::InsertElement: {
2314 if (isa<ScalableVectorType>(Val: I->getType())) {
2315 Known.resetAll();
2316 return;
2317 }
2318 const Value *Vec = I->getOperand(i: 0);
2319 const Value *Elt = I->getOperand(i: 1);
2320 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
2321 unsigned NumElts = DemandedElts.getBitWidth();
2322 APInt DemandedVecElts = DemandedElts;
2323 bool NeedsElt = true;
2324 // If we know the index we are inserting too, clear it from Vec check.
2325 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
2326 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
2327 NeedsElt = DemandedElts[CIdx->getZExtValue()];
2328 }
2329
2330 Known.setAllConflict();
2331 if (NeedsElt) {
2332 computeKnownBits(V: Elt, Known, Q, Depth: Depth + 1);
2333 // If we don't know any bits, early out.
2334 if (Known.isUnknown())
2335 break;
2336 }
2337
2338 if (!DemandedVecElts.isZero()) {
2339 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known&: Known2, Q, Depth: Depth + 1);
2340 Known = Known.intersectWith(RHS: Known2);
2341 }
2342 break;
2343 }
2344 case Instruction::ExtractElement: {
2345 // Look through extract element. If the index is non-constant or
2346 // out-of-range demand all elements, otherwise just the extracted element.
2347 const Value *Vec = I->getOperand(i: 0);
2348 const Value *Idx = I->getOperand(i: 1);
2349 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
2350 if (isa<ScalableVectorType>(Val: Vec->getType())) {
2351 // FIXME: there's probably *something* we can do with scalable vectors
2352 Known.resetAll();
2353 break;
2354 }
2355 unsigned NumElts = cast<FixedVectorType>(Val: Vec->getType())->getNumElements();
2356 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
2357 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
2358 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
2359 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known, Q, Depth: Depth + 1);
2360 break;
2361 }
2362 case Instruction::ExtractValue:
2363 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I->getOperand(i: 0))) {
2364 const ExtractValueInst *EVI = cast<ExtractValueInst>(Val: I);
2365 if (EVI->getNumIndices() != 1) break;
2366 if (EVI->getIndices()[0] == 0) {
2367 switch (II->getIntrinsicID()) {
2368 default: break;
2369 case Intrinsic::uadd_with_overflow:
2370 case Intrinsic::sadd_with_overflow:
2371 computeKnownBitsAddSub(
2372 Add: true, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2373 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2374 break;
2375 case Intrinsic::usub_with_overflow:
2376 case Intrinsic::ssub_with_overflow:
2377 computeKnownBitsAddSub(
2378 Add: false, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2379 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2380 break;
2381 case Intrinsic::umul_with_overflow:
2382 case Intrinsic::smul_with_overflow:
2383 computeKnownBitsMul(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), NSW: false,
2384 NUW: false, DemandedElts, Known, Known2, Q, Depth);
2385 break;
2386 }
2387 }
2388 }
2389 break;
2390 case Instruction::Freeze:
2391 if (isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
2392 Depth: Depth + 1))
2393 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2394 break;
2395 }
2396}
2397
2398/// Determine which bits of V are known to be either zero or one and return
2399/// them.
2400KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
2401 const SimplifyQuery &Q, unsigned Depth) {
2402 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2403 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
2404 return Known;
2405}
2406
2407/// Determine which bits of V are known to be either zero or one and return
2408/// them.
2409KnownBits llvm::computeKnownBits(const Value *V, const SimplifyQuery &Q,
2410 unsigned Depth) {
2411 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2412 computeKnownBits(V, Known, Q, Depth);
2413 return Known;
2414}
2415
2416/// Determine which bits of V are known to be either zero or one and return
2417/// them in the Known bit set.
2418///
2419/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that
2420/// we cannot optimize based on the assumption that it is zero without changing
2421/// it to be an explicit zero. If we don't change it to zero, other code could
2422/// optimized based on the contradictory assumption that it is non-zero.
2423/// Because instcombine aggressively folds operations with undef args anyway,
2424/// this won't lose us code quality.
2425///
2426/// This function is defined on values with integer type, values with pointer
2427/// type, and vectors of integers. In the case
2428/// where V is a vector, known zero, and known one values are the
2429/// same width as the vector element, and the bit is set only if it is true
2430/// for all of the demanded elements in the vector specified by DemandedElts.
2431void computeKnownBits(const Value *V, const APInt &DemandedElts,
2432 KnownBits &Known, const SimplifyQuery &Q,
2433 unsigned Depth) {
2434 if (!DemandedElts) {
2435 // No demanded elts, better to assume we don't know anything.
2436 Known.resetAll();
2437 return;
2438 }
2439
2440 assert(V && "No Value?");
2441 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2442
2443#ifndef NDEBUG
2444 Type *Ty = V->getType();
2445 unsigned BitWidth = Known.getBitWidth();
2446
2447 assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
2448 "Not integer or pointer type!");
2449
2450 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
2451 assert(
2452 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
2453 "DemandedElt width should equal the fixed vector number of elements");
2454 } else {
2455 assert(DemandedElts == APInt(1, 1) &&
2456 "DemandedElt width should be 1 for scalars or scalable vectors");
2457 }
2458
2459 Type *ScalarTy = Ty->getScalarType();
2460 if (ScalarTy->isPointerTy()) {
2461 assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
2462 "V and Known should have same BitWidth");
2463 } else {
2464 assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
2465 "V and Known should have same BitWidth");
2466 }
2467#endif
2468
2469 const APInt *C;
2470 if (match(V, P: m_APInt(Res&: C))) {
2471 // We know all of the bits for a scalar constant or a splat vector constant!
2472 Known = KnownBits::makeConstant(C: *C);
2473 return;
2474 }
2475 // Null and aggregate-zero are all-zeros.
2476 if (isa<ConstantPointerNull>(Val: V) || isa<ConstantAggregateZero>(Val: V)) {
2477 Known.setAllZero();
2478 return;
2479 }
2480 // Handle a constant vector by taking the intersection of the known bits of
2481 // each element.
2482 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(Val: V)) {
2483 assert(!isa<ScalableVectorType>(V->getType()));
2484 // We know that CDV must be a vector of integers. Take the intersection of
2485 // each element.
2486 Known.setAllConflict();
2487 for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
2488 if (!DemandedElts[i])
2489 continue;
2490 APInt Elt = CDV->getElementAsAPInt(i);
2491 Known.Zero &= ~Elt;
2492 Known.One &= Elt;
2493 }
2494 if (Known.hasConflict())
2495 Known.resetAll();
2496 return;
2497 }
2498
2499 if (const auto *CV = dyn_cast<ConstantVector>(Val: V)) {
2500 assert(!isa<ScalableVectorType>(V->getType()));
2501 // We know that CV must be a vector of integers. Take the intersection of
2502 // each element.
2503 Known.setAllConflict();
2504 for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
2505 if (!DemandedElts[i])
2506 continue;
2507 Constant *Element = CV->getAggregateElement(Elt: i);
2508 if (isa<PoisonValue>(Val: Element))
2509 continue;
2510 auto *ElementCI = dyn_cast_or_null<ConstantInt>(Val: Element);
2511 if (!ElementCI) {
2512 Known.resetAll();
2513 return;
2514 }
2515 const APInt &Elt = ElementCI->getValue();
2516 Known.Zero &= ~Elt;
2517 Known.One &= Elt;
2518 }
2519 if (Known.hasConflict())
2520 Known.resetAll();
2521 return;
2522 }
2523
2524 // Start out not knowing anything.
2525 Known.resetAll();
2526
2527 // We can't imply anything about undefs.
2528 if (isa<UndefValue>(Val: V))
2529 return;
2530
2531 // There's no point in looking through other users of ConstantData for
2532 // assumptions. Confirm that we've handled them all.
2533 assert(!isa<ConstantData>(V) && "Unhandled constant data!");
2534
2535 if (const auto *A = dyn_cast<Argument>(Val: V))
2536 if (std::optional<ConstantRange> Range = A->getRange())
2537 Known = Range->toKnownBits();
2538
2539 // All recursive calls that increase depth must come after this.
2540 if (Depth == MaxAnalysisRecursionDepth)
2541 return;
2542
2543 // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
2544 // the bits of its aliasee.
2545 if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(Val: V)) {
2546 if (!GA->isInterposable())
2547 computeKnownBits(V: GA->getAliasee(), Known, Q, Depth: Depth + 1);
2548 return;
2549 }
2550
2551 if (const Operator *I = dyn_cast<Operator>(Val: V))
2552 computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
2553 else if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
2554 if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
2555 Known = CR->toKnownBits();
2556 }
2557
2558 // Aligned pointers have trailing zeros - refine Known.Zero set
2559 if (isa<PointerType>(Val: V->getType())) {
2560 Align Alignment = V->getPointerAlignment(DL: Q.DL);
2561 Known.Zero.setLowBits(Log2(A: Alignment));
2562 }
2563
2564 // computeKnownBitsFromContext strictly refines Known.
2565 // Therefore, we run them after computeKnownBitsFromOperator.
2566
2567 // Check whether we can determine known bits from context such as assumes.
2568 computeKnownBitsFromContext(V, Known, Q, Depth);
2569}
2570
2571/// Try to detect a recurrence that the value of the induction variable is
2572/// always a power of two (or zero).
2573static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
2574 SimplifyQuery &Q, unsigned Depth) {
2575 BinaryOperator *BO = nullptr;
2576 Value *Start = nullptr, *Step = nullptr;
2577 if (!matchSimpleRecurrence(P: PN, BO, Start, Step))
2578 return false;
2579
2580 // Initial value must be a power of two.
2581 for (const Use &U : PN->operands()) {
2582 if (U.get() == Start) {
2583 // Initial value comes from a different BB, need to adjust context
2584 // instruction for analysis.
2585 Q.CxtI = PN->getIncomingBlock(U)->getTerminator();
2586 if (!isKnownToBeAPowerOfTwo(V: Start, OrZero, Q, Depth))
2587 return false;
2588 }
2589 }
2590
2591 // Except for Mul, the induction variable must be on the left side of the
2592 // increment expression, otherwise its value can be arbitrary.
2593 if (BO->getOpcode() != Instruction::Mul && BO->getOperand(i_nocapture: 1) != Step)
2594 return false;
2595
2596 Q.CxtI = BO->getParent()->getTerminator();
2597 switch (BO->getOpcode()) {
2598 case Instruction::Mul:
2599 // Power of two is closed under multiplication.
2600 return (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) ||
2601 Q.IIQ.hasNoSignedWrap(Op: BO)) &&
2602 isKnownToBeAPowerOfTwo(V: Step, OrZero, Q, Depth);
2603 case Instruction::SDiv:
2604 // Start value must not be signmask for signed division, so simply being a
2605 // power of two is not sufficient, and it has to be a constant.
2606 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2607 return false;
2608 [[fallthrough]];
2609 case Instruction::UDiv:
2610 // Divisor must be a power of two.
2611 // If OrZero is false, cannot guarantee induction variable is non-zero after
2612 // division, same for Shr, unless it is exact division.
2613 return (OrZero || Q.IIQ.isExact(Op: BO)) &&
2614 isKnownToBeAPowerOfTwo(V: Step, OrZero: false, Q, Depth);
2615 case Instruction::Shl:
2616 return OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO);
2617 case Instruction::AShr:
2618 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2619 return false;
2620 [[fallthrough]];
2621 case Instruction::LShr:
2622 return OrZero || Q.IIQ.isExact(Op: BO);
2623 default:
2624 return false;
2625 }
2626}
2627
2628/// Return true if we can infer that \p V is known to be a power of 2 from
2629/// dominating condition \p Cond (e.g., ctpop(V) == 1).
2630static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
2631 const Value *Cond,
2632 bool CondIsTrue) {
2633 CmpPredicate Pred;
2634 const APInt *RHSC;
2635 if (!match(V: Cond, P: m_ICmp(Pred, L: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Specific(V)),
2636 R: m_APInt(Res&: RHSC))))
2637 return false;
2638 if (!CondIsTrue)
2639 Pred = ICmpInst::getInversePredicate(pred: Pred);
2640 // ctpop(V) u< 2
2641 if (OrZero && Pred == ICmpInst::ICMP_ULT && *RHSC == 2)
2642 return true;
2643 // ctpop(V) == 1
2644 return Pred == ICmpInst::ICMP_EQ && *RHSC == 1;
2645}
2646
2647/// Return true if the given value is known to have exactly one
2648/// bit set when defined. For vectors return true if every element is known to
2649/// be a power of two when defined. Supports values with integer or pointer
2650/// types and vectors of integers.
2651bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero,
2652 const SimplifyQuery &Q, unsigned Depth) {
2653 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2654
2655 if (isa<Constant>(Val: V))
2656 return OrZero ? match(V, P: m_Power2OrZero()) : match(V, P: m_Power2());
2657
2658 // i1 is by definition a power of 2 or zero.
2659 if (OrZero && V->getType()->getScalarSizeInBits() == 1)
2660 return true;
2661
2662 // Try to infer from assumptions.
2663 if (Q.AC && Q.CxtI) {
2664 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
2665 if (!AssumeVH)
2666 continue;
2667 CallInst *I = cast<CallInst>(Val&: AssumeVH);
2668 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond: I->getArgOperand(i: 0),
2669 /*CondIsTrue=*/true) &&
2670 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
2671 return true;
2672 }
2673 }
2674
2675 // Handle dominating conditions.
2676 if (Q.DC && Q.CxtI && Q.DT) {
2677 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
2678 Value *Cond = BI->getCondition();
2679
2680 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
2681 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2682 /*CondIsTrue=*/true) &&
2683 Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
2684 return true;
2685
2686 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
2687 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2688 /*CondIsTrue=*/false) &&
2689 Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
2690 return true;
2691 }
2692 }
2693
2694 auto *I = dyn_cast<Instruction>(Val: V);
2695 if (!I)
2696 return false;
2697
2698 if (Q.CxtI && match(V, P: m_VScale())) {
2699 const Function *F = Q.CxtI->getFunction();
2700 // The vscale_range indicates vscale is a power-of-two.
2701 return F->hasFnAttribute(Kind: Attribute::VScaleRange);
2702 }
2703
2704 // 1 << X is clearly a power of two if the one is not shifted off the end. If
2705 // it is shifted off the end then the result is undefined.
2706 if (match(V: I, P: m_Shl(L: m_One(), R: m_Value())))
2707 return true;
2708
2709 // (signmask) >>l X is clearly a power of two if the one is not shifted off
2710 // the bottom. If it is shifted off the bottom then the result is undefined.
2711 if (match(V: I, P: m_LShr(L: m_SignMask(), R: m_Value())))
2712 return true;
2713
2714 // The remaining tests are all recursive, so bail out if we hit the limit.
2715 if (Depth++ == MaxAnalysisRecursionDepth)
2716 return false;
2717
2718 switch (I->getOpcode()) {
2719 case Instruction::ZExt:
2720 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2721 case Instruction::Trunc:
2722 return OrZero && isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2723 case Instruction::Shl:
2724 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: I) || Q.IIQ.hasNoSignedWrap(Op: I))
2725 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2726 return false;
2727 case Instruction::LShr:
2728 if (OrZero || Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2729 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2730 return false;
2731 case Instruction::UDiv:
2732 if (Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2733 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2734 return false;
2735 case Instruction::Mul:
2736 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2737 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth) &&
2738 (OrZero || isKnownNonZero(V: I, Q, Depth));
2739 case Instruction::And:
2740 // A power of two and'd with anything is a power of two or zero.
2741 if (OrZero &&
2742 (isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), /*OrZero*/ true, Q, Depth) ||
2743 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), /*OrZero*/ true, Q, Depth)))
2744 return true;
2745 // X & (-X) is always a power of two or zero.
2746 if (match(V: I->getOperand(i: 0), P: m_Neg(V: m_Specific(V: I->getOperand(i: 1)))) ||
2747 match(V: I->getOperand(i: 1), P: m_Neg(V: m_Specific(V: I->getOperand(i: 0)))))
2748 return OrZero || isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
2749 return false;
2750 case Instruction::Add: {
2751 // Adding a power-of-two or zero to the same power-of-two or zero yields
2752 // either the original power-of-two, a larger power-of-two or zero.
2753 const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(Val: V);
2754 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO) ||
2755 Q.IIQ.hasNoSignedWrap(Op: VOBO)) {
2756 if (match(V: I->getOperand(i: 0),
2757 P: m_c_And(L: m_Specific(V: I->getOperand(i: 1)), R: m_Value())) &&
2758 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth))
2759 return true;
2760 if (match(V: I->getOperand(i: 1),
2761 P: m_c_And(L: m_Specific(V: I->getOperand(i: 0)), R: m_Value())) &&
2762 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth))
2763 return true;
2764
2765 unsigned BitWidth = V->getType()->getScalarSizeInBits();
2766 KnownBits LHSBits(BitWidth);
2767 computeKnownBits(V: I->getOperand(i: 0), Known&: LHSBits, Q, Depth);
2768
2769 KnownBits RHSBits(BitWidth);
2770 computeKnownBits(V: I->getOperand(i: 1), Known&: RHSBits, Q, Depth);
2771 // If i8 V is a power of two or zero:
2772 // ZeroBits: 1 1 1 0 1 1 1 1
2773 // ~ZeroBits: 0 0 0 1 0 0 0 0
2774 if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
2775 // If OrZero isn't set, we cannot give back a zero result.
2776 // Make sure either the LHS or RHS has a bit set.
2777 if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
2778 return true;
2779 }
2780
2781 // LShr(UINT_MAX, Y) + 1 is a power of two (if add is nuw) or zero.
2782 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO))
2783 if (match(V: I, P: m_Add(L: m_LShr(L: m_AllOnes(), R: m_Value()), R: m_One())))
2784 return true;
2785 return false;
2786 }
2787 case Instruction::Select:
2788 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2789 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 2), OrZero, Q, Depth);
2790 case Instruction::PHI: {
2791 // A PHI node is power of two if all incoming values are power of two, or if
2792 // it is an induction variable where in each step its value is a power of
2793 // two.
2794 auto *PN = cast<PHINode>(Val: I);
2795 SimplifyQuery RecQ = Q.getWithoutCondContext();
2796
2797 // Check if it is an induction variable and always power of two.
2798 if (isPowerOfTwoRecurrence(PN, OrZero, Q&: RecQ, Depth))
2799 return true;
2800
2801 // Recursively check all incoming values. Limit recursion to 2 levels, so
2802 // that search complexity is limited to number of operands^2.
2803 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
2804 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
2805 // Value is power of 2 if it is coming from PHI node itself by induction.
2806 if (U.get() == PN)
2807 return true;
2808
2809 // Change the context instruction to the incoming block where it is
2810 // evaluated.
2811 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
2812 return isKnownToBeAPowerOfTwo(V: U.get(), OrZero, Q: RecQ, Depth: NewDepth);
2813 });
2814 }
2815 case Instruction::Invoke:
2816 case Instruction::Call: {
2817 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
2818 switch (II->getIntrinsicID()) {
2819 case Intrinsic::umax:
2820 case Intrinsic::smax:
2821 case Intrinsic::umin:
2822 case Intrinsic::smin:
2823 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 1), OrZero, Q, Depth) &&
2824 isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2825 // bswap/bitreverse just move around bits, but don't change any 1s/0s
2826 // thus dont change pow2/non-pow2 status.
2827 case Intrinsic::bitreverse:
2828 case Intrinsic::bswap:
2829 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2830 case Intrinsic::fshr:
2831 case Intrinsic::fshl:
2832 // If Op0 == Op1, this is a rotate. is_pow2(rotate(x, y)) == is_pow2(x)
2833 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
2834 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2835 break;
2836 default:
2837 break;
2838 }
2839 }
2840 return false;
2841 }
2842 default:
2843 return false;
2844 }
2845}
2846
2847/// Test whether a GEP's result is known to be non-null.
2848///
2849/// Uses properties inherent in a GEP to try to determine whether it is known
2850/// to be non-null.
2851///
2852/// Currently this routine does not support vector GEPs.
2853static bool isGEPKnownNonNull(const GEPOperator *GEP, const SimplifyQuery &Q,
2854 unsigned Depth) {
2855 const Function *F = nullptr;
2856 if (const Instruction *I = dyn_cast<Instruction>(Val: GEP))
2857 F = I->getFunction();
2858
2859 // If the gep is nuw or inbounds with invalid null pointer, then the GEP
2860 // may be null iff the base pointer is null and the offset is zero.
2861 if (!GEP->hasNoUnsignedWrap() &&
2862 !(GEP->isInBounds() &&
2863 !NullPointerIsDefined(F, AS: GEP->getPointerAddressSpace())))
2864 return false;
2865
2866 // FIXME: Support vector-GEPs.
2867 assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
2868
2869 // If the base pointer is non-null, we cannot walk to a null address with an
2870 // inbounds GEP in address space zero.
2871 if (isKnownNonZero(V: GEP->getPointerOperand(), Q, Depth))
2872 return true;
2873
2874 // Walk the GEP operands and see if any operand introduces a non-zero offset.
2875 // If so, then the GEP cannot produce a null pointer, as doing so would
2876 // inherently violate the inbounds contract within address space zero.
2877 for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
2878 GTI != GTE; ++GTI) {
2879 // Struct types are easy -- they must always be indexed by a constant.
2880 if (StructType *STy = GTI.getStructTypeOrNull()) {
2881 ConstantInt *OpC = cast<ConstantInt>(Val: GTI.getOperand());
2882 unsigned ElementIdx = OpC->getZExtValue();
2883 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
2884 uint64_t ElementOffset = SL->getElementOffset(Idx: ElementIdx);
2885 if (ElementOffset > 0)
2886 return true;
2887 continue;
2888 }
2889
2890 // If we have a zero-sized type, the index doesn't matter. Keep looping.
2891 if (GTI.getSequentialElementStride(DL: Q.DL).isZero())
2892 continue;
2893
2894 // Fast path the constant operand case both for efficiency and so we don't
2895 // increment Depth when just zipping down an all-constant GEP.
2896 if (ConstantInt *OpC = dyn_cast<ConstantInt>(Val: GTI.getOperand())) {
2897 if (!OpC->isZero())
2898 return true;
2899 continue;
2900 }
2901
2902 // We post-increment Depth here because while isKnownNonZero increments it
2903 // as well, when we pop back up that increment won't persist. We don't want
2904 // to recurse 10k times just because we have 10k GEP operands. We don't
2905 // bail completely out because we want to handle constant GEPs regardless
2906 // of depth.
2907 if (Depth++ >= MaxAnalysisRecursionDepth)
2908 continue;
2909
2910 if (isKnownNonZero(V: GTI.getOperand(), Q, Depth))
2911 return true;
2912 }
2913
2914 return false;
2915}
2916
2917static bool isKnownNonNullFromDominatingCondition(const Value *V,
2918 const Instruction *CtxI,
2919 const DominatorTree *DT) {
2920 assert(!isa<Constant>(V) && "Called for constant?");
2921
2922 if (!CtxI || !DT)
2923 return false;
2924
2925 unsigned NumUsesExplored = 0;
2926 for (auto &U : V->uses()) {
2927 // Avoid massive lists
2928 if (NumUsesExplored >= DomConditionsMaxUses)
2929 break;
2930 NumUsesExplored++;
2931
2932 const Instruction *UI = cast<Instruction>(Val: U.getUser());
2933 // If the value is used as an argument to a call or invoke, then argument
2934 // attributes may provide an answer about null-ness.
2935 if (V->getType()->isPointerTy()) {
2936 if (const auto *CB = dyn_cast<CallBase>(Val: UI)) {
2937 if (CB->isArgOperand(U: &U) &&
2938 CB->paramHasNonNullAttr(ArgNo: CB->getArgOperandNo(U: &U),
2939 /*AllowUndefOrPoison=*/false) &&
2940 DT->dominates(Def: CB, User: CtxI))
2941 return true;
2942 }
2943 }
2944
2945 // If the value is used as a load/store, then the pointer must be non null.
2946 if (V == getLoadStorePointerOperand(V: UI)) {
2947 if (!NullPointerIsDefined(F: UI->getFunction(),
2948 AS: V->getType()->getPointerAddressSpace()) &&
2949 DT->dominates(Def: UI, User: CtxI))
2950 return true;
2951 }
2952
2953 if ((match(V: UI, P: m_IDiv(L: m_Value(), R: m_Specific(V))) ||
2954 match(V: UI, P: m_IRem(L: m_Value(), R: m_Specific(V)))) &&
2955 isValidAssumeForContext(Inv: UI, CxtI: CtxI, DT))
2956 return true;
2957
2958 // Consider only compare instructions uniquely controlling a branch
2959 Value *RHS;
2960 CmpPredicate Pred;
2961 if (!match(V: UI, P: m_c_ICmp(Pred, L: m_Specific(V), R: m_Value(V&: RHS))))
2962 continue;
2963
2964 bool NonNullIfTrue;
2965 if (cmpExcludesZero(Pred, RHS))
2966 NonNullIfTrue = true;
2967 else if (cmpExcludesZero(Pred: CmpInst::getInversePredicate(pred: Pred), RHS))
2968 NonNullIfTrue = false;
2969 else
2970 continue;
2971
2972 SmallVector<const User *, 4> WorkList;
2973 SmallPtrSet<const User *, 4> Visited;
2974 for (const auto *CmpU : UI->users()) {
2975 assert(WorkList.empty() && "Should be!");
2976 if (Visited.insert(Ptr: CmpU).second)
2977 WorkList.push_back(Elt: CmpU);
2978
2979 while (!WorkList.empty()) {
2980 auto *Curr = WorkList.pop_back_val();
2981
2982 // If a user is an AND, add all its users to the work list. We only
2983 // propagate "pred != null" condition through AND because it is only
2984 // correct to assume that all conditions of AND are met in true branch.
2985 // TODO: Support similar logic of OR and EQ predicate?
2986 if (NonNullIfTrue)
2987 if (match(V: Curr, P: m_LogicalAnd(L: m_Value(), R: m_Value()))) {
2988 for (const auto *CurrU : Curr->users())
2989 if (Visited.insert(Ptr: CurrU).second)
2990 WorkList.push_back(Elt: CurrU);
2991 continue;
2992 }
2993
2994 if (const BranchInst *BI = dyn_cast<BranchInst>(Val: Curr)) {
2995 assert(BI->isConditional() && "uses a comparison!");
2996
2997 BasicBlock *NonNullSuccessor =
2998 BI->getSuccessor(i: NonNullIfTrue ? 0 : 1);
2999 BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
3000 if (Edge.isSingleEdge() && DT->dominates(BBE: Edge, BB: CtxI->getParent()))
3001 return true;
3002 } else if (NonNullIfTrue && isGuard(U: Curr) &&
3003 DT->dominates(Def: cast<Instruction>(Val: Curr), User: CtxI)) {
3004 return true;
3005 }
3006 }
3007 }
3008 }
3009
3010 return false;
3011}
3012
3013/// Does the 'Range' metadata (which must be a valid MD_range operand list)
3014/// ensure that the value it's attached to is never Value? 'RangeType' is
3015/// is the type of the value described by the range.
3016static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
3017 const unsigned NumRanges = Ranges->getNumOperands() / 2;
3018 assert(NumRanges >= 1);
3019 for (unsigned i = 0; i < NumRanges; ++i) {
3020 ConstantInt *Lower =
3021 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 0));
3022 ConstantInt *Upper =
3023 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 1));
3024 ConstantRange Range(Lower->getValue(), Upper->getValue());
3025 if (Range.contains(Val: Value))
3026 return false;
3027 }
3028 return true;
3029}
3030
3031/// Try to detect a recurrence that monotonically increases/decreases from a
3032/// non-zero starting value. These are common as induction variables.
3033static bool isNonZeroRecurrence(const PHINode *PN) {
3034 BinaryOperator *BO = nullptr;
3035 Value *Start = nullptr, *Step = nullptr;
3036 const APInt *StartC, *StepC;
3037 if (!matchSimpleRecurrence(P: PN, BO, Start, Step) ||
3038 !match(V: Start, P: m_APInt(Res&: StartC)) || StartC->isZero())
3039 return false;
3040
3041 switch (BO->getOpcode()) {
3042 case Instruction::Add:
3043 // Starting from non-zero and stepping away from zero can never wrap back
3044 // to zero.
3045 return BO->hasNoUnsignedWrap() ||
3046 (BO->hasNoSignedWrap() && match(V: Step, P: m_APInt(Res&: StepC)) &&
3047 StartC->isNegative() == StepC->isNegative());
3048 case Instruction::Mul:
3049 return (BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) &&
3050 match(V: Step, P: m_APInt(Res&: StepC)) && !StepC->isZero();
3051 case Instruction::Shl:
3052 return BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap();
3053 case Instruction::AShr:
3054 case Instruction::LShr:
3055 return BO->isExact();
3056 default:
3057 return false;
3058 }
3059}
3060
3061static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
3062 return match(V: Op0, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3063 L: m_Specific(V: Op1), R: m_Zero()))) ||
3064 match(V: Op1, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3065 L: m_Specific(V: Op0), R: m_Zero())));
3066}
3067
3068static bool isNonZeroAdd(const APInt &DemandedElts, const SimplifyQuery &Q,
3069 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3070 bool NUW, unsigned Depth) {
3071 // (X + (X != 0)) is non zero
3072 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3073 return true;
3074
3075 if (NUW)
3076 return isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3077 isKnownNonZero(V: X, DemandedElts, Q, Depth);
3078
3079 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3080 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3081
3082 // If X and Y are both non-negative (as signed values) then their sum is not
3083 // zero unless both X and Y are zero.
3084 if (XKnown.isNonNegative() && YKnown.isNonNegative())
3085 if (isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3086 isKnownNonZero(V: X, DemandedElts, Q, Depth))
3087 return true;
3088
3089 // If X and Y are both negative (as signed values) then their sum is not
3090 // zero unless both X and Y equal INT_MIN.
3091 if (XKnown.isNegative() && YKnown.isNegative()) {
3092 APInt Mask = APInt::getSignedMaxValue(numBits: BitWidth);
3093 // The sign bit of X is set. If some other bit is set then X is not equal
3094 // to INT_MIN.
3095 if (XKnown.One.intersects(RHS: Mask))
3096 return true;
3097 // The sign bit of Y is set. If some other bit is set then Y is not equal
3098 // to INT_MIN.
3099 if (YKnown.One.intersects(RHS: Mask))
3100 return true;
3101 }
3102
3103 // The sum of a non-negative number and a power of two is not zero.
3104 if (XKnown.isNonNegative() &&
3105 isKnownToBeAPowerOfTwo(V: Y, /*OrZero*/ false, Q, Depth))
3106 return true;
3107 if (YKnown.isNonNegative() &&
3108 isKnownToBeAPowerOfTwo(V: X, /*OrZero*/ false, Q, Depth))
3109 return true;
3110
3111 return KnownBits::add(LHS: XKnown, RHS: YKnown, NSW, NUW).isNonZero();
3112}
3113
3114static bool isNonZeroSub(const APInt &DemandedElts, const SimplifyQuery &Q,
3115 unsigned BitWidth, Value *X, Value *Y,
3116 unsigned Depth) {
3117 // (X - (X != 0)) is non zero
3118 // ((X != 0) - X) is non zero
3119 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3120 return true;
3121
3122 // TODO: Move this case into isKnownNonEqual().
3123 if (auto *C = dyn_cast<Constant>(Val: X))
3124 if (C->isNullValue() && isKnownNonZero(V: Y, DemandedElts, Q, Depth))
3125 return true;
3126
3127 return ::isKnownNonEqual(V1: X, V2: Y, DemandedElts, Q, Depth);
3128}
3129
3130static bool isNonZeroMul(const APInt &DemandedElts, const SimplifyQuery &Q,
3131 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3132 bool NUW, unsigned Depth) {
3133 // If X and Y are non-zero then so is X * Y as long as the multiplication
3134 // does not overflow.
3135 if (NSW || NUW)
3136 return isKnownNonZero(V: X, DemandedElts, Q, Depth) &&
3137 isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3138
3139 // If either X or Y is odd, then if the other is non-zero the result can't
3140 // be zero.
3141 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3142 if (XKnown.One[0])
3143 return isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3144
3145 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3146 if (YKnown.One[0])
3147 return XKnown.isNonZero() || isKnownNonZero(V: X, DemandedElts, Q, Depth);
3148
3149 // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
3150 // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
3151 // the lowest known One of X and Y. If they are non-zero, the result
3152 // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
3153 // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
3154 return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) <
3155 BitWidth;
3156}
3157
3158static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts,
3159 const SimplifyQuery &Q, const KnownBits &KnownVal,
3160 unsigned Depth) {
3161 auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3162 switch (I->getOpcode()) {
3163 case Instruction::Shl:
3164 return Lhs.shl(ShiftAmt: Rhs);
3165 case Instruction::LShr:
3166 return Lhs.lshr(ShiftAmt: Rhs);
3167 case Instruction::AShr:
3168 return Lhs.ashr(ShiftAmt: Rhs);
3169 default:
3170 llvm_unreachable("Unknown Shift Opcode");
3171 }
3172 };
3173
3174 auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3175 switch (I->getOpcode()) {
3176 case Instruction::Shl:
3177 return Lhs.lshr(ShiftAmt: Rhs);
3178 case Instruction::LShr:
3179 case Instruction::AShr:
3180 return Lhs.shl(ShiftAmt: Rhs);
3181 default:
3182 llvm_unreachable("Unknown Shift Opcode");
3183 }
3184 };
3185
3186 if (KnownVal.isUnknown())
3187 return false;
3188
3189 KnownBits KnownCnt =
3190 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3191 APInt MaxShift = KnownCnt.getMaxValue();
3192 unsigned NumBits = KnownVal.getBitWidth();
3193 if (MaxShift.uge(RHS: NumBits))
3194 return false;
3195
3196 if (!ShiftOp(KnownVal.One, MaxShift).isZero())
3197 return true;
3198
3199 // If all of the bits shifted out are known to be zero, and Val is known
3200 // non-zero then at least one non-zero bit must remain.
3201 if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift)
3202 .eq(RHS: InvShiftOp(APInt::getAllOnes(numBits: NumBits), NumBits - MaxShift)) &&
3203 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth))
3204 return true;
3205
3206 return false;
3207}
3208
3209static bool isKnownNonZeroFromOperator(const Operator *I,
3210 const APInt &DemandedElts,
3211 const SimplifyQuery &Q, unsigned Depth) {
3212 unsigned BitWidth = getBitWidth(Ty: I->getType()->getScalarType(), DL: Q.DL);
3213 switch (I->getOpcode()) {
3214 case Instruction::Alloca:
3215 // Alloca never returns null, malloc might.
3216 return I->getType()->getPointerAddressSpace() == 0;
3217 case Instruction::GetElementPtr:
3218 if (I->getType()->isPointerTy())
3219 return isGEPKnownNonNull(GEP: cast<GEPOperator>(Val: I), Q, Depth);
3220 break;
3221 case Instruction::BitCast: {
3222 // We need to be a bit careful here. We can only peek through the bitcast
3223 // if the scalar size of elements in the operand are smaller than and a
3224 // multiple of the size they are casting too. Take three cases:
3225 //
3226 // 1) Unsafe:
3227 // bitcast <2 x i16> %NonZero to <4 x i8>
3228 //
3229 // %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a
3230 // <4 x i8> requires that all 4 i8 elements be non-zero which isn't
3231 // guranteed (imagine just sign bit set in the 2 i16 elements).
3232 //
3233 // 2) Unsafe:
3234 // bitcast <4 x i3> %NonZero to <3 x i4>
3235 //
3236 // Even though the scalar size of the src (`i3`) is smaller than the
3237 // scalar size of the dst `i4`, because `i3` is not a multiple of `i4`
3238 // its possible for the `3 x i4` elements to be zero because there are
3239 // some elements in the destination that don't contain any full src
3240 // element.
3241 //
3242 // 3) Safe:
3243 // bitcast <4 x i8> %NonZero to <2 x i16>
3244 //
3245 // This is always safe as non-zero in the 4 i8 elements implies
3246 // non-zero in the combination of any two adjacent ones. Since i8 is a
3247 // multiple of i16, each i16 is guranteed to have 2 full i8 elements.
3248 // This all implies the 2 i16 elements are non-zero.
3249 Type *FromTy = I->getOperand(i: 0)->getType();
3250 if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
3251 (BitWidth % getBitWidth(Ty: FromTy->getScalarType(), DL: Q.DL)) == 0)
3252 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3253 } break;
3254 case Instruction::IntToPtr:
3255 // Note that we have to take special care to avoid looking through
3256 // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
3257 // as casts that can alter the value, e.g., AddrSpaceCasts.
3258 if (!isa<ScalableVectorType>(Val: I->getType()) &&
3259 Q.DL.getTypeSizeInBits(Ty: I->getOperand(i: 0)->getType()).getFixedValue() <=
3260 Q.DL.getTypeSizeInBits(Ty: I->getType()).getFixedValue())
3261 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3262 break;
3263 case Instruction::PtrToAddr:
3264 // isKnownNonZero() for pointers refers to the address bits being non-zero,
3265 // so we can directly forward.
3266 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3267 case Instruction::PtrToInt:
3268 // For inttoptr, make sure the result size is >= the address size. If the
3269 // address is non-zero, any larger value is also non-zero.
3270 if (Q.DL.getAddressSizeInBits(Ty: I->getOperand(i: 0)->getType()) <=
3271 I->getType()->getScalarSizeInBits())
3272 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3273 break;
3274 case Instruction::Trunc:
3275 // nuw/nsw trunc preserves zero/non-zero status of input.
3276 if (auto *TI = dyn_cast<TruncInst>(Val: I))
3277 if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
3278 return isKnownNonZero(V: TI->getOperand(i_nocapture: 0), DemandedElts, Q, Depth);
3279 break;
3280
3281 // Iff x - y != 0, then x ^ y != 0
3282 // Therefore we can do the same exact checks
3283 case Instruction::Xor:
3284 case Instruction::Sub:
3285 return isNonZeroSub(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3286 Y: I->getOperand(i: 1), Depth);
3287 case Instruction::Or:
3288 // (X | (X != 0)) is non zero
3289 if (matchOpWithOpEqZero(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1)))
3290 return true;
3291 // X | Y != 0 if X != Y.
3292 if (isKnownNonEqual(V1: I->getOperand(i: 0), V2: I->getOperand(i: 1), DemandedElts, Q,
3293 Depth))
3294 return true;
3295 // X | Y != 0 if X != 0 or Y != 0.
3296 return isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth) ||
3297 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3298 case Instruction::SExt:
3299 case Instruction::ZExt:
3300 // ext X != 0 if X != 0.
3301 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3302
3303 case Instruction::Shl: {
3304 // shl nsw/nuw can't remove any non-zero bits.
3305 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3306 if (Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO))
3307 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3308
3309 // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
3310 // if the lowest bit is shifted off the end.
3311 KnownBits Known(BitWidth);
3312 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth);
3313 if (Known.One[0])
3314 return true;
3315
3316 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3317 }
3318 case Instruction::LShr:
3319 case Instruction::AShr: {
3320 // shr exact can only shift out zero bits.
3321 const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(Val: I);
3322 if (BO->isExact())
3323 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3324
3325 // shr X, Y != 0 if X is negative. Note that the value of the shift is not
3326 // defined if the sign bit is shifted off the end.
3327 KnownBits Known =
3328 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3329 if (Known.isNegative())
3330 return true;
3331
3332 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3333 }
3334 case Instruction::UDiv:
3335 case Instruction::SDiv: {
3336 // X / Y
3337 // div exact can only produce a zero if the dividend is zero.
3338 if (cast<PossiblyExactOperator>(Val: I)->isExact())
3339 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3340
3341 KnownBits XKnown =
3342 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3343 // If X is fully unknown we won't be able to figure anything out so don't
3344 // both computing knownbits for Y.
3345 if (XKnown.isUnknown())
3346 return false;
3347
3348 KnownBits YKnown =
3349 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3350 if (I->getOpcode() == Instruction::SDiv) {
3351 // For signed division need to compare abs value of the operands.
3352 XKnown = XKnown.abs(/*IntMinIsPoison*/ false);
3353 YKnown = YKnown.abs(/*IntMinIsPoison*/ false);
3354 }
3355 // If X u>= Y then div is non zero (0/0 is UB).
3356 std::optional<bool> XUgeY = KnownBits::uge(LHS: XKnown, RHS: YKnown);
3357 // If X is total unknown or X u< Y we won't be able to prove non-zero
3358 // with compute known bits so just return early.
3359 return XUgeY && *XUgeY;
3360 }
3361 case Instruction::Add: {
3362 // X + Y.
3363
3364 // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
3365 // non-zero.
3366 auto *BO = cast<OverflowingBinaryOperator>(Val: I);
3367 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3368 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3369 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3370 }
3371 case Instruction::Mul: {
3372 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3373 return isNonZeroMul(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3374 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3375 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3376 }
3377 case Instruction::Select: {
3378 // (C ? X : Y) != 0 if X != 0 and Y != 0.
3379
3380 // First check if the arm is non-zero using `isKnownNonZero`. If that fails,
3381 // then see if the select condition implies the arm is non-zero. For example
3382 // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is
3383 // dominated by `X != 0`.
3384 auto SelectArmIsNonZero = [&](bool IsTrueArm) {
3385 Value *Op;
3386 Op = IsTrueArm ? I->getOperand(i: 1) : I->getOperand(i: 2);
3387 // Op is trivially non-zero.
3388 if (isKnownNonZero(V: Op, DemandedElts, Q, Depth))
3389 return true;
3390
3391 // The condition of the select dominates the true/false arm. Check if the
3392 // condition implies that a given arm is non-zero.
3393 Value *X;
3394 CmpPredicate Pred;
3395 if (!match(V: I->getOperand(i: 0), P: m_c_ICmp(Pred, L: m_Specific(V: Op), R: m_Value(V&: X))))
3396 return false;
3397
3398 if (!IsTrueArm)
3399 Pred = ICmpInst::getInversePredicate(pred: Pred);
3400
3401 return cmpExcludesZero(Pred, RHS: X);
3402 };
3403
3404 if (SelectArmIsNonZero(/* IsTrueArm */ true) &&
3405 SelectArmIsNonZero(/* IsTrueArm */ false))
3406 return true;
3407 break;
3408 }
3409 case Instruction::PHI: {
3410 auto *PN = cast<PHINode>(Val: I);
3411 if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
3412 return true;
3413
3414 // Check if all incoming values are non-zero using recursion.
3415 SimplifyQuery RecQ = Q.getWithoutCondContext();
3416 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
3417 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
3418 if (U.get() == PN)
3419 return true;
3420 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
3421 // Check if the branch on the phi excludes zero.
3422 CmpPredicate Pred;
3423 Value *X;
3424 BasicBlock *TrueSucc, *FalseSucc;
3425 if (match(V: RecQ.CxtI,
3426 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: U.get()), R: m_Value(V&: X)),
3427 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
3428 // Check for cases of duplicate successors.
3429 if ((TrueSucc == PN->getParent()) != (FalseSucc == PN->getParent())) {
3430 // If we're using the false successor, invert the predicate.
3431 if (FalseSucc == PN->getParent())
3432 Pred = CmpInst::getInversePredicate(pred: Pred);
3433 if (cmpExcludesZero(Pred, RHS: X))
3434 return true;
3435 }
3436 }
3437 // Finally recurse on the edge and check it directly.
3438 return isKnownNonZero(V: U.get(), DemandedElts, Q: RecQ, Depth: NewDepth);
3439 });
3440 }
3441 case Instruction::InsertElement: {
3442 if (isa<ScalableVectorType>(Val: I->getType()))
3443 break;
3444
3445 const Value *Vec = I->getOperand(i: 0);
3446 const Value *Elt = I->getOperand(i: 1);
3447 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
3448
3449 unsigned NumElts = DemandedElts.getBitWidth();
3450 APInt DemandedVecElts = DemandedElts;
3451 bool SkipElt = false;
3452 // If we know the index we are inserting too, clear it from Vec check.
3453 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
3454 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
3455 SkipElt = !DemandedElts[CIdx->getZExtValue()];
3456 }
3457
3458 // Result is zero if Elt is non-zero and rest of the demanded elts in Vec
3459 // are non-zero.
3460 return (SkipElt || isKnownNonZero(V: Elt, Q, Depth)) &&
3461 (DemandedVecElts.isZero() ||
3462 isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth));
3463 }
3464 case Instruction::ExtractElement:
3465 if (const auto *EEI = dyn_cast<ExtractElementInst>(Val: I)) {
3466 const Value *Vec = EEI->getVectorOperand();
3467 const Value *Idx = EEI->getIndexOperand();
3468 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
3469 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
3470 unsigned NumElts = VecTy->getNumElements();
3471 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
3472 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
3473 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
3474 return isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth);
3475 }
3476 }
3477 break;
3478 case Instruction::ShuffleVector: {
3479 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
3480 if (!Shuf)
3481 break;
3482 APInt DemandedLHS, DemandedRHS;
3483 // For undef elements, we don't know anything about the common state of
3484 // the shuffle result.
3485 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
3486 break;
3487 // If demanded elements for both vecs are non-zero, the shuffle is non-zero.
3488 return (DemandedRHS.isZero() ||
3489 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 1), DemandedElts: DemandedRHS, Q, Depth)) &&
3490 (DemandedLHS.isZero() ||
3491 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 0), DemandedElts: DemandedLHS, Q, Depth));
3492 }
3493 case Instruction::Freeze:
3494 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth) &&
3495 isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
3496 Depth);
3497 case Instruction::Load: {
3498 auto *LI = cast<LoadInst>(Val: I);
3499 // A Load tagged with nonnull or dereferenceable with null pointer undefined
3500 // is never null.
3501 if (auto *PtrT = dyn_cast<PointerType>(Val: I->getType())) {
3502 if (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_nonnull) ||
3503 (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_dereferenceable) &&
3504 !NullPointerIsDefined(F: LI->getFunction(), AS: PtrT->getAddressSpace())))
3505 return true;
3506 } else if (MDNode *Ranges = Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_range)) {
3507 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3508 }
3509
3510 // No need to fall through to computeKnownBits as range metadata is already
3511 // handled in isKnownNonZero.
3512 return false;
3513 }
3514 case Instruction::ExtractValue: {
3515 const WithOverflowInst *WO;
3516 if (match(V: I, P: m_ExtractValue<0>(V: m_WithOverflowInst(I&: WO)))) {
3517 switch (WO->getBinaryOp()) {
3518 default:
3519 break;
3520 case Instruction::Add:
3521 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3522 Y: WO->getArgOperand(i: 1),
3523 /*NSW=*/false,
3524 /*NUW=*/false, Depth);
3525 case Instruction::Sub:
3526 return isNonZeroSub(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3527 Y: WO->getArgOperand(i: 1), Depth);
3528 case Instruction::Mul:
3529 return isNonZeroMul(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3530 Y: WO->getArgOperand(i: 1),
3531 /*NSW=*/false, /*NUW=*/false, Depth);
3532 break;
3533 }
3534 }
3535 break;
3536 }
3537 case Instruction::Call:
3538 case Instruction::Invoke: {
3539 const auto *Call = cast<CallBase>(Val: I);
3540 if (I->getType()->isPointerTy()) {
3541 if (Call->isReturnNonNull())
3542 return true;
3543 if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: true))
3544 return isKnownNonZero(V: RP, Q, Depth);
3545 } else {
3546 if (MDNode *Ranges = Q.IIQ.getMetadata(I: Call, KindID: LLVMContext::MD_range))
3547 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3548 if (std::optional<ConstantRange> Range = Call->getRange()) {
3549 const APInt ZeroValue(Range->getBitWidth(), 0);
3550 if (!Range->contains(Val: ZeroValue))
3551 return true;
3552 }
3553 if (const Value *RV = Call->getReturnedArgOperand())
3554 if (RV->getType() == I->getType() && isKnownNonZero(V: RV, Q, Depth))
3555 return true;
3556 }
3557
3558 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
3559 switch (II->getIntrinsicID()) {
3560 case Intrinsic::sshl_sat:
3561 case Intrinsic::ushl_sat:
3562 case Intrinsic::abs:
3563 case Intrinsic::bitreverse:
3564 case Intrinsic::bswap:
3565 case Intrinsic::ctpop:
3566 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3567 // NB: We don't do usub_sat here as in any case we can prove its
3568 // non-zero, we will fold it to `sub nuw` in InstCombine.
3569 case Intrinsic::ssub_sat:
3570 // For most types, if x != y then ssub.sat x, y != 0. But
3571 // ssub.sat.i1 0, -1 = 0, because 1 saturates to 0. This means
3572 // isNonZeroSub will do the wrong thing for ssub.sat.i1.
3573 if (BitWidth == 1)
3574 return false;
3575 return isNonZeroSub(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3576 Y: II->getArgOperand(i: 1), Depth);
3577 case Intrinsic::sadd_sat:
3578 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3579 Y: II->getArgOperand(i: 1),
3580 /*NSW=*/true, /* NUW=*/false, Depth);
3581 // Vec reverse preserves zero/non-zero status from input vec.
3582 case Intrinsic::vector_reverse:
3583 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
3584 Q, Depth);
3585 // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
3586 case Intrinsic::vector_reduce_or:
3587 case Intrinsic::vector_reduce_umax:
3588 case Intrinsic::vector_reduce_umin:
3589 case Intrinsic::vector_reduce_smax:
3590 case Intrinsic::vector_reduce_smin:
3591 return isKnownNonZero(V: II->getArgOperand(i: 0), Q, Depth);
3592 case Intrinsic::umax:
3593 case Intrinsic::uadd_sat:
3594 // umax(X, (X != 0)) is non zero
3595 // X +usat (X != 0) is non zero
3596 if (matchOpWithOpEqZero(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1)))
3597 return true;
3598
3599 return isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth) ||
3600 isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3601 case Intrinsic::smax: {
3602 // If either arg is strictly positive the result is non-zero. Otherwise
3603 // the result is non-zero if both ops are non-zero.
3604 auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
3605 const KnownBits &OpKnown) {
3606 if (!OpNonZero.has_value())
3607 OpNonZero = OpKnown.isNonZero() ||
3608 isKnownNonZero(V: Op, DemandedElts, Q, Depth);
3609 return *OpNonZero;
3610 };
3611 // Avoid re-computing isKnownNonZero.
3612 std::optional<bool> Op0NonZero, Op1NonZero;
3613 KnownBits Op1Known =
3614 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3615 if (Op1Known.isNonNegative() &&
3616 IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known))
3617 return true;
3618 KnownBits Op0Known =
3619 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3620 if (Op0Known.isNonNegative() &&
3621 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known))
3622 return true;
3623 return IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known) &&
3624 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known);
3625 }
3626 case Intrinsic::smin: {
3627 // If either arg is negative the result is non-zero. Otherwise
3628 // the result is non-zero if both ops are non-zero.
3629 KnownBits Op1Known =
3630 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3631 if (Op1Known.isNegative())
3632 return true;
3633 KnownBits Op0Known =
3634 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3635 if (Op0Known.isNegative())
3636 return true;
3637
3638 if (Op1Known.isNonZero() && Op0Known.isNonZero())
3639 return true;
3640 }
3641 [[fallthrough]];
3642 case Intrinsic::umin:
3643 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth) &&
3644 isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3645 case Intrinsic::cttz:
3646 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3647 .Zero[0];
3648 case Intrinsic::ctlz:
3649 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3650 .isNonNegative();
3651 case Intrinsic::fshr:
3652 case Intrinsic::fshl:
3653 // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0.
3654 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
3655 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3656 break;
3657 case Intrinsic::vscale:
3658 return true;
3659 case Intrinsic::experimental_get_vector_length:
3660 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3661 default:
3662 break;
3663 }
3664 break;
3665 }
3666
3667 return false;
3668 }
3669 }
3670
3671 KnownBits Known(BitWidth);
3672 computeKnownBits(V: I, DemandedElts, Known, Q, Depth);
3673 return Known.One != 0;
3674}
3675
3676/// Return true if the given value is known to be non-zero when defined. For
3677/// vectors, return true if every demanded element is known to be non-zero when
3678/// defined. For pointers, if the context instruction and dominator tree are
3679/// specified, perform context-sensitive analysis and return true if the
3680/// pointer couldn't possibly be null at the specified instruction.
3681/// Supports values with integer or pointer type and vectors of integers.
3682bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
3683 const SimplifyQuery &Q, unsigned Depth) {
3684 Type *Ty = V->getType();
3685
3686#ifndef NDEBUG
3687 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3688
3689 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3690 assert(
3691 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3692 "DemandedElt width should equal the fixed vector number of elements");
3693 } else {
3694 assert(DemandedElts == APInt(1, 1) &&
3695 "DemandedElt width should be 1 for scalars");
3696 }
3697#endif
3698
3699 if (auto *C = dyn_cast<Constant>(Val: V)) {
3700 if (C->isNullValue())
3701 return false;
3702 if (isa<ConstantInt>(Val: C))
3703 // Must be non-zero due to null test above.
3704 return true;
3705
3706 // For constant vectors, check that all elements are poison or known
3707 // non-zero to determine that the whole vector is known non-zero.
3708 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Ty)) {
3709 for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
3710 if (!DemandedElts[i])
3711 continue;
3712 Constant *Elt = C->getAggregateElement(Elt: i);
3713 if (!Elt || Elt->isNullValue())
3714 return false;
3715 if (!isa<PoisonValue>(Val: Elt) && !isa<ConstantInt>(Val: Elt))
3716 return false;
3717 }
3718 return true;
3719 }
3720
3721 // Constant ptrauth can be null, iff the base pointer can be.
3722 if (auto *CPA = dyn_cast<ConstantPtrAuth>(Val: V))
3723 return isKnownNonZero(V: CPA->getPointer(), DemandedElts, Q, Depth);
3724
3725 // A global variable in address space 0 is non null unless extern weak
3726 // or an absolute symbol reference. Other address spaces may have null as a
3727 // valid address for a global, so we can't assume anything.
3728 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
3729 if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
3730 GV->getType()->getAddressSpace() == 0)
3731 return true;
3732 }
3733
3734 // For constant expressions, fall through to the Operator code below.
3735 if (!isa<ConstantExpr>(Val: V))
3736 return false;
3737 }
3738
3739 if (const auto *A = dyn_cast<Argument>(Val: V))
3740 if (std::optional<ConstantRange> Range = A->getRange()) {
3741 const APInt ZeroValue(Range->getBitWidth(), 0);
3742 if (!Range->contains(Val: ZeroValue))
3743 return true;
3744 }
3745
3746 if (!isa<Constant>(Val: V) && isKnownNonZeroFromAssume(V, Q))
3747 return true;
3748
3749 // Some of the tests below are recursive, so bail out if we hit the limit.
3750 if (Depth++ >= MaxAnalysisRecursionDepth)
3751 return false;
3752
3753 // Check for pointer simplifications.
3754
3755 if (PointerType *PtrTy = dyn_cast<PointerType>(Val: Ty)) {
3756 // A byval, inalloca may not be null in a non-default addres space. A
3757 // nonnull argument is assumed never 0.
3758 if (const Argument *A = dyn_cast<Argument>(Val: V)) {
3759 if (((A->hasPassPointeeByValueCopyAttr() &&
3760 !NullPointerIsDefined(F: A->getParent(), AS: PtrTy->getAddressSpace())) ||
3761 A->hasNonNullAttr()))
3762 return true;
3763 }
3764 }
3765
3766 if (const auto *I = dyn_cast<Operator>(Val: V))
3767 if (isKnownNonZeroFromOperator(I, DemandedElts, Q, Depth))
3768 return true;
3769
3770 if (!isa<Constant>(Val: V) &&
3771 isKnownNonNullFromDominatingCondition(V, CtxI: Q.CxtI, DT: Q.DT))
3772 return true;
3773
3774 if (const Value *Stripped = stripNullTest(V))
3775 return isKnownNonZero(V: Stripped, DemandedElts, Q, Depth);
3776
3777 return false;
3778}
3779
3780bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
3781 unsigned Depth) {
3782 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
3783 APInt DemandedElts =
3784 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
3785 return ::isKnownNonZero(V, DemandedElts, Q, Depth);
3786}
3787
3788/// If the pair of operators are the same invertible function, return the
3789/// the operands of the function corresponding to each input. Otherwise,
3790/// return std::nullopt. An invertible function is one that is 1-to-1 and maps
3791/// every input value to exactly one output value. This is equivalent to
3792/// saying that Op1 and Op2 are equal exactly when the specified pair of
3793/// operands are equal, (except that Op1 and Op2 may be poison more often.)
3794static std::optional<std::pair<Value*, Value*>>
3795getInvertibleOperands(const Operator *Op1,
3796 const Operator *Op2) {
3797 if (Op1->getOpcode() != Op2->getOpcode())
3798 return std::nullopt;
3799
3800 auto getOperands = [&](unsigned OpNum) -> auto {
3801 return std::make_pair(x: Op1->getOperand(i: OpNum), y: Op2->getOperand(i: OpNum));
3802 };
3803
3804 switch (Op1->getOpcode()) {
3805 default:
3806 break;
3807 case Instruction::Or:
3808 if (!cast<PossiblyDisjointInst>(Val: Op1)->isDisjoint() ||
3809 !cast<PossiblyDisjointInst>(Val: Op2)->isDisjoint())
3810 break;
3811 [[fallthrough]];
3812 case Instruction::Xor:
3813 case Instruction::Add: {
3814 Value *Other;
3815 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 0)), R: m_Value(V&: Other))))
3816 return std::make_pair(x: Op1->getOperand(i: 1), y&: Other);
3817 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 1)), R: m_Value(V&: Other))))
3818 return std::make_pair(x: Op1->getOperand(i: 0), y&: Other);
3819 break;
3820 }
3821 case Instruction::Sub:
3822 if (Op1->getOperand(i: 0) == Op2->getOperand(i: 0))
3823 return getOperands(1);
3824 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3825 return getOperands(0);
3826 break;
3827 case Instruction::Mul: {
3828 // invertible if A * B == (A * B) mod 2^N where A, and B are integers
3829 // and N is the bitwdith. The nsw case is non-obvious, but proven by
3830 // alive2: https://alive2.llvm.org/ce/z/Z6D5qK
3831 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3832 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3833 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3834 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3835 break;
3836
3837 // Assume operand order has been canonicalized
3838 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1) &&
3839 isa<ConstantInt>(Val: Op1->getOperand(i: 1)) &&
3840 !cast<ConstantInt>(Val: Op1->getOperand(i: 1))->isZero())
3841 return getOperands(0);
3842 break;
3843 }
3844 case Instruction::Shl: {
3845 // Same as multiplies, with the difference that we don't need to check
3846 // for a non-zero multiply. Shifts always multiply by non-zero.
3847 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3848 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3849 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3850 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3851 break;
3852
3853 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3854 return getOperands(0);
3855 break;
3856 }
3857 case Instruction::AShr:
3858 case Instruction::LShr: {
3859 auto *PEO1 = cast<PossiblyExactOperator>(Val: Op1);
3860 auto *PEO2 = cast<PossiblyExactOperator>(Val: Op2);
3861 if (!PEO1->isExact() || !PEO2->isExact())
3862 break;
3863
3864 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3865 return getOperands(0);
3866 break;
3867 }
3868 case Instruction::SExt:
3869 case Instruction::ZExt:
3870 if (Op1->getOperand(i: 0)->getType() == Op2->getOperand(i: 0)->getType())
3871 return getOperands(0);
3872 break;
3873 case Instruction::PHI: {
3874 const PHINode *PN1 = cast<PHINode>(Val: Op1);
3875 const PHINode *PN2 = cast<PHINode>(Val: Op2);
3876
3877 // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
3878 // are a single invertible function of the start values? Note that repeated
3879 // application of an invertible function is also invertible
3880 BinaryOperator *BO1 = nullptr;
3881 Value *Start1 = nullptr, *Step1 = nullptr;
3882 BinaryOperator *BO2 = nullptr;
3883 Value *Start2 = nullptr, *Step2 = nullptr;
3884 if (PN1->getParent() != PN2->getParent() ||
3885 !matchSimpleRecurrence(P: PN1, BO&: BO1, Start&: Start1, Step&: Step1) ||
3886 !matchSimpleRecurrence(P: PN2, BO&: BO2, Start&: Start2, Step&: Step2))
3887 break;
3888
3889 auto Values = getInvertibleOperands(Op1: cast<Operator>(Val: BO1),
3890 Op2: cast<Operator>(Val: BO2));
3891 if (!Values)
3892 break;
3893
3894 // We have to be careful of mutually defined recurrences here. Ex:
3895 // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
3896 // * X_i = Y_i = X_(i-1) OP Y_(i-1)
3897 // The invertibility of these is complicated, and not worth reasoning
3898 // about (yet?).
3899 if (Values->first != PN1 || Values->second != PN2)
3900 break;
3901
3902 return std::make_pair(x&: Start1, y&: Start2);
3903 }
3904 }
3905 return std::nullopt;
3906}
3907
3908/// Return true if V1 == (binop V2, X), where X is known non-zero.
3909/// Only handle a small subset of binops where (binop V2, X) with non-zero X
3910/// implies V2 != V1.
3911static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
3912 const APInt &DemandedElts,
3913 const SimplifyQuery &Q, unsigned Depth) {
3914 const BinaryOperator *BO = dyn_cast<BinaryOperator>(Val: V1);
3915 if (!BO)
3916 return false;
3917 switch (BO->getOpcode()) {
3918 default:
3919 break;
3920 case Instruction::Or:
3921 if (!cast<PossiblyDisjointInst>(Val: V1)->isDisjoint())
3922 break;
3923 [[fallthrough]];
3924 case Instruction::Xor:
3925 case Instruction::Add:
3926 Value *Op = nullptr;
3927 if (V2 == BO->getOperand(i_nocapture: 0))
3928 Op = BO->getOperand(i_nocapture: 1);
3929 else if (V2 == BO->getOperand(i_nocapture: 1))
3930 Op = BO->getOperand(i_nocapture: 0);
3931 else
3932 return false;
3933 return isKnownNonZero(V: Op, DemandedElts, Q, Depth: Depth + 1);
3934 }
3935 return false;
3936}
3937
3938/// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
3939/// the multiplication is nuw or nsw.
3940static bool isNonEqualMul(const Value *V1, const Value *V2,
3941 const APInt &DemandedElts, const SimplifyQuery &Q,
3942 unsigned Depth) {
3943 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3944 const APInt *C;
3945 return match(V: OBO, P: m_Mul(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3946 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3947 !C->isZero() && !C->isOne() &&
3948 isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3949 }
3950 return false;
3951}
3952
3953/// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
3954/// the shift is nuw or nsw.
3955static bool isNonEqualShl(const Value *V1, const Value *V2,
3956 const APInt &DemandedElts, const SimplifyQuery &Q,
3957 unsigned Depth) {
3958 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3959 const APInt *C;
3960 return match(V: OBO, P: m_Shl(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3961 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3962 !C->isZero() && isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3963 }
3964 return false;
3965}
3966
3967static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
3968 const APInt &DemandedElts, const SimplifyQuery &Q,
3969 unsigned Depth) {
3970 // Check two PHIs are in same block.
3971 if (PN1->getParent() != PN2->getParent())
3972 return false;
3973
3974 SmallPtrSet<const BasicBlock *, 8> VisitedBBs;
3975 bool UsedFullRecursion = false;
3976 for (const BasicBlock *IncomBB : PN1->blocks()) {
3977 if (!VisitedBBs.insert(Ptr: IncomBB).second)
3978 continue; // Don't reprocess blocks that we have dealt with already.
3979 const Value *IV1 = PN1->getIncomingValueForBlock(BB: IncomBB);
3980 const Value *IV2 = PN2->getIncomingValueForBlock(BB: IncomBB);
3981 const APInt *C1, *C2;
3982 if (match(V: IV1, P: m_APInt(Res&: C1)) && match(V: IV2, P: m_APInt(Res&: C2)) && *C1 != *C2)
3983 continue;
3984
3985 // Only one pair of phi operands is allowed for full recursion.
3986 if (UsedFullRecursion)
3987 return false;
3988
3989 SimplifyQuery RecQ = Q.getWithoutCondContext();
3990 RecQ.CxtI = IncomBB->getTerminator();
3991 if (!isKnownNonEqual(V1: IV1, V2: IV2, DemandedElts, Q: RecQ, Depth: Depth + 1))
3992 return false;
3993 UsedFullRecursion = true;
3994 }
3995 return true;
3996}
3997
3998static bool isNonEqualSelect(const Value *V1, const Value *V2,
3999 const APInt &DemandedElts, const SimplifyQuery &Q,
4000 unsigned Depth) {
4001 const SelectInst *SI1 = dyn_cast<SelectInst>(Val: V1);
4002 if (!SI1)
4003 return false;
4004
4005 if (const SelectInst *SI2 = dyn_cast<SelectInst>(Val: V2)) {
4006 const Value *Cond1 = SI1->getCondition();
4007 const Value *Cond2 = SI2->getCondition();
4008 if (Cond1 == Cond2)
4009 return isKnownNonEqual(V1: SI1->getTrueValue(), V2: SI2->getTrueValue(),
4010 DemandedElts, Q, Depth: Depth + 1) &&
4011 isKnownNonEqual(V1: SI1->getFalseValue(), V2: SI2->getFalseValue(),
4012 DemandedElts, Q, Depth: Depth + 1);
4013 }
4014 return isKnownNonEqual(V1: SI1->getTrueValue(), V2, DemandedElts, Q, Depth: Depth + 1) &&
4015 isKnownNonEqual(V1: SI1->getFalseValue(), V2, DemandedElts, Q, Depth: Depth + 1);
4016}
4017
4018// Check to see if A is both a GEP and is the incoming value for a PHI in the
4019// loop, and B is either a ptr or another GEP. If the PHI has 2 incoming values,
4020// one of them being the recursive GEP A and the other a ptr at same base and at
4021// the same/higher offset than B we are only incrementing the pointer further in
4022// loop if offset of recursive GEP is greater than 0.
4023static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
4024 const SimplifyQuery &Q) {
4025 if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
4026 return false;
4027
4028 auto *GEPA = dyn_cast<GEPOperator>(Val: A);
4029 if (!GEPA || GEPA->getNumIndices() != 1 || !isa<Constant>(Val: GEPA->idx_begin()))
4030 return false;
4031
4032 // Handle 2 incoming PHI values with one being a recursive GEP.
4033 auto *PN = dyn_cast<PHINode>(Val: GEPA->getPointerOperand());
4034 if (!PN || PN->getNumIncomingValues() != 2)
4035 return false;
4036
4037 // Search for the recursive GEP as an incoming operand, and record that as
4038 // Step.
4039 Value *Start = nullptr;
4040 Value *Step = const_cast<Value *>(A);
4041 if (PN->getIncomingValue(i: 0) == Step)
4042 Start = PN->getIncomingValue(i: 1);
4043 else if (PN->getIncomingValue(i: 1) == Step)
4044 Start = PN->getIncomingValue(i: 0);
4045 else
4046 return false;
4047
4048 // Other incoming node base should match the B base.
4049 // StartOffset >= OffsetB && StepOffset > 0?
4050 // StartOffset <= OffsetB && StepOffset < 0?
4051 // Is non-equal if above are true.
4052 // We use stripAndAccumulateInBoundsConstantOffsets to restrict the
4053 // optimisation to inbounds GEPs only.
4054 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: Start->getType());
4055 APInt StartOffset(IndexWidth, 0);
4056 Start = Start->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StartOffset);
4057 APInt StepOffset(IndexWidth, 0);
4058 Step = Step->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StepOffset);
4059
4060 // Check if Base Pointer of Step matches the PHI.
4061 if (Step != PN)
4062 return false;
4063 APInt OffsetB(IndexWidth, 0);
4064 B = B->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: OffsetB);
4065 return Start == B &&
4066 ((StartOffset.sge(RHS: OffsetB) && StepOffset.isStrictlyPositive()) ||
4067 (StartOffset.sle(RHS: OffsetB) && StepOffset.isNegative()));
4068}
4069
4070static bool isKnownNonEqualFromContext(const Value *V1, const Value *V2,
4071 const SimplifyQuery &Q, unsigned Depth) {
4072 if (!Q.CxtI)
4073 return false;
4074
4075 // Try to infer NonEqual based on information from dominating conditions.
4076 if (Q.DC && Q.DT) {
4077 auto IsKnownNonEqualFromDominatingCondition = [&](const Value *V) {
4078 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4079 Value *Cond = BI->getCondition();
4080 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4081 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()) &&
4082 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4083 /*LHSIsTrue=*/true, Depth)
4084 .value_or(u: false))
4085 return true;
4086
4087 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4088 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()) &&
4089 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4090 /*LHSIsTrue=*/false, Depth)
4091 .value_or(u: false))
4092 return true;
4093 }
4094
4095 return false;
4096 };
4097
4098 if (IsKnownNonEqualFromDominatingCondition(V1) ||
4099 IsKnownNonEqualFromDominatingCondition(V2))
4100 return true;
4101 }
4102
4103 if (!Q.AC)
4104 return false;
4105
4106 // Try to infer NonEqual based on information from assumptions.
4107 for (auto &AssumeVH : Q.AC->assumptionsFor(V: V1)) {
4108 if (!AssumeVH)
4109 continue;
4110 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4111
4112 assert(I->getFunction() == Q.CxtI->getFunction() &&
4113 "Got assumption for the wrong function!");
4114 assert(I->getIntrinsicID() == Intrinsic::assume &&
4115 "must be an assume intrinsic");
4116
4117 if (isImpliedCondition(LHS: I->getArgOperand(i: 0), RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4118 /*LHSIsTrue=*/true, Depth)
4119 .value_or(u: false) &&
4120 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4121 return true;
4122 }
4123
4124 return false;
4125}
4126
4127/// Return true if it is known that V1 != V2.
4128static bool isKnownNonEqual(const Value *V1, const Value *V2,
4129 const APInt &DemandedElts, const SimplifyQuery &Q,
4130 unsigned Depth) {
4131 if (V1 == V2)
4132 return false;
4133 if (V1->getType() != V2->getType())
4134 // We can't look through casts yet.
4135 return false;
4136
4137 if (Depth >= MaxAnalysisRecursionDepth)
4138 return false;
4139
4140 // See if we can recurse through (exactly one of) our operands. This
4141 // requires our operation be 1-to-1 and map every input value to exactly
4142 // one output value. Such an operation is invertible.
4143 auto *O1 = dyn_cast<Operator>(Val: V1);
4144 auto *O2 = dyn_cast<Operator>(Val: V2);
4145 if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
4146 if (auto Values = getInvertibleOperands(Op1: O1, Op2: O2))
4147 return isKnownNonEqual(V1: Values->first, V2: Values->second, DemandedElts, Q,
4148 Depth: Depth + 1);
4149
4150 if (const PHINode *PN1 = dyn_cast<PHINode>(Val: V1)) {
4151 const PHINode *PN2 = cast<PHINode>(Val: V2);
4152 // FIXME: This is missing a generalization to handle the case where one is
4153 // a PHI and another one isn't.
4154 if (isNonEqualPHIs(PN1, PN2, DemandedElts, Q, Depth))
4155 return true;
4156 };
4157 }
4158
4159 if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Q, Depth) ||
4160 isModifyingBinopOfNonZero(V1: V2, V2: V1, DemandedElts, Q, Depth))
4161 return true;
4162
4163 if (isNonEqualMul(V1, V2, DemandedElts, Q, Depth) ||
4164 isNonEqualMul(V1: V2, V2: V1, DemandedElts, Q, Depth))
4165 return true;
4166
4167 if (isNonEqualShl(V1, V2, DemandedElts, Q, Depth) ||
4168 isNonEqualShl(V1: V2, V2: V1, DemandedElts, Q, Depth))
4169 return true;
4170
4171 if (V1->getType()->isIntOrIntVectorTy()) {
4172 // Are any known bits in V1 contradictory to known bits in V2? If V1
4173 // has a known zero where V2 has a known one, they must not be equal.
4174 KnownBits Known1 = computeKnownBits(V: V1, DemandedElts, Q, Depth);
4175 if (!Known1.isUnknown()) {
4176 KnownBits Known2 = computeKnownBits(V: V2, DemandedElts, Q, Depth);
4177 if (Known1.Zero.intersects(RHS: Known2.One) ||
4178 Known2.Zero.intersects(RHS: Known1.One))
4179 return true;
4180 }
4181 }
4182
4183 if (isNonEqualSelect(V1, V2, DemandedElts, Q, Depth) ||
4184 isNonEqualSelect(V1: V2, V2: V1, DemandedElts, Q, Depth))
4185 return true;
4186
4187 if (isNonEqualPointersWithRecursiveGEP(A: V1, B: V2, Q) ||
4188 isNonEqualPointersWithRecursiveGEP(A: V2, B: V1, Q))
4189 return true;
4190
4191 Value *A, *B;
4192 // PtrToInts are NonEqual if their Ptrs are NonEqual.
4193 // Check PtrToInt type matches the pointer size.
4194 if (match(V: V1, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: A))) &&
4195 match(V: V2, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: B))))
4196 return isKnownNonEqual(V1: A, V2: B, DemandedElts, Q, Depth: Depth + 1);
4197
4198 if (isKnownNonEqualFromContext(V1, V2, Q, Depth))
4199 return true;
4200
4201 return false;
4202}
4203
4204/// For vector constants, loop over the elements and find the constant with the
4205/// minimum number of sign bits. Return 0 if the value is not a vector constant
4206/// or if any element was not analyzed; otherwise, return the count for the
4207/// element with the minimum number of sign bits.
4208static unsigned computeNumSignBitsVectorConstant(const Value *V,
4209 const APInt &DemandedElts,
4210 unsigned TyBits) {
4211 const auto *CV = dyn_cast<Constant>(Val: V);
4212 if (!CV || !isa<FixedVectorType>(Val: CV->getType()))
4213 return 0;
4214
4215 unsigned MinSignBits = TyBits;
4216 unsigned NumElts = cast<FixedVectorType>(Val: CV->getType())->getNumElements();
4217 for (unsigned i = 0; i != NumElts; ++i) {
4218 if (!DemandedElts[i])
4219 continue;
4220 // If we find a non-ConstantInt, bail out.
4221 auto *Elt = dyn_cast_or_null<ConstantInt>(Val: CV->getAggregateElement(Elt: i));
4222 if (!Elt)
4223 return 0;
4224
4225 MinSignBits = std::min(a: MinSignBits, b: Elt->getValue().getNumSignBits());
4226 }
4227
4228 return MinSignBits;
4229}
4230
4231static unsigned ComputeNumSignBitsImpl(const Value *V,
4232 const APInt &DemandedElts,
4233 const SimplifyQuery &Q, unsigned Depth);
4234
4235static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
4236 const SimplifyQuery &Q, unsigned Depth) {
4237 unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Q, Depth);
4238 assert(Result > 0 && "At least one sign bit needs to be present!");
4239 return Result;
4240}
4241
4242/// Return the number of times the sign bit of the register is replicated into
4243/// the other bits. We know that at least 1 bit is always equal to the sign bit
4244/// (itself), but other cases can give us information. For example, immediately
4245/// after an "ashr X, 2", we know that the top 3 bits are all equal to each
4246/// other, so we return 3. For vectors, return the number of sign bits for the
4247/// vector element with the minimum number of known sign bits of the demanded
4248/// elements in the vector specified by DemandedElts.
4249static unsigned ComputeNumSignBitsImpl(const Value *V,
4250 const APInt &DemandedElts,
4251 const SimplifyQuery &Q, unsigned Depth) {
4252 Type *Ty = V->getType();
4253#ifndef NDEBUG
4254 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4255
4256 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
4257 assert(
4258 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
4259 "DemandedElt width should equal the fixed vector number of elements");
4260 } else {
4261 assert(DemandedElts == APInt(1, 1) &&
4262 "DemandedElt width should be 1 for scalars");
4263 }
4264#endif
4265
4266 // We return the minimum number of sign bits that are guaranteed to be present
4267 // in V, so for undef we have to conservatively return 1. We don't have the
4268 // same behavior for poison though -- that's a FIXME today.
4269
4270 Type *ScalarTy = Ty->getScalarType();
4271 unsigned TyBits = ScalarTy->isPointerTy() ?
4272 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
4273 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
4274
4275 unsigned Tmp, Tmp2;
4276 unsigned FirstAnswer = 1;
4277
4278 // Note that ConstantInt is handled by the general computeKnownBits case
4279 // below.
4280
4281 if (Depth == MaxAnalysisRecursionDepth)
4282 return 1;
4283
4284 if (auto *U = dyn_cast<Operator>(Val: V)) {
4285 switch (Operator::getOpcode(V)) {
4286 default: break;
4287 case Instruction::BitCast: {
4288 Value *Src = U->getOperand(i: 0);
4289 Type *SrcTy = Src->getType();
4290
4291 // Skip if the source type is not an integer or integer vector type
4292 // This ensures we only process integer-like types
4293 if (!SrcTy->isIntOrIntVectorTy())
4294 break;
4295
4296 unsigned SrcBits = SrcTy->getScalarSizeInBits();
4297
4298 // Bitcast 'large element' scalar/vector to 'small element' vector.
4299 if ((SrcBits % TyBits) != 0)
4300 break;
4301
4302 // Only proceed if the destination type is a fixed-size vector
4303 if (isa<FixedVectorType>(Val: Ty)) {
4304 // Fast case - sign splat can be simply split across the small elements.
4305 // This works for both vector and scalar sources
4306 Tmp = ComputeNumSignBits(V: Src, Q, Depth: Depth + 1);
4307 if (Tmp == SrcBits)
4308 return TyBits;
4309 }
4310 break;
4311 }
4312 case Instruction::SExt:
4313 Tmp = TyBits - U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4314 return ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1) +
4315 Tmp;
4316
4317 case Instruction::SDiv: {
4318 const APInt *Denominator;
4319 // sdiv X, C -> adds log(C) sign bits.
4320 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4321
4322 // Ignore non-positive denominator.
4323 if (!Denominator->isStrictlyPositive())
4324 break;
4325
4326 // Calculate the incoming numerator bits.
4327 unsigned NumBits =
4328 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4329
4330 // Add floor(log(C)) bits to the numerator bits.
4331 return std::min(a: TyBits, b: NumBits + Denominator->logBase2());
4332 }
4333 break;
4334 }
4335
4336 case Instruction::SRem: {
4337 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4338
4339 const APInt *Denominator;
4340 // srem X, C -> we know that the result is within [-C+1,C) when C is a
4341 // positive constant. This let us put a lower bound on the number of sign
4342 // bits.
4343 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4344
4345 // Ignore non-positive denominator.
4346 if (Denominator->isStrictlyPositive()) {
4347 // Calculate the leading sign bit constraints by examining the
4348 // denominator. Given that the denominator is positive, there are two
4349 // cases:
4350 //
4351 // 1. The numerator is positive. The result range is [0,C) and
4352 // [0,C) u< (1 << ceilLogBase2(C)).
4353 //
4354 // 2. The numerator is negative. Then the result range is (-C,0] and
4355 // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
4356 //
4357 // Thus a lower bound on the number of sign bits is `TyBits -
4358 // ceilLogBase2(C)`.
4359
4360 unsigned ResBits = TyBits - Denominator->ceilLogBase2();
4361 Tmp = std::max(a: Tmp, b: ResBits);
4362 }
4363 }
4364 return Tmp;
4365 }
4366
4367 case Instruction::AShr: {
4368 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4369 // ashr X, C -> adds C sign bits. Vectors too.
4370 const APInt *ShAmt;
4371 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4372 if (ShAmt->uge(RHS: TyBits))
4373 break; // Bad shift.
4374 unsigned ShAmtLimited = ShAmt->getZExtValue();
4375 Tmp += ShAmtLimited;
4376 if (Tmp > TyBits) Tmp = TyBits;
4377 }
4378 return Tmp;
4379 }
4380 case Instruction::Shl: {
4381 const APInt *ShAmt;
4382 Value *X = nullptr;
4383 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4384 // shl destroys sign bits.
4385 if (ShAmt->uge(RHS: TyBits))
4386 break; // Bad shift.
4387 // We can look through a zext (more or less treating it as a sext) if
4388 // all extended bits are shifted out.
4389 if (match(V: U->getOperand(i: 0), P: m_ZExt(Op: m_Value(V&: X))) &&
4390 ShAmt->uge(RHS: TyBits - X->getType()->getScalarSizeInBits())) {
4391 Tmp = ComputeNumSignBits(V: X, DemandedElts, Q, Depth: Depth + 1);
4392 Tmp += TyBits - X->getType()->getScalarSizeInBits();
4393 } else
4394 Tmp =
4395 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4396 if (ShAmt->uge(RHS: Tmp))
4397 break; // Shifted all sign bits out.
4398 Tmp2 = ShAmt->getZExtValue();
4399 return Tmp - Tmp2;
4400 }
4401 break;
4402 }
4403 case Instruction::And:
4404 case Instruction::Or:
4405 case Instruction::Xor: // NOT is handled here.
4406 // Logical binary ops preserve the number of sign bits at the worst.
4407 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4408 if (Tmp != 1) {
4409 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4410 FirstAnswer = std::min(a: Tmp, b: Tmp2);
4411 // We computed what we know about the sign bits as our first
4412 // answer. Now proceed to the generic code that uses
4413 // computeKnownBits, and pick whichever answer is better.
4414 }
4415 break;
4416
4417 case Instruction::Select: {
4418 // If we have a clamp pattern, we know that the number of sign bits will
4419 // be the minimum of the clamp min/max range.
4420 const Value *X;
4421 const APInt *CLow, *CHigh;
4422 if (isSignedMinMaxClamp(Select: U, In&: X, CLow, CHigh))
4423 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4424
4425 Tmp = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4426 if (Tmp == 1)
4427 break;
4428 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 2), DemandedElts, Q, Depth: Depth + 1);
4429 return std::min(a: Tmp, b: Tmp2);
4430 }
4431
4432 case Instruction::Add:
4433 // Add can have at most one carry bit. Thus we know that the output
4434 // is, at worst, one more bit than the inputs.
4435 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4436 if (Tmp == 1) break;
4437
4438 // Special case decrementing a value (ADD X, -1):
4439 if (const auto *CRHS = dyn_cast<Constant>(Val: U->getOperand(i: 1)))
4440 if (CRHS->isAllOnesValue()) {
4441 KnownBits Known(TyBits);
4442 computeKnownBits(V: U->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
4443
4444 // If the input is known to be 0 or 1, the output is 0/-1, which is
4445 // all sign bits set.
4446 if ((Known.Zero | 1).isAllOnes())
4447 return TyBits;
4448
4449 // If we are subtracting one from a positive number, there is no carry
4450 // out of the result.
4451 if (Known.isNonNegative())
4452 return Tmp;
4453 }
4454
4455 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4456 if (Tmp2 == 1)
4457 break;
4458 return std::min(a: Tmp, b: Tmp2) - 1;
4459
4460 case Instruction::Sub:
4461 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4462 if (Tmp2 == 1)
4463 break;
4464
4465 // Handle NEG.
4466 if (const auto *CLHS = dyn_cast<Constant>(Val: U->getOperand(i: 0)))
4467 if (CLHS->isNullValue()) {
4468 KnownBits Known(TyBits);
4469 computeKnownBits(V: U->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
4470 // If the input is known to be 0 or 1, the output is 0/-1, which is
4471 // all sign bits set.
4472 if ((Known.Zero | 1).isAllOnes())
4473 return TyBits;
4474
4475 // If the input is known to be positive (the sign bit is known clear),
4476 // the output of the NEG has the same number of sign bits as the
4477 // input.
4478 if (Known.isNonNegative())
4479 return Tmp2;
4480
4481 // Otherwise, we treat this like a SUB.
4482 }
4483
4484 // Sub can have at most one carry bit. Thus we know that the output
4485 // is, at worst, one more bit than the inputs.
4486 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4487 if (Tmp == 1)
4488 break;
4489 return std::min(a: Tmp, b: Tmp2) - 1;
4490
4491 case Instruction::Mul: {
4492 // The output of the Mul can be at most twice the valid bits in the
4493 // inputs.
4494 unsigned SignBitsOp0 =
4495 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4496 if (SignBitsOp0 == 1)
4497 break;
4498 unsigned SignBitsOp1 =
4499 ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4500 if (SignBitsOp1 == 1)
4501 break;
4502 unsigned OutValidBits =
4503 (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
4504 return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
4505 }
4506
4507 case Instruction::PHI: {
4508 const PHINode *PN = cast<PHINode>(Val: U);
4509 unsigned NumIncomingValues = PN->getNumIncomingValues();
4510 // Don't analyze large in-degree PHIs.
4511 if (NumIncomingValues > 4) break;
4512 // Unreachable blocks may have zero-operand PHI nodes.
4513 if (NumIncomingValues == 0) break;
4514
4515 // Take the minimum of all incoming values. This can't infinitely loop
4516 // because of our depth threshold.
4517 SimplifyQuery RecQ = Q.getWithoutCondContext();
4518 Tmp = TyBits;
4519 for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
4520 if (Tmp == 1) return Tmp;
4521 RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
4522 Tmp = std::min(a: Tmp, b: ComputeNumSignBits(V: PN->getIncomingValue(i),
4523 DemandedElts, Q: RecQ, Depth: Depth + 1));
4524 }
4525 return Tmp;
4526 }
4527
4528 case Instruction::Trunc: {
4529 // If the input contained enough sign bits that some remain after the
4530 // truncation, then we can make use of that. Otherwise we don't know
4531 // anything.
4532 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4533 unsigned OperandTyBits = U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4534 if (Tmp > (OperandTyBits - TyBits))
4535 return Tmp - (OperandTyBits - TyBits);
4536
4537 return 1;
4538 }
4539
4540 case Instruction::ExtractElement:
4541 // Look through extract element. At the moment we keep this simple and
4542 // skip tracking the specific element. But at least we might find
4543 // information valid for all elements of the vector (for example if vector
4544 // is sign extended, shifted, etc).
4545 return ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4546
4547 case Instruction::ShuffleVector: {
4548 // Collect the minimum number of sign bits that are shared by every vector
4549 // element referenced by the shuffle.
4550 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: U);
4551 if (!Shuf) {
4552 // FIXME: Add support for shufflevector constant expressions.
4553 return 1;
4554 }
4555 APInt DemandedLHS, DemandedRHS;
4556 // For undef elements, we don't know anything about the common state of
4557 // the shuffle result.
4558 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
4559 return 1;
4560 Tmp = std::numeric_limits<unsigned>::max();
4561 if (!!DemandedLHS) {
4562 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
4563 Tmp = ComputeNumSignBits(V: LHS, DemandedElts: DemandedLHS, Q, Depth: Depth + 1);
4564 }
4565 // If we don't know anything, early out and try computeKnownBits
4566 // fall-back.
4567 if (Tmp == 1)
4568 break;
4569 if (!!DemandedRHS) {
4570 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
4571 Tmp2 = ComputeNumSignBits(V: RHS, DemandedElts: DemandedRHS, Q, Depth: Depth + 1);
4572 Tmp = std::min(a: Tmp, b: Tmp2);
4573 }
4574 // If we don't know anything, early out and try computeKnownBits
4575 // fall-back.
4576 if (Tmp == 1)
4577 break;
4578 assert(Tmp <= TyBits && "Failed to determine minimum sign bits");
4579 return Tmp;
4580 }
4581 case Instruction::Call: {
4582 if (const auto *II = dyn_cast<IntrinsicInst>(Val: U)) {
4583 switch (II->getIntrinsicID()) {
4584 default:
4585 break;
4586 case Intrinsic::abs:
4587 Tmp =
4588 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4589 if (Tmp == 1)
4590 break;
4591
4592 // Absolute value reduces number of sign bits by at most 1.
4593 return Tmp - 1;
4594 case Intrinsic::smin:
4595 case Intrinsic::smax: {
4596 const APInt *CLow, *CHigh;
4597 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
4598 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4599 }
4600 }
4601 }
4602 }
4603 }
4604 }
4605
4606 // Finally, if we can prove that the top bits of the result are 0's or 1's,
4607 // use this information.
4608
4609 // If we can examine all elements of a vector constant successfully, we're
4610 // done (we can't do any better than that). If not, keep trying.
4611 if (unsigned VecSignBits =
4612 computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
4613 return VecSignBits;
4614
4615 KnownBits Known(TyBits);
4616 computeKnownBits(V, DemandedElts, Known, Q, Depth);
4617
4618 // If we know that the sign bit is either zero or one, determine the number of
4619 // identical bits in the top of the input value.
4620 return std::max(a: FirstAnswer, b: Known.countMinSignBits());
4621}
4622
4623Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
4624 const TargetLibraryInfo *TLI) {
4625 const Function *F = CB.getCalledFunction();
4626 if (!F)
4627 return Intrinsic::not_intrinsic;
4628
4629 if (F->isIntrinsic())
4630 return F->getIntrinsicID();
4631
4632 // We are going to infer semantics of a library function based on mapping it
4633 // to an LLVM intrinsic. Check that the library function is available from
4634 // this callbase and in this environment.
4635 LibFunc Func;
4636 if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, F&: Func) ||
4637 !CB.onlyReadsMemory())
4638 return Intrinsic::not_intrinsic;
4639
4640 switch (Func) {
4641 default:
4642 break;
4643 case LibFunc_sin:
4644 case LibFunc_sinf:
4645 case LibFunc_sinl:
4646 return Intrinsic::sin;
4647 case LibFunc_cos:
4648 case LibFunc_cosf:
4649 case LibFunc_cosl:
4650 return Intrinsic::cos;
4651 case LibFunc_tan:
4652 case LibFunc_tanf:
4653 case LibFunc_tanl:
4654 return Intrinsic::tan;
4655 case LibFunc_asin:
4656 case LibFunc_asinf:
4657 case LibFunc_asinl:
4658 return Intrinsic::asin;
4659 case LibFunc_acos:
4660 case LibFunc_acosf:
4661 case LibFunc_acosl:
4662 return Intrinsic::acos;
4663 case LibFunc_atan:
4664 case LibFunc_atanf:
4665 case LibFunc_atanl:
4666 return Intrinsic::atan;
4667 case LibFunc_atan2:
4668 case LibFunc_atan2f:
4669 case LibFunc_atan2l:
4670 return Intrinsic::atan2;
4671 case LibFunc_sinh:
4672 case LibFunc_sinhf:
4673 case LibFunc_sinhl:
4674 return Intrinsic::sinh;
4675 case LibFunc_cosh:
4676 case LibFunc_coshf:
4677 case LibFunc_coshl:
4678 return Intrinsic::cosh;
4679 case LibFunc_tanh:
4680 case LibFunc_tanhf:
4681 case LibFunc_tanhl:
4682 return Intrinsic::tanh;
4683 case LibFunc_exp:
4684 case LibFunc_expf:
4685 case LibFunc_expl:
4686 return Intrinsic::exp;
4687 case LibFunc_exp2:
4688 case LibFunc_exp2f:
4689 case LibFunc_exp2l:
4690 return Intrinsic::exp2;
4691 case LibFunc_exp10:
4692 case LibFunc_exp10f:
4693 case LibFunc_exp10l:
4694 return Intrinsic::exp10;
4695 case LibFunc_log:
4696 case LibFunc_logf:
4697 case LibFunc_logl:
4698 return Intrinsic::log;
4699 case LibFunc_log10:
4700 case LibFunc_log10f:
4701 case LibFunc_log10l:
4702 return Intrinsic::log10;
4703 case LibFunc_log2:
4704 case LibFunc_log2f:
4705 case LibFunc_log2l:
4706 return Intrinsic::log2;
4707 case LibFunc_fabs:
4708 case LibFunc_fabsf:
4709 case LibFunc_fabsl:
4710 return Intrinsic::fabs;
4711 case LibFunc_fmin:
4712 case LibFunc_fminf:
4713 case LibFunc_fminl:
4714 return Intrinsic::minnum;
4715 case LibFunc_fmax:
4716 case LibFunc_fmaxf:
4717 case LibFunc_fmaxl:
4718 return Intrinsic::maxnum;
4719 case LibFunc_copysign:
4720 case LibFunc_copysignf:
4721 case LibFunc_copysignl:
4722 return Intrinsic::copysign;
4723 case LibFunc_floor:
4724 case LibFunc_floorf:
4725 case LibFunc_floorl:
4726 return Intrinsic::floor;
4727 case LibFunc_ceil:
4728 case LibFunc_ceilf:
4729 case LibFunc_ceill:
4730 return Intrinsic::ceil;
4731 case LibFunc_trunc:
4732 case LibFunc_truncf:
4733 case LibFunc_truncl:
4734 return Intrinsic::trunc;
4735 case LibFunc_rint:
4736 case LibFunc_rintf:
4737 case LibFunc_rintl:
4738 return Intrinsic::rint;
4739 case LibFunc_nearbyint:
4740 case LibFunc_nearbyintf:
4741 case LibFunc_nearbyintl:
4742 return Intrinsic::nearbyint;
4743 case LibFunc_round:
4744 case LibFunc_roundf:
4745 case LibFunc_roundl:
4746 return Intrinsic::round;
4747 case LibFunc_roundeven:
4748 case LibFunc_roundevenf:
4749 case LibFunc_roundevenl:
4750 return Intrinsic::roundeven;
4751 case LibFunc_pow:
4752 case LibFunc_powf:
4753 case LibFunc_powl:
4754 return Intrinsic::pow;
4755 case LibFunc_sqrt:
4756 case LibFunc_sqrtf:
4757 case LibFunc_sqrtl:
4758 return Intrinsic::sqrt;
4759 }
4760
4761 return Intrinsic::not_intrinsic;
4762}
4763
4764/// Given an exploded icmp instruction, return true if the comparison only
4765/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
4766/// the result of the comparison is true when the input value is signed.
4767bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
4768 bool &TrueIfSigned) {
4769 switch (Pred) {
4770 case ICmpInst::ICMP_SLT: // True if LHS s< 0
4771 TrueIfSigned = true;
4772 return RHS.isZero();
4773 case ICmpInst::ICMP_SLE: // True if LHS s<= -1
4774 TrueIfSigned = true;
4775 return RHS.isAllOnes();
4776 case ICmpInst::ICMP_SGT: // True if LHS s> -1
4777 TrueIfSigned = false;
4778 return RHS.isAllOnes();
4779 case ICmpInst::ICMP_SGE: // True if LHS s>= 0
4780 TrueIfSigned = false;
4781 return RHS.isZero();
4782 case ICmpInst::ICMP_UGT:
4783 // True if LHS u> RHS and RHS == sign-bit-mask - 1
4784 TrueIfSigned = true;
4785 return RHS.isMaxSignedValue();
4786 case ICmpInst::ICMP_UGE:
4787 // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4788 TrueIfSigned = true;
4789 return RHS.isMinSignedValue();
4790 case ICmpInst::ICMP_ULT:
4791 // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4792 TrueIfSigned = false;
4793 return RHS.isMinSignedValue();
4794 case ICmpInst::ICMP_ULE:
4795 // True if LHS u<= RHS and RHS == sign-bit-mask - 1
4796 TrueIfSigned = false;
4797 return RHS.isMaxSignedValue();
4798 default:
4799 return false;
4800 }
4801}
4802
4803static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4804 bool CondIsTrue,
4805 const Instruction *CxtI,
4806 KnownFPClass &KnownFromContext,
4807 unsigned Depth = 0) {
4808 Value *A, *B;
4809 if (Depth < MaxAnalysisRecursionDepth &&
4810 (CondIsTrue ? match(V: Cond, P: m_LogicalAnd(L: m_Value(V&: A), R: m_Value(V&: B)))
4811 : match(V: Cond, P: m_LogicalOr(L: m_Value(V&: A), R: m_Value(V&: B))))) {
4812 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue, CxtI, KnownFromContext,
4813 Depth: Depth + 1);
4814 computeKnownFPClassFromCond(V, Cond: B, CondIsTrue, CxtI, KnownFromContext,
4815 Depth: Depth + 1);
4816 return;
4817 }
4818 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A)))) {
4819 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue: !CondIsTrue, CxtI, KnownFromContext,
4820 Depth: Depth + 1);
4821 return;
4822 }
4823 CmpPredicate Pred;
4824 Value *LHS;
4825 uint64_t ClassVal = 0;
4826 const APFloat *CRHS;
4827 const APInt *RHS;
4828 if (match(V: Cond, P: m_FCmp(Pred, L: m_Value(V&: LHS), R: m_APFloat(Res&: CRHS)))) {
4829 auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4830 Pred, F: *cast<Instruction>(Val: Cond)->getParent()->getParent(), LHS, ConstRHS: *CRHS,
4831 LookThroughSrc: LHS != V);
4832 if (CmpVal == V)
4833 KnownFromContext.knownNot(RuleOut: ~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4834 } else if (match(V: Cond, P: m_Intrinsic<Intrinsic::is_fpclass>(
4835 Op0: m_Specific(V), Op1: m_ConstantInt(V&: ClassVal)))) {
4836 FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4837 KnownFromContext.knownNot(RuleOut: CondIsTrue ? ~Mask : Mask);
4838 } else if (match(V: Cond, P: m_ICmp(Pred, L: m_ElementWiseBitCast(Op: m_Specific(V)),
4839 R: m_APInt(Res&: RHS)))) {
4840 bool TrueIfSigned;
4841 if (!isSignBitCheck(Pred, RHS: *RHS, TrueIfSigned))
4842 return;
4843 if (TrueIfSigned == CondIsTrue)
4844 KnownFromContext.signBitMustBeOne();
4845 else
4846 KnownFromContext.signBitMustBeZero();
4847 }
4848}
4849
4850static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4851 const SimplifyQuery &Q) {
4852 KnownFPClass KnownFromContext;
4853
4854 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
4855 computeKnownFPClassFromCond(V, Cond: Q.CC->Cond, CondIsTrue: !Q.CC->Invert, CxtI: Q.CxtI,
4856 KnownFromContext);
4857
4858 if (!Q.CxtI)
4859 return KnownFromContext;
4860
4861 if (Q.DC && Q.DT) {
4862 // Handle dominating conditions.
4863 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4864 Value *Cond = BI->getCondition();
4865
4866 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4867 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
4868 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, CxtI: Q.CxtI,
4869 KnownFromContext);
4870
4871 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4872 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
4873 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, CxtI: Q.CxtI,
4874 KnownFromContext);
4875 }
4876 }
4877
4878 if (!Q.AC)
4879 return KnownFromContext;
4880
4881 // Try to restrict the floating-point classes based on information from
4882 // assumptions.
4883 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
4884 if (!AssumeVH)
4885 continue;
4886 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4887
4888 assert(I->getFunction() == Q.CxtI->getParent()->getParent() &&
4889 "Got assumption for the wrong function!");
4890 assert(I->getIntrinsicID() == Intrinsic::assume &&
4891 "must be an assume intrinsic");
4892
4893 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4894 continue;
4895
4896 computeKnownFPClassFromCond(V, Cond: I->getArgOperand(i: 0),
4897 /*CondIsTrue=*/true, CxtI: Q.CxtI, KnownFromContext);
4898 }
4899
4900 return KnownFromContext;
4901}
4902
4903void llvm::adjustKnownFPClassForSelectArm(KnownFPClass &Known, Value *Cond,
4904 Value *Arm, bool Invert,
4905 const SimplifyQuery &SQ,
4906 unsigned Depth) {
4907
4908 KnownFPClass KnownSrc;
4909 computeKnownFPClassFromCond(V: Arm, Cond,
4910 /*CondIsTrue=*/!Invert, CxtI: SQ.CxtI, KnownFromContext&: KnownSrc,
4911 Depth: Depth + 1);
4912 KnownSrc = KnownSrc.unionWith(RHS: Known);
4913 if (KnownSrc.isUnknown())
4914 return;
4915
4916 if (isGuaranteedNotToBeUndef(V: Arm, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT, Depth: Depth + 1))
4917 Known = KnownSrc;
4918}
4919
4920void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4921 FPClassTest InterestedClasses, KnownFPClass &Known,
4922 const SimplifyQuery &Q, unsigned Depth);
4923
4924static void computeKnownFPClass(const Value *V, KnownFPClass &Known,
4925 FPClassTest InterestedClasses,
4926 const SimplifyQuery &Q, unsigned Depth) {
4927 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
4928 APInt DemandedElts =
4929 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
4930 computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Q, Depth);
4931}
4932
4933static void computeKnownFPClassForFPTrunc(const Operator *Op,
4934 const APInt &DemandedElts,
4935 FPClassTest InterestedClasses,
4936 KnownFPClass &Known,
4937 const SimplifyQuery &Q,
4938 unsigned Depth) {
4939 if ((InterestedClasses &
4940 (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone)
4941 return;
4942
4943 KnownFPClass KnownSrc;
4944 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
4945 Known&: KnownSrc, Q, Depth: Depth + 1);
4946 Known = KnownFPClass::fptrunc(KnownSrc);
4947}
4948
4949static constexpr KnownFPClass::MinMaxKind getMinMaxKind(Intrinsic::ID IID) {
4950 switch (IID) {
4951 case Intrinsic::minimum:
4952 return KnownFPClass::MinMaxKind::minimum;
4953 case Intrinsic::maximum:
4954 return KnownFPClass::MinMaxKind::maximum;
4955 case Intrinsic::minimumnum:
4956 return KnownFPClass::MinMaxKind::minimumnum;
4957 case Intrinsic::maximumnum:
4958 return KnownFPClass::MinMaxKind::maximumnum;
4959 case Intrinsic::minnum:
4960 return KnownFPClass::MinMaxKind::minnum;
4961 case Intrinsic::maxnum:
4962 return KnownFPClass::MinMaxKind::maxnum;
4963 default:
4964 llvm_unreachable("not a floating-point min-max intrinsic");
4965 }
4966}
4967
4968void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4969 FPClassTest InterestedClasses, KnownFPClass &Known,
4970 const SimplifyQuery &Q, unsigned Depth) {
4971 assert(Known.isUnknown() && "should not be called with known information");
4972
4973 if (!DemandedElts) {
4974 // No demanded elts, better to assume we don't know anything.
4975 Known.resetAll();
4976 return;
4977 }
4978
4979 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4980
4981 if (auto *CFP = dyn_cast<ConstantFP>(Val: V)) {
4982 Known = KnownFPClass(CFP->getValueAPF());
4983 return;
4984 }
4985
4986 if (isa<ConstantAggregateZero>(Val: V)) {
4987 Known.KnownFPClasses = fcPosZero;
4988 Known.SignBit = false;
4989 return;
4990 }
4991
4992 if (isa<PoisonValue>(Val: V)) {
4993 Known.KnownFPClasses = fcNone;
4994 Known.SignBit = false;
4995 return;
4996 }
4997
4998 // Try to handle fixed width vector constants
4999 auto *VFVTy = dyn_cast<FixedVectorType>(Val: V->getType());
5000 const Constant *CV = dyn_cast<Constant>(Val: V);
5001 if (VFVTy && CV) {
5002 Known.KnownFPClasses = fcNone;
5003 bool SignBitAllZero = true;
5004 bool SignBitAllOne = true;
5005
5006 // For vectors, verify that each element is not NaN.
5007 unsigned NumElts = VFVTy->getNumElements();
5008 for (unsigned i = 0; i != NumElts; ++i) {
5009 if (!DemandedElts[i])
5010 continue;
5011
5012 Constant *Elt = CV->getAggregateElement(Elt: i);
5013 if (!Elt) {
5014 Known = KnownFPClass();
5015 return;
5016 }
5017 if (isa<PoisonValue>(Val: Elt))
5018 continue;
5019 auto *CElt = dyn_cast<ConstantFP>(Val: Elt);
5020 if (!CElt) {
5021 Known = KnownFPClass();
5022 return;
5023 }
5024
5025 const APFloat &C = CElt->getValueAPF();
5026 Known.KnownFPClasses |= C.classify();
5027 if (C.isNegative())
5028 SignBitAllZero = false;
5029 else
5030 SignBitAllOne = false;
5031 }
5032 if (SignBitAllOne != SignBitAllZero)
5033 Known.SignBit = SignBitAllOne;
5034 return;
5035 }
5036
5037 FPClassTest KnownNotFromFlags = fcNone;
5038 if (const auto *CB = dyn_cast<CallBase>(Val: V))
5039 KnownNotFromFlags |= CB->getRetNoFPClass();
5040 else if (const auto *Arg = dyn_cast<Argument>(Val: V))
5041 KnownNotFromFlags |= Arg->getNoFPClass();
5042
5043 const Operator *Op = dyn_cast<Operator>(Val: V);
5044 if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Val: Op)) {
5045 if (FPOp->hasNoNaNs())
5046 KnownNotFromFlags |= fcNan;
5047 if (FPOp->hasNoInfs())
5048 KnownNotFromFlags |= fcInf;
5049 }
5050
5051 KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
5052 KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
5053
5054 // We no longer need to find out about these bits from inputs if we can
5055 // assume this from flags/attributes.
5056 InterestedClasses &= ~KnownNotFromFlags;
5057
5058 llvm::scope_exit ClearClassesFromFlags([=, &Known] {
5059 Known.knownNot(RuleOut: KnownNotFromFlags);
5060 if (!Known.SignBit && AssumedClasses.SignBit) {
5061 if (*AssumedClasses.SignBit)
5062 Known.signBitMustBeOne();
5063 else
5064 Known.signBitMustBeZero();
5065 }
5066 });
5067
5068 if (!Op)
5069 return;
5070
5071 // All recursive calls that increase depth must come after this.
5072 if (Depth == MaxAnalysisRecursionDepth)
5073 return;
5074
5075 const unsigned Opc = Op->getOpcode();
5076 switch (Opc) {
5077 case Instruction::FNeg: {
5078 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5079 Known, Q, Depth: Depth + 1);
5080 Known.fneg();
5081 break;
5082 }
5083 case Instruction::Select: {
5084 auto ComputeForArm = [&](Value *Arm, bool Invert) {
5085 KnownFPClass Res;
5086 computeKnownFPClass(V: Arm, DemandedElts, InterestedClasses, Known&: Res, Q,
5087 Depth: Depth + 1);
5088 adjustKnownFPClassForSelectArm(Known&: Res, Cond: Op->getOperand(i: 0), Arm, Invert, SQ: Q,
5089 Depth);
5090 return Res;
5091 };
5092 // Only known if known in both the LHS and RHS.
5093 Known =
5094 ComputeForArm(Op->getOperand(i: 1), /*Invert=*/false)
5095 .intersectWith(RHS: ComputeForArm(Op->getOperand(i: 2), /*Invert=*/true));
5096 break;
5097 }
5098 case Instruction::Load: {
5099 const MDNode *NoFPClass =
5100 cast<LoadInst>(Val: Op)->getMetadata(KindID: LLVMContext::MD_nofpclass);
5101 if (!NoFPClass)
5102 break;
5103
5104 ConstantInt *MaskVal =
5105 mdconst::extract<ConstantInt>(MD: NoFPClass->getOperand(I: 0));
5106 Known.knownNot(RuleOut: static_cast<FPClassTest>(MaskVal->getZExtValue()));
5107 break;
5108 }
5109 case Instruction::Call: {
5110 const CallInst *II = cast<CallInst>(Val: Op);
5111 const Intrinsic::ID IID = II->getIntrinsicID();
5112 switch (IID) {
5113 case Intrinsic::fabs: {
5114 if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) {
5115 // If we only care about the sign bit we don't need to inspect the
5116 // operand.
5117 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5118 InterestedClasses, Known, Q, Depth: Depth + 1);
5119 }
5120
5121 Known.fabs();
5122 break;
5123 }
5124 case Intrinsic::copysign: {
5125 KnownFPClass KnownSign;
5126
5127 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5128 Known, Q, Depth: Depth + 1);
5129 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5130 Known&: KnownSign, Q, Depth: Depth + 1);
5131 Known.copysign(Sign: KnownSign);
5132 break;
5133 }
5134 case Intrinsic::fma:
5135 case Intrinsic::fmuladd: {
5136 if ((InterestedClasses & fcNegative) == fcNone)
5137 break;
5138
5139 // FIXME: This should check isGuaranteedNotToBeUndef
5140 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1)) {
5141 KnownFPClass KnownSrc, KnownAddend;
5142 computeKnownFPClass(V: II->getArgOperand(i: 2), DemandedElts,
5143 InterestedClasses, Known&: KnownAddend, Q, Depth: Depth + 1);
5144 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5145 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5146
5147 const Function *F = II->getFunction();
5148 const fltSemantics &FltSem =
5149 II->getType()->getScalarType()->getFltSemantics();
5150 DenormalMode Mode =
5151 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5152
5153 if (KnownNotFromFlags & fcNan) {
5154 KnownSrc.knownNot(RuleOut: fcNan);
5155 KnownAddend.knownNot(RuleOut: fcNan);
5156 }
5157
5158 if (KnownNotFromFlags & fcInf) {
5159 KnownSrc.knownNot(RuleOut: fcInf);
5160 KnownAddend.knownNot(RuleOut: fcInf);
5161 }
5162
5163 Known = KnownFPClass::fma_square(Squared: KnownSrc, Addend: KnownAddend, Mode);
5164 break;
5165 }
5166
5167 KnownFPClass KnownSrc[3];
5168 for (int I = 0; I != 3; ++I) {
5169 computeKnownFPClass(V: II->getArgOperand(i: I), DemandedElts,
5170 InterestedClasses, Known&: KnownSrc[I], Q, Depth: Depth + 1);
5171 if (KnownSrc[I].isUnknown())
5172 return;
5173
5174 if (KnownNotFromFlags & fcNan)
5175 KnownSrc[I].knownNot(RuleOut: fcNan);
5176 if (KnownNotFromFlags & fcInf)
5177 KnownSrc[I].knownNot(RuleOut: fcInf);
5178 }
5179
5180 const Function *F = II->getFunction();
5181 const fltSemantics &FltSem =
5182 II->getType()->getScalarType()->getFltSemantics();
5183 DenormalMode Mode =
5184 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5185 Known = KnownFPClass::fma(LHS: KnownSrc[0], RHS: KnownSrc[1], Addend: KnownSrc[2], Mode);
5186 break;
5187 }
5188 case Intrinsic::sqrt:
5189 case Intrinsic::experimental_constrained_sqrt: {
5190 KnownFPClass KnownSrc;
5191 FPClassTest InterestedSrcs = InterestedClasses;
5192 if (InterestedClasses & fcNan)
5193 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5194
5195 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5196 Known&: KnownSrc, Q, Depth: Depth + 1);
5197
5198 DenormalMode Mode = DenormalMode::getDynamic();
5199
5200 bool HasNSZ = Q.IIQ.hasNoSignedZeros(Op: II);
5201 if (!HasNSZ) {
5202 const Function *F = II->getFunction();
5203 const fltSemantics &FltSem =
5204 II->getType()->getScalarType()->getFltSemantics();
5205 Mode = F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5206 }
5207
5208 Known = KnownFPClass::sqrt(Src: KnownSrc, Mode);
5209 if (HasNSZ)
5210 Known.knownNot(RuleOut: fcNegZero);
5211
5212 break;
5213 }
5214 case Intrinsic::sin:
5215 case Intrinsic::cos: {
5216 // Return NaN on infinite inputs.
5217 KnownFPClass KnownSrc;
5218 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5219 Known&: KnownSrc, Q, Depth: Depth + 1);
5220 Known = IID == Intrinsic::sin ? KnownFPClass::sin(Src: KnownSrc)
5221 : KnownFPClass::cos(Src: KnownSrc);
5222 break;
5223 }
5224 case Intrinsic::maxnum:
5225 case Intrinsic::minnum:
5226 case Intrinsic::minimum:
5227 case Intrinsic::maximum:
5228 case Intrinsic::minimumnum:
5229 case Intrinsic::maximumnum: {
5230 KnownFPClass KnownLHS, KnownRHS;
5231 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5232 Known&: KnownLHS, Q, Depth: Depth + 1);
5233 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5234 Known&: KnownRHS, Q, Depth: Depth + 1);
5235
5236 const Function *F = II->getFunction();
5237
5238 DenormalMode Mode =
5239 F ? F->getDenormalMode(
5240 FPType: II->getType()->getScalarType()->getFltSemantics())
5241 : DenormalMode::getDynamic();
5242
5243 Known = KnownFPClass::minMaxLike(LHS: KnownLHS, RHS: KnownRHS, Kind: getMinMaxKind(IID),
5244 DenormMode: Mode);
5245 break;
5246 }
5247 case Intrinsic::canonicalize: {
5248 KnownFPClass KnownSrc;
5249 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5250 Known&: KnownSrc, Q, Depth: Depth + 1);
5251
5252 const Function *F = II->getFunction();
5253 DenormalMode DenormMode =
5254 F ? F->getDenormalMode(
5255 FPType: II->getType()->getScalarType()->getFltSemantics())
5256 : DenormalMode::getDynamic();
5257 Known = KnownFPClass::canonicalize(Src: KnownSrc, DenormMode);
5258 break;
5259 }
5260 case Intrinsic::vector_reduce_fmax:
5261 case Intrinsic::vector_reduce_fmin:
5262 case Intrinsic::vector_reduce_fmaximum:
5263 case Intrinsic::vector_reduce_fminimum: {
5264 // reduce min/max will choose an element from one of the vector elements,
5265 // so we can infer and class information that is common to all elements.
5266 Known = computeKnownFPClass(V: II->getArgOperand(i: 0), FMF: II->getFastMathFlags(),
5267 InterestedClasses, SQ: Q, Depth: Depth + 1);
5268 // Can only propagate sign if output is never NaN.
5269 if (!Known.isKnownNeverNaN())
5270 Known.SignBit.reset();
5271 break;
5272 }
5273 // reverse preserves all characteristics of the input vec's element.
5274 case Intrinsic::vector_reverse:
5275 Known = computeKnownFPClass(
5276 V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
5277 FMF: II->getFastMathFlags(), InterestedClasses, SQ: Q, Depth: Depth + 1);
5278 break;
5279 case Intrinsic::trunc:
5280 case Intrinsic::floor:
5281 case Intrinsic::ceil:
5282 case Intrinsic::rint:
5283 case Intrinsic::nearbyint:
5284 case Intrinsic::round:
5285 case Intrinsic::roundeven: {
5286 KnownFPClass KnownSrc;
5287 FPClassTest InterestedSrcs = InterestedClasses;
5288 if (InterestedSrcs & fcPosFinite)
5289 InterestedSrcs |= fcPosFinite;
5290 if (InterestedSrcs & fcNegFinite)
5291 InterestedSrcs |= fcNegFinite;
5292 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5293 Known&: KnownSrc, Q, Depth: Depth + 1);
5294
5295 Known = KnownFPClass::roundToIntegral(
5296 Src: KnownSrc, IsTrunc: IID == Intrinsic::trunc,
5297 IsMultiUnitFPType: V->getType()->getScalarType()->isMultiUnitFPType());
5298 break;
5299 }
5300 case Intrinsic::exp:
5301 case Intrinsic::exp2:
5302 case Intrinsic::exp10:
5303 case Intrinsic::amdgcn_exp2: {
5304 KnownFPClass KnownSrc;
5305 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5306 Known&: KnownSrc, Q, Depth: Depth + 1);
5307
5308 Known = KnownFPClass::exp(Src: KnownSrc);
5309
5310 Type *EltTy = II->getType()->getScalarType();
5311 if (IID == Intrinsic::amdgcn_exp2 && EltTy->isFloatTy())
5312 Known.knownNot(RuleOut: fcSubnormal);
5313
5314 break;
5315 }
5316 case Intrinsic::fptrunc_round: {
5317 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5318 Q, Depth);
5319 break;
5320 }
5321 case Intrinsic::log:
5322 case Intrinsic::log10:
5323 case Intrinsic::log2:
5324 case Intrinsic::experimental_constrained_log:
5325 case Intrinsic::experimental_constrained_log10:
5326 case Intrinsic::experimental_constrained_log2:
5327 case Intrinsic::amdgcn_log: {
5328 Type *EltTy = II->getType()->getScalarType();
5329
5330 // log(+inf) -> +inf
5331 // log([+-]0.0) -> -inf
5332 // log(-inf) -> nan
5333 // log(-x) -> nan
5334 if ((InterestedClasses & (fcNan | fcInf)) != fcNone) {
5335 FPClassTest InterestedSrcs = InterestedClasses;
5336 if ((InterestedClasses & fcNegInf) != fcNone)
5337 InterestedSrcs |= fcZero | fcSubnormal;
5338 if ((InterestedClasses & fcNan) != fcNone)
5339 InterestedSrcs |= fcNan | fcNegative;
5340
5341 KnownFPClass KnownSrc;
5342 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5343 Known&: KnownSrc, Q, Depth: Depth + 1);
5344
5345 const Function *F = II->getFunction();
5346 DenormalMode Mode = F ? F->getDenormalMode(FPType: EltTy->getFltSemantics())
5347 : DenormalMode::getDynamic();
5348 Known = KnownFPClass::log(Src: KnownSrc, Mode);
5349 }
5350
5351 break;
5352 }
5353 case Intrinsic::powi: {
5354 if ((InterestedClasses & fcNegative) == fcNone)
5355 break;
5356
5357 const Value *Exp = II->getArgOperand(i: 1);
5358 Type *ExpTy = Exp->getType();
5359 unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth();
5360 KnownBits ExponentKnownBits(BitWidth);
5361 computeKnownBits(V: Exp, DemandedElts: isa<VectorType>(Val: ExpTy) ? DemandedElts : APInt(1, 1),
5362 Known&: ExponentKnownBits, Q, Depth: Depth + 1);
5363
5364 KnownFPClass KnownSrc;
5365 if (!ExponentKnownBits.isEven()) {
5366 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: fcNegative,
5367 Known&: KnownSrc, Q, Depth: Depth + 1);
5368 }
5369
5370 Known = KnownFPClass::powi(Src: KnownSrc, N: ExponentKnownBits);
5371 break;
5372 }
5373 case Intrinsic::ldexp: {
5374 KnownFPClass KnownSrc;
5375 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5376 Known&: KnownSrc, Q, Depth: Depth + 1);
5377 // Can refine inf/zero handling based on the exponent operand.
5378 const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf;
5379
5380 KnownBits ExpBits;
5381 if ((KnownSrc.KnownFPClasses & ExpInfoMask) != fcNone) {
5382 const Value *ExpArg = II->getArgOperand(i: 1);
5383 ExpBits = computeKnownBits(V: ExpArg, DemandedElts, Q, Depth: Depth + 1);
5384 }
5385
5386 const fltSemantics &Flt =
5387 II->getType()->getScalarType()->getFltSemantics();
5388
5389 const Function *F = II->getFunction();
5390 DenormalMode Mode =
5391 F ? F->getDenormalMode(FPType: Flt) : DenormalMode::getDynamic();
5392
5393 Known = KnownFPClass::ldexp(Src: KnownSrc, N: ExpBits, Flt, Mode);
5394 break;
5395 }
5396 case Intrinsic::arithmetic_fence: {
5397 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5398 Known, Q, Depth: Depth + 1);
5399 break;
5400 }
5401 case Intrinsic::experimental_constrained_sitofp:
5402 case Intrinsic::experimental_constrained_uitofp:
5403 // Cannot produce nan
5404 Known.knownNot(RuleOut: fcNan);
5405
5406 // sitofp and uitofp turn into +0.0 for zero.
5407 Known.knownNot(RuleOut: fcNegZero);
5408
5409 // Integers cannot be subnormal
5410 Known.knownNot(RuleOut: fcSubnormal);
5411
5412 if (IID == Intrinsic::experimental_constrained_uitofp)
5413 Known.signBitMustBeZero();
5414
5415 // TODO: Copy inf handling from instructions
5416 break;
5417 case Intrinsic::amdgcn_rcp: {
5418 KnownFPClass KnownSrc;
5419 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5420 Known&: KnownSrc, Q, Depth: Depth + 1);
5421
5422 Known.propagateNaN(Src: KnownSrc);
5423
5424 Type *EltTy = II->getType()->getScalarType();
5425
5426 // f32 denormal always flushed.
5427 if (EltTy->isFloatTy()) {
5428 Known.knownNot(RuleOut: fcSubnormal);
5429 KnownSrc.knownNot(RuleOut: fcSubnormal);
5430 }
5431
5432 if (KnownSrc.isKnownNever(Mask: fcNegative))
5433 Known.knownNot(RuleOut: fcNegative);
5434 if (KnownSrc.isKnownNever(Mask: fcPositive))
5435 Known.knownNot(RuleOut: fcPositive);
5436
5437 if (const Function *F = II->getFunction()) {
5438 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5439 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5440 Known.knownNot(RuleOut: fcPosInf);
5441 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5442 Known.knownNot(RuleOut: fcNegInf);
5443 }
5444
5445 break;
5446 }
5447 case Intrinsic::amdgcn_rsq: {
5448 KnownFPClass KnownSrc;
5449 // The only negative value that can be returned is -inf for -0 inputs.
5450 Known.knownNot(RuleOut: fcNegZero | fcNegSubnormal | fcNegNormal);
5451
5452 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5453 Known&: KnownSrc, Q, Depth: Depth + 1);
5454
5455 // Negative -> nan
5456 if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5457 Known.knownNot(RuleOut: fcNan);
5458 else if (KnownSrc.isKnownNever(Mask: fcSNan))
5459 Known.knownNot(RuleOut: fcSNan);
5460
5461 // +inf -> +0
5462 if (KnownSrc.isKnownNeverPosInfinity())
5463 Known.knownNot(RuleOut: fcPosZero);
5464
5465 Type *EltTy = II->getType()->getScalarType();
5466
5467 // f32 denormal always flushed.
5468 if (EltTy->isFloatTy())
5469 Known.knownNot(RuleOut: fcPosSubnormal);
5470
5471 if (const Function *F = II->getFunction()) {
5472 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5473
5474 // -0 -> -inf
5475 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5476 Known.knownNot(RuleOut: fcNegInf);
5477
5478 // +0 -> +inf
5479 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5480 Known.knownNot(RuleOut: fcPosInf);
5481 }
5482
5483 break;
5484 }
5485 default:
5486 break;
5487 }
5488
5489 break;
5490 }
5491 case Instruction::FAdd:
5492 case Instruction::FSub: {
5493 KnownFPClass KnownLHS, KnownRHS;
5494 bool WantNegative =
5495 Op->getOpcode() == Instruction::FAdd &&
5496 (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone;
5497 bool WantNaN = (InterestedClasses & fcNan) != fcNone;
5498 bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone;
5499
5500 if (!WantNaN && !WantNegative && !WantNegZero)
5501 break;
5502
5503 FPClassTest InterestedSrcs = InterestedClasses;
5504 if (WantNegative)
5505 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5506 if (InterestedClasses & fcNan)
5507 InterestedSrcs |= fcInf;
5508 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts, InterestedClasses: InterestedSrcs,
5509 Known&: KnownRHS, Q, Depth: Depth + 1);
5510
5511 // Special case fadd x, x, which is the canonical form of fmul x, 2.
5512 bool Self = Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5513 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
5514 Depth: Depth + 1);
5515 if (Self)
5516 KnownLHS = KnownRHS;
5517
5518 if ((WantNaN && KnownRHS.isKnownNeverNaN()) ||
5519 (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) ||
5520 WantNegZero || Opc == Instruction::FSub) {
5521
5522 // FIXME: Context function should always be passed in separately
5523 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5524 const fltSemantics &FltSem =
5525 Op->getType()->getScalarType()->getFltSemantics();
5526 DenormalMode Mode =
5527 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5528
5529 if (Self && Opc == Instruction::FAdd) {
5530 Known = KnownFPClass::fadd_self(Src: KnownLHS, Mode);
5531 } else {
5532 // RHS is canonically cheaper to compute. Skip inspecting the LHS if
5533 // there's no point.
5534
5535 if (!Self) {
5536 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5537 Known&: KnownLHS, Q, Depth: Depth + 1);
5538 }
5539
5540 Known = Opc == Instruction::FAdd
5541 ? KnownFPClass::fadd(LHS: KnownLHS, RHS: KnownRHS, Mode)
5542 : KnownFPClass::fsub(LHS: KnownLHS, RHS: KnownRHS, Mode);
5543 }
5544 }
5545
5546 break;
5547 }
5548 case Instruction::FMul: {
5549 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5550 DenormalMode Mode =
5551 F ? F->getDenormalMode(
5552 FPType: Op->getType()->getScalarType()->getFltSemantics())
5553 : DenormalMode::getDynamic();
5554
5555 // X * X is always non-negative or a NaN.
5556 // FIXME: Should check isGuaranteedNotToBeUndef
5557 if (Op->getOperand(i: 0) == Op->getOperand(i: 1)) {
5558 KnownFPClass KnownSrc;
5559 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownSrc,
5560 Q, Depth: Depth + 1);
5561 Known = KnownFPClass::square(Src: KnownSrc, Mode);
5562 break;
5563 }
5564
5565 KnownFPClass KnownLHS, KnownRHS;
5566
5567 bool CannotBeSubnormal = false;
5568 const APFloat *CRHS;
5569 if (match(V: Op->getOperand(i: 1), P: m_APFloat(Res&: CRHS))) {
5570 // Match denormal scaling pattern, similar to the case in ldexp. If the
5571 // constant's exponent is sufficiently large, the result cannot be
5572 // subnormal.
5573
5574 // TODO: Should do general ConstantFPRange analysis.
5575 const fltSemantics &Flt =
5576 Op->getType()->getScalarType()->getFltSemantics();
5577 unsigned Precision = APFloat::semanticsPrecision(Flt);
5578 const int MantissaBits = Precision - 1;
5579
5580 int MinKnownExponent = ilogb(Arg: *CRHS);
5581 if (MinKnownExponent >= MantissaBits)
5582 CannotBeSubnormal = true;
5583
5584 KnownRHS = KnownFPClass(*CRHS);
5585 } else {
5586 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownRHS,
5587 Q, Depth: Depth + 1);
5588 }
5589
5590 // TODO: Improve accuracy in unfused FMA pattern. We can prove an additional
5591 // not-nan if the addend is known-not negative infinity if the multiply is
5592 // known-not infinity.
5593
5594 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS,
5595 Q, Depth: Depth + 1);
5596
5597 Known = KnownFPClass::fmul(LHS: KnownLHS, RHS: KnownRHS, Mode);
5598 if (CannotBeSubnormal)
5599 Known.knownNot(RuleOut: fcSubnormal);
5600 break;
5601 }
5602 case Instruction::FDiv:
5603 case Instruction::FRem: {
5604 const bool WantNan = (InterestedClasses & fcNan) != fcNone;
5605
5606 if (Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5607 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT)) {
5608 if (Op->getOpcode() == Instruction::FDiv) {
5609 // X / X is always exactly 1.0 or a NaN.
5610 Known.KnownFPClasses = fcNan | fcPosNormal;
5611 } else {
5612 // X % X is always exactly [+-]0.0 or a NaN.
5613 Known.KnownFPClasses = fcNan | fcZero;
5614 }
5615
5616 if (!WantNan)
5617 break;
5618
5619 KnownFPClass KnownSrc;
5620 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts,
5621 InterestedClasses: fcNan | fcInf | fcZero | fcSubnormal, Known&: KnownSrc, Q,
5622 Depth: Depth + 1);
5623 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5624 const fltSemantics &FltSem =
5625 Op->getType()->getScalarType()->getFltSemantics();
5626
5627 DenormalMode Mode =
5628 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5629
5630 Known = Op->getOpcode() == Instruction::FDiv
5631 ? KnownFPClass::fdiv_self(Src: KnownSrc, Mode)
5632 : KnownFPClass::frem_self(Src: KnownSrc, Mode);
5633 break;
5634 }
5635
5636 const bool WantNegative = (InterestedClasses & fcNegative) != fcNone;
5637 const bool WantPositive =
5638 Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone;
5639 if (!WantNan && !WantNegative && !WantPositive)
5640 break;
5641
5642 KnownFPClass KnownLHS, KnownRHS;
5643
5644 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts,
5645 InterestedClasses: fcNan | fcInf | fcZero | fcNegative, Known&: KnownRHS, Q,
5646 Depth: Depth + 1);
5647
5648 bool KnowSomethingUseful = KnownRHS.isKnownNeverNaN() ||
5649 KnownRHS.isKnownNever(Mask: fcNegative) ||
5650 KnownRHS.isKnownNever(Mask: fcPositive);
5651
5652 if (KnowSomethingUseful || WantPositive) {
5653 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS,
5654 Q, Depth: Depth + 1);
5655 }
5656
5657 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5658 const fltSemantics &FltSem =
5659 Op->getType()->getScalarType()->getFltSemantics();
5660
5661 if (Op->getOpcode() == Instruction::FDiv) {
5662 DenormalMode Mode =
5663 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5664 Known = KnownFPClass::fdiv(LHS: KnownLHS, RHS: KnownRHS, Mode);
5665 } else {
5666 // Inf REM x and x REM 0 produce NaN.
5667 if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5668 KnownLHS.isKnownNeverInfinity() && F &&
5669 KnownRHS.isKnownNeverLogicalZero(Mode: F->getDenormalMode(FPType: FltSem))) {
5670 Known.knownNot(RuleOut: fcNan);
5671 }
5672
5673 // The sign for frem is the same as the first operand.
5674 if (KnownLHS.cannotBeOrderedLessThanZero())
5675 Known.knownNot(RuleOut: KnownFPClass::OrderedLessThanZeroMask);
5676 if (KnownLHS.cannotBeOrderedGreaterThanZero())
5677 Known.knownNot(RuleOut: KnownFPClass::OrderedGreaterThanZeroMask);
5678
5679 // See if we can be more aggressive about the sign of 0.
5680 if (KnownLHS.isKnownNever(Mask: fcNegative))
5681 Known.knownNot(RuleOut: fcNegative);
5682 if (KnownLHS.isKnownNever(Mask: fcPositive))
5683 Known.knownNot(RuleOut: fcPositive);
5684 }
5685
5686 break;
5687 }
5688 case Instruction::FPExt: {
5689 KnownFPClass KnownSrc;
5690 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5691 Known&: KnownSrc, Q, Depth: Depth + 1);
5692
5693 const fltSemantics &DstTy =
5694 Op->getType()->getScalarType()->getFltSemantics();
5695 const fltSemantics &SrcTy =
5696 Op->getOperand(i: 0)->getType()->getScalarType()->getFltSemantics();
5697
5698 Known = KnownFPClass::fpext(KnownSrc, DstTy, SrcTy);
5699 break;
5700 }
5701 case Instruction::FPTrunc: {
5702 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known, Q,
5703 Depth);
5704 break;
5705 }
5706 case Instruction::SIToFP:
5707 case Instruction::UIToFP: {
5708 // Cannot produce nan
5709 Known.knownNot(RuleOut: fcNan);
5710
5711 // Integers cannot be subnormal
5712 Known.knownNot(RuleOut: fcSubnormal);
5713
5714 // sitofp and uitofp turn into +0.0 for zero.
5715 Known.knownNot(RuleOut: fcNegZero);
5716 if (Op->getOpcode() == Instruction::UIToFP)
5717 Known.signBitMustBeZero();
5718
5719 if (InterestedClasses & fcInf) {
5720 // Get width of largest magnitude integer (remove a bit if signed).
5721 // This still works for a signed minimum value because the largest FP
5722 // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
5723 int IntSize = Op->getOperand(i: 0)->getType()->getScalarSizeInBits();
5724 if (Op->getOpcode() == Instruction::SIToFP)
5725 --IntSize;
5726
5727 // If the exponent of the largest finite FP value can hold the largest
5728 // integer, the result of the cast must be finite.
5729 Type *FPTy = Op->getType()->getScalarType();
5730 if (ilogb(Arg: APFloat::getLargest(Sem: FPTy->getFltSemantics())) >= IntSize)
5731 Known.knownNot(RuleOut: fcInf);
5732 }
5733
5734 break;
5735 }
5736 case Instruction::ExtractElement: {
5737 // Look through extract element. If the index is non-constant or
5738 // out-of-range demand all elements, otherwise just the extracted element.
5739 const Value *Vec = Op->getOperand(i: 0);
5740
5741 APInt DemandedVecElts;
5742 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
5743 unsigned NumElts = VecTy->getNumElements();
5744 DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
5745 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 1));
5746 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
5747 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
5748 } else {
5749 DemandedVecElts = APInt(1, 1);
5750 }
5751
5752 return computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known,
5753 Q, Depth: Depth + 1);
5754 }
5755 case Instruction::InsertElement: {
5756 if (isa<ScalableVectorType>(Val: Op->getType()))
5757 return;
5758
5759 const Value *Vec = Op->getOperand(i: 0);
5760 const Value *Elt = Op->getOperand(i: 1);
5761 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 2));
5762 unsigned NumElts = DemandedElts.getBitWidth();
5763 APInt DemandedVecElts = DemandedElts;
5764 bool NeedsElt = true;
5765 // If we know the index we are inserting to, clear it from Vec check.
5766 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
5767 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
5768 NeedsElt = DemandedElts[CIdx->getZExtValue()];
5769 }
5770
5771 // Do we demand the inserted element?
5772 if (NeedsElt) {
5773 computeKnownFPClass(V: Elt, Known, InterestedClasses, Q, Depth: Depth + 1);
5774 // If we don't know any bits, early out.
5775 if (Known.isUnknown())
5776 break;
5777 } else {
5778 Known.KnownFPClasses = fcNone;
5779 }
5780
5781 // Do we need anymore elements from Vec?
5782 if (!DemandedVecElts.isZero()) {
5783 KnownFPClass Known2;
5784 computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known&: Known2, Q,
5785 Depth: Depth + 1);
5786 Known |= Known2;
5787 }
5788
5789 break;
5790 }
5791 case Instruction::ShuffleVector: {
5792 // Handle vector splat idiom
5793 if (Value *Splat = getSplatValue(V)) {
5794 computeKnownFPClass(V: Splat, Known, InterestedClasses, Q, Depth: Depth + 1);
5795 break;
5796 }
5797
5798 // For undef elements, we don't know anything about the common state of
5799 // the shuffle result.
5800 APInt DemandedLHS, DemandedRHS;
5801 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Op);
5802 if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
5803 return;
5804
5805 if (!!DemandedLHS) {
5806 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
5807 computeKnownFPClass(V: LHS, DemandedElts: DemandedLHS, InterestedClasses, Known, Q,
5808 Depth: Depth + 1);
5809
5810 // If we don't know any bits, early out.
5811 if (Known.isUnknown())
5812 break;
5813 } else {
5814 Known.KnownFPClasses = fcNone;
5815 }
5816
5817 if (!!DemandedRHS) {
5818 KnownFPClass Known2;
5819 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
5820 computeKnownFPClass(V: RHS, DemandedElts: DemandedRHS, InterestedClasses, Known&: Known2, Q,
5821 Depth: Depth + 1);
5822 Known |= Known2;
5823 }
5824
5825 break;
5826 }
5827 case Instruction::ExtractValue: {
5828 const ExtractValueInst *Extract = cast<ExtractValueInst>(Val: Op);
5829 ArrayRef<unsigned> Indices = Extract->getIndices();
5830 const Value *Src = Extract->getAggregateOperand();
5831 if (isa<StructType>(Val: Src->getType()) && Indices.size() == 1 &&
5832 Indices[0] == 0) {
5833 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Src)) {
5834 switch (II->getIntrinsicID()) {
5835 case Intrinsic::frexp: {
5836 Known.knownNot(RuleOut: fcSubnormal);
5837
5838 KnownFPClass KnownSrc;
5839 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5840 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5841
5842 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5843 const fltSemantics &FltSem =
5844 Op->getType()->getScalarType()->getFltSemantics();
5845
5846 DenormalMode Mode =
5847 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5848 Known = KnownFPClass::frexp_mant(Src: KnownSrc, Mode);
5849 return;
5850 }
5851 default:
5852 break;
5853 }
5854 }
5855 }
5856
5857 computeKnownFPClass(V: Src, DemandedElts, InterestedClasses, Known, Q,
5858 Depth: Depth + 1);
5859 break;
5860 }
5861 case Instruction::PHI: {
5862 const PHINode *P = cast<PHINode>(Val: Op);
5863 // Unreachable blocks may have zero-operand PHI nodes.
5864 if (P->getNumIncomingValues() == 0)
5865 break;
5866
5867 // Otherwise take the unions of the known bit sets of the operands,
5868 // taking conservative care to avoid excessive recursion.
5869 const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2;
5870
5871 if (Depth < PhiRecursionLimit) {
5872 // Skip if every incoming value references to ourself.
5873 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
5874 break;
5875
5876 bool First = true;
5877
5878 for (const Use &U : P->operands()) {
5879 Value *IncValue;
5880 Instruction *CxtI;
5881 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI);
5882 // Skip direct self references.
5883 if (IncValue == P)
5884 continue;
5885
5886 KnownFPClass KnownSrc;
5887 // Recurse, but cap the recursion to two levels, because we don't want
5888 // to waste time spinning around in loops. We need at least depth 2 to
5889 // detect known sign bits.
5890 computeKnownFPClass(V: IncValue, DemandedElts, InterestedClasses, Known&: KnownSrc,
5891 Q: Q.getWithoutCondContext().getWithInstruction(I: CxtI),
5892 Depth: PhiRecursionLimit);
5893
5894 if (First) {
5895 Known = KnownSrc;
5896 First = false;
5897 } else {
5898 Known |= KnownSrc;
5899 }
5900
5901 if (Known.KnownFPClasses == fcAllFlags)
5902 break;
5903 }
5904 }
5905
5906 break;
5907 }
5908 case Instruction::BitCast: {
5909 const Value *Src;
5910 if (!match(V: Op, P: m_ElementWiseBitCast(Op: m_Value(V&: Src))) ||
5911 !Src->getType()->isIntOrIntVectorTy())
5912 break;
5913
5914 const Type *Ty = Op->getType()->getScalarType();
5915 KnownBits Bits(Ty->getScalarSizeInBits());
5916 computeKnownBits(V: Src, DemandedElts, Known&: Bits, Q, Depth: Depth + 1);
5917
5918 // Transfer information from the sign bit.
5919 if (Bits.isNonNegative())
5920 Known.signBitMustBeZero();
5921 else if (Bits.isNegative())
5922 Known.signBitMustBeOne();
5923
5924 if (Ty->isIEEELikeFPTy()) {
5925 // IEEE floats are NaN when all bits of the exponent plus at least one of
5926 // the fraction bits are 1. This means:
5927 // - If we assume unknown bits are 0 and the value is NaN, it will
5928 // always be NaN
5929 // - If we assume unknown bits are 1 and the value is not NaN, it can
5930 // never be NaN
5931 // Note: They do not hold for x86_fp80 format.
5932 if (APFloat(Ty->getFltSemantics(), Bits.One).isNaN())
5933 Known.KnownFPClasses = fcNan;
5934 else if (!APFloat(Ty->getFltSemantics(), ~Bits.Zero).isNaN())
5935 Known.knownNot(RuleOut: fcNan);
5936
5937 // Build KnownBits representing Inf and check if it must be equal or
5938 // unequal to this value.
5939 auto InfKB = KnownBits::makeConstant(
5940 C: APFloat::getInf(Sem: Ty->getFltSemantics()).bitcastToAPInt());
5941 InfKB.Zero.clearSignBit();
5942 if (const auto InfResult = KnownBits::eq(LHS: Bits, RHS: InfKB)) {
5943 assert(!InfResult.value());
5944 Known.knownNot(RuleOut: fcInf);
5945 } else if (Bits == InfKB) {
5946 Known.KnownFPClasses = fcInf;
5947 }
5948
5949 // Build KnownBits representing Zero and check if it must be equal or
5950 // unequal to this value.
5951 auto ZeroKB = KnownBits::makeConstant(
5952 C: APFloat::getZero(Sem: Ty->getFltSemantics()).bitcastToAPInt());
5953 ZeroKB.Zero.clearSignBit();
5954 if (const auto ZeroResult = KnownBits::eq(LHS: Bits, RHS: ZeroKB)) {
5955 assert(!ZeroResult.value());
5956 Known.knownNot(RuleOut: fcZero);
5957 } else if (Bits == ZeroKB) {
5958 Known.KnownFPClasses = fcZero;
5959 }
5960 }
5961
5962 break;
5963 }
5964 default:
5965 break;
5966 }
5967}
5968
5969KnownFPClass llvm::computeKnownFPClass(const Value *V,
5970 const APInt &DemandedElts,
5971 FPClassTest InterestedClasses,
5972 const SimplifyQuery &SQ,
5973 unsigned Depth) {
5974 KnownFPClass KnownClasses;
5975 ::computeKnownFPClass(V, DemandedElts, InterestedClasses, Known&: KnownClasses, Q: SQ,
5976 Depth);
5977 return KnownClasses;
5978}
5979
5980KnownFPClass llvm::computeKnownFPClass(const Value *V,
5981 FPClassTest InterestedClasses,
5982 const SimplifyQuery &SQ,
5983 unsigned Depth) {
5984 KnownFPClass Known;
5985 ::computeKnownFPClass(V, Known, InterestedClasses, Q: SQ, Depth);
5986 return Known;
5987}
5988
5989KnownFPClass llvm::computeKnownFPClass(
5990 const Value *V, const DataLayout &DL, FPClassTest InterestedClasses,
5991 const TargetLibraryInfo *TLI, AssumptionCache *AC, const Instruction *CxtI,
5992 const DominatorTree *DT, bool UseInstrInfo, unsigned Depth) {
5993 return computeKnownFPClass(V, InterestedClasses,
5994 SQ: SimplifyQuery(DL, TLI, DT, AC, CxtI, UseInstrInfo),
5995 Depth);
5996}
5997
5998KnownFPClass
5999llvm::computeKnownFPClass(const Value *V, const APInt &DemandedElts,
6000 FastMathFlags FMF, FPClassTest InterestedClasses,
6001 const SimplifyQuery &SQ, unsigned Depth) {
6002 if (FMF.noNaNs())
6003 InterestedClasses &= ~fcNan;
6004 if (FMF.noInfs())
6005 InterestedClasses &= ~fcInf;
6006
6007 KnownFPClass Result =
6008 computeKnownFPClass(V, DemandedElts, InterestedClasses, SQ, Depth);
6009
6010 if (FMF.noNaNs())
6011 Result.KnownFPClasses &= ~fcNan;
6012 if (FMF.noInfs())
6013 Result.KnownFPClasses &= ~fcInf;
6014 return Result;
6015}
6016
6017KnownFPClass llvm::computeKnownFPClass(const Value *V, FastMathFlags FMF,
6018 FPClassTest InterestedClasses,
6019 const SimplifyQuery &SQ,
6020 unsigned Depth) {
6021 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
6022 APInt DemandedElts =
6023 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
6024 return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, SQ,
6025 Depth);
6026}
6027
6028bool llvm::cannotBeNegativeZero(const Value *V, const SimplifyQuery &SQ,
6029 unsigned Depth) {
6030 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNegZero, SQ, Depth);
6031 return Known.isKnownNeverNegZero();
6032}
6033
6034bool llvm::cannotBeOrderedLessThanZero(const Value *V, const SimplifyQuery &SQ,
6035 unsigned Depth) {
6036 KnownFPClass Known =
6037 computeKnownFPClass(V, InterestedClasses: KnownFPClass::OrderedLessThanZeroMask, SQ, Depth);
6038 return Known.cannotBeOrderedLessThanZero();
6039}
6040
6041bool llvm::isKnownNeverInfinity(const Value *V, const SimplifyQuery &SQ,
6042 unsigned Depth) {
6043 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf, SQ, Depth);
6044 return Known.isKnownNeverInfinity();
6045}
6046
6047/// Return true if the floating-point value can never contain a NaN or infinity.
6048bool llvm::isKnownNeverInfOrNaN(const Value *V, const SimplifyQuery &SQ,
6049 unsigned Depth) {
6050 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf | fcNan, SQ, Depth);
6051 return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity();
6052}
6053
6054/// Return true if the floating-point scalar value is not a NaN or if the
6055/// floating-point vector value has no NaN elements. Return false if a value
6056/// could ever be NaN.
6057bool llvm::isKnownNeverNaN(const Value *V, const SimplifyQuery &SQ,
6058 unsigned Depth) {
6059 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNan, SQ, Depth);
6060 return Known.isKnownNeverNaN();
6061}
6062
6063/// Return false if we can prove that the specified FP value's sign bit is 0.
6064/// Return true if we can prove that the specified FP value's sign bit is 1.
6065/// Otherwise return std::nullopt.
6066std::optional<bool> llvm::computeKnownFPSignBit(const Value *V,
6067 const SimplifyQuery &SQ,
6068 unsigned Depth) {
6069 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcAllFlags, SQ, Depth);
6070 return Known.SignBit;
6071}
6072
6073bool llvm::canIgnoreSignBitOfZero(const Use &U) {
6074 auto *User = cast<Instruction>(Val: U.getUser());
6075 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6076 if (FPOp->hasNoSignedZeros())
6077 return true;
6078 }
6079
6080 switch (User->getOpcode()) {
6081 case Instruction::FPToSI:
6082 case Instruction::FPToUI:
6083 return true;
6084 case Instruction::FCmp:
6085 // fcmp treats both positive and negative zero as equal.
6086 return true;
6087 case Instruction::Call:
6088 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6089 switch (II->getIntrinsicID()) {
6090 case Intrinsic::fabs:
6091 return true;
6092 case Intrinsic::copysign:
6093 return U.getOperandNo() == 0;
6094 case Intrinsic::is_fpclass:
6095 case Intrinsic::vp_is_fpclass: {
6096 auto Test =
6097 static_cast<FPClassTest>(
6098 cast<ConstantInt>(Val: II->getArgOperand(i: 1))->getZExtValue()) &
6099 FPClassTest::fcZero;
6100 return Test == FPClassTest::fcZero || Test == FPClassTest::fcNone;
6101 }
6102 default:
6103 return false;
6104 }
6105 }
6106 return false;
6107 default:
6108 return false;
6109 }
6110}
6111
6112bool llvm::canIgnoreSignBitOfNaN(const Use &U) {
6113 auto *User = cast<Instruction>(Val: U.getUser());
6114 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6115 if (FPOp->hasNoNaNs())
6116 return true;
6117 }
6118
6119 switch (User->getOpcode()) {
6120 case Instruction::FPToSI:
6121 case Instruction::FPToUI:
6122 return true;
6123 // Proper FP math operations ignore the sign bit of NaN.
6124 case Instruction::FAdd:
6125 case Instruction::FSub:
6126 case Instruction::FMul:
6127 case Instruction::FDiv:
6128 case Instruction::FRem:
6129 case Instruction::FPTrunc:
6130 case Instruction::FPExt:
6131 case Instruction::FCmp:
6132 return true;
6133 // Bitwise FP operations should preserve the sign bit of NaN.
6134 case Instruction::FNeg:
6135 case Instruction::Select:
6136 case Instruction::PHI:
6137 return false;
6138 case Instruction::Ret:
6139 return User->getFunction()->getAttributes().getRetNoFPClass() &
6140 FPClassTest::fcNan;
6141 case Instruction::Call:
6142 case Instruction::Invoke: {
6143 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6144 switch (II->getIntrinsicID()) {
6145 case Intrinsic::fabs:
6146 return true;
6147 case Intrinsic::copysign:
6148 return U.getOperandNo() == 0;
6149 // Other proper FP math intrinsics ignore the sign bit of NaN.
6150 case Intrinsic::maxnum:
6151 case Intrinsic::minnum:
6152 case Intrinsic::maximum:
6153 case Intrinsic::minimum:
6154 case Intrinsic::maximumnum:
6155 case Intrinsic::minimumnum:
6156 case Intrinsic::canonicalize:
6157 case Intrinsic::fma:
6158 case Intrinsic::fmuladd:
6159 case Intrinsic::sqrt:
6160 case Intrinsic::pow:
6161 case Intrinsic::powi:
6162 case Intrinsic::fptoui_sat:
6163 case Intrinsic::fptosi_sat:
6164 case Intrinsic::is_fpclass:
6165 case Intrinsic::vp_is_fpclass:
6166 return true;
6167 default:
6168 return false;
6169 }
6170 }
6171
6172 FPClassTest NoFPClass =
6173 cast<CallBase>(Val: User)->getParamNoFPClass(i: U.getOperandNo());
6174 return NoFPClass & FPClassTest::fcNan;
6175 }
6176 default:
6177 return false;
6178 }
6179}
6180
6181bool llvm::isKnownIntegral(const Value *V, const SimplifyQuery &SQ,
6182 FastMathFlags FMF) {
6183 if (isa<PoisonValue>(Val: V))
6184 return true;
6185 if (isa<UndefValue>(Val: V))
6186 return false;
6187
6188 if (match(V, P: m_CheckedFp(CheckFn: [](const APFloat &Val) { return Val.isInteger(); })))
6189 return true;
6190
6191 const Instruction *I = dyn_cast<Instruction>(Val: V);
6192 if (!I)
6193 return false;
6194
6195 switch (I->getOpcode()) {
6196 case Instruction::SIToFP:
6197 case Instruction::UIToFP:
6198 // TODO: Could check nofpclass(inf) on incoming argument
6199 if (FMF.noInfs())
6200 return true;
6201
6202 // Need to check int size cannot produce infinity, which computeKnownFPClass
6203 // knows how to do already.
6204 return isKnownNeverInfinity(V: I, SQ);
6205 case Instruction::Call: {
6206 const CallInst *CI = cast<CallInst>(Val: I);
6207 switch (CI->getIntrinsicID()) {
6208 case Intrinsic::trunc:
6209 case Intrinsic::floor:
6210 case Intrinsic::ceil:
6211 case Intrinsic::rint:
6212 case Intrinsic::nearbyint:
6213 case Intrinsic::round:
6214 case Intrinsic::roundeven:
6215 return (FMF.noInfs() && FMF.noNaNs()) || isKnownNeverInfOrNaN(V: I, SQ);
6216 default:
6217 break;
6218 }
6219
6220 break;
6221 }
6222 default:
6223 break;
6224 }
6225
6226 return false;
6227}
6228
6229Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
6230
6231 // All byte-wide stores are splatable, even of arbitrary variables.
6232 if (V->getType()->isIntegerTy(Bitwidth: 8))
6233 return V;
6234
6235 LLVMContext &Ctx = V->getContext();
6236
6237 // Undef don't care.
6238 auto *UndefInt8 = UndefValue::get(T: Type::getInt8Ty(C&: Ctx));
6239 if (isa<UndefValue>(Val: V))
6240 return UndefInt8;
6241
6242 // Return poison for zero-sized type.
6243 if (DL.getTypeStoreSize(Ty: V->getType()).isZero())
6244 return PoisonValue::get(T: Type::getInt8Ty(C&: Ctx));
6245
6246 Constant *C = dyn_cast<Constant>(Val: V);
6247 if (!C) {
6248 // Conceptually, we could handle things like:
6249 // %a = zext i8 %X to i16
6250 // %b = shl i16 %a, 8
6251 // %c = or i16 %a, %b
6252 // but until there is an example that actually needs this, it doesn't seem
6253 // worth worrying about.
6254 return nullptr;
6255 }
6256
6257 // Handle 'null' ConstantArrayZero etc.
6258 if (C->isNullValue())
6259 return Constant::getNullValue(Ty: Type::getInt8Ty(C&: Ctx));
6260
6261 // Constant floating-point values can be handled as integer values if the
6262 // corresponding integer value is "byteable". An important case is 0.0.
6263 if (ConstantFP *CFP = dyn_cast<ConstantFP>(Val: C)) {
6264 Type *Ty = nullptr;
6265 if (CFP->getType()->isHalfTy())
6266 Ty = Type::getInt16Ty(C&: Ctx);
6267 else if (CFP->getType()->isFloatTy())
6268 Ty = Type::getInt32Ty(C&: Ctx);
6269 else if (CFP->getType()->isDoubleTy())
6270 Ty = Type::getInt64Ty(C&: Ctx);
6271 // Don't handle long double formats, which have strange constraints.
6272 return Ty ? isBytewiseValue(V: ConstantExpr::getBitCast(C: CFP, Ty), DL)
6273 : nullptr;
6274 }
6275
6276 // We can handle constant integers that are multiple of 8 bits.
6277 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: C)) {
6278 if (CI->getBitWidth() % 8 == 0) {
6279 assert(CI->getBitWidth() > 8 && "8 bits should be handled above!");
6280 if (!CI->getValue().isSplat(SplatSizeInBits: 8))
6281 return nullptr;
6282 return ConstantInt::get(Context&: Ctx, V: CI->getValue().trunc(width: 8));
6283 }
6284 }
6285
6286 if (auto *CE = dyn_cast<ConstantExpr>(Val: C)) {
6287 if (CE->getOpcode() == Instruction::IntToPtr) {
6288 if (auto *PtrTy = dyn_cast<PointerType>(Val: CE->getType())) {
6289 unsigned BitWidth = DL.getPointerSizeInBits(AS: PtrTy->getAddressSpace());
6290 if (Constant *Op = ConstantFoldIntegerCast(
6291 C: CE->getOperand(i_nocapture: 0), DestTy: Type::getIntNTy(C&: Ctx, N: BitWidth), IsSigned: false, DL))
6292 return isBytewiseValue(V: Op, DL);
6293 }
6294 }
6295 }
6296
6297 auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
6298 if (LHS == RHS)
6299 return LHS;
6300 if (!LHS || !RHS)
6301 return nullptr;
6302 if (LHS == UndefInt8)
6303 return RHS;
6304 if (RHS == UndefInt8)
6305 return LHS;
6306 return nullptr;
6307 };
6308
6309 if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(Val: C)) {
6310 Value *Val = UndefInt8;
6311 for (uint64_t I = 0, E = CA->getNumElements(); I != E; ++I)
6312 if (!(Val = Merge(Val, isBytewiseValue(V: CA->getElementAsConstant(i: I), DL))))
6313 return nullptr;
6314 return Val;
6315 }
6316
6317 if (isa<ConstantAggregate>(Val: C)) {
6318 Value *Val = UndefInt8;
6319 for (Value *Op : C->operands())
6320 if (!(Val = Merge(Val, isBytewiseValue(V: Op, DL))))
6321 return nullptr;
6322 return Val;
6323 }
6324
6325 // Don't try to handle the handful of other constants.
6326 return nullptr;
6327}
6328
6329// This is the recursive version of BuildSubAggregate. It takes a few different
6330// arguments. Idxs is the index within the nested struct From that we are
6331// looking at now (which is of type IndexedType). IdxSkip is the number of
6332// indices from Idxs that should be left out when inserting into the resulting
6333// struct. To is the result struct built so far, new insertvalue instructions
6334// build on that.
6335static Value *BuildSubAggregate(Value *From, Value *To, Type *IndexedType,
6336 SmallVectorImpl<unsigned> &Idxs,
6337 unsigned IdxSkip,
6338 BasicBlock::iterator InsertBefore) {
6339 StructType *STy = dyn_cast<StructType>(Val: IndexedType);
6340 if (STy) {
6341 // Save the original To argument so we can modify it
6342 Value *OrigTo = To;
6343 // General case, the type indexed by Idxs is a struct
6344 for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
6345 // Process each struct element recursively
6346 Idxs.push_back(Elt: i);
6347 Value *PrevTo = To;
6348 To = BuildSubAggregate(From, To, IndexedType: STy->getElementType(N: i), Idxs, IdxSkip,
6349 InsertBefore);
6350 Idxs.pop_back();
6351 if (!To) {
6352 // Couldn't find any inserted value for this index? Cleanup
6353 while (PrevTo != OrigTo) {
6354 InsertValueInst* Del = cast<InsertValueInst>(Val: PrevTo);
6355 PrevTo = Del->getAggregateOperand();
6356 Del->eraseFromParent();
6357 }
6358 // Stop processing elements
6359 break;
6360 }
6361 }
6362 // If we successfully found a value for each of our subaggregates
6363 if (To)
6364 return To;
6365 }
6366 // Base case, the type indexed by SourceIdxs is not a struct, or not all of
6367 // the struct's elements had a value that was inserted directly. In the latter
6368 // case, perhaps we can't determine each of the subelements individually, but
6369 // we might be able to find the complete struct somewhere.
6370
6371 // Find the value that is at that particular spot
6372 Value *V = FindInsertedValue(V: From, idx_range: Idxs);
6373
6374 if (!V)
6375 return nullptr;
6376
6377 // Insert the value in the new (sub) aggregate
6378 return InsertValueInst::Create(Agg: To, Val: V, Idxs: ArrayRef(Idxs).slice(N: IdxSkip), NameStr: "tmp",
6379 InsertBefore);
6380}
6381
6382// This helper takes a nested struct and extracts a part of it (which is again a
6383// struct) into a new value. For example, given the struct:
6384// { a, { b, { c, d }, e } }
6385// and the indices "1, 1" this returns
6386// { c, d }.
6387//
6388// It does this by inserting an insertvalue for each element in the resulting
6389// struct, as opposed to just inserting a single struct. This will only work if
6390// each of the elements of the substruct are known (ie, inserted into From by an
6391// insertvalue instruction somewhere).
6392//
6393// All inserted insertvalue instructions are inserted before InsertBefore
6394static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
6395 BasicBlock::iterator InsertBefore) {
6396 Type *IndexedType = ExtractValueInst::getIndexedType(Agg: From->getType(),
6397 Idxs: idx_range);
6398 Value *To = PoisonValue::get(T: IndexedType);
6399 SmallVector<unsigned, 10> Idxs(idx_range);
6400 unsigned IdxSkip = Idxs.size();
6401
6402 return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
6403}
6404
6405/// Given an aggregate and a sequence of indices, see if the scalar value
6406/// indexed is already around as a register, for example if it was inserted
6407/// directly into the aggregate.
6408///
6409/// If InsertBefore is not null, this function will duplicate (modified)
6410/// insertvalues when a part of a nested struct is extracted.
6411Value *
6412llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
6413 std::optional<BasicBlock::iterator> InsertBefore) {
6414 // Nothing to index? Just return V then (this is useful at the end of our
6415 // recursion).
6416 if (idx_range.empty())
6417 return V;
6418 // We have indices, so V should have an indexable type.
6419 assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
6420 "Not looking at a struct or array?");
6421 assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
6422 "Invalid indices for type?");
6423
6424 if (Constant *C = dyn_cast<Constant>(Val: V)) {
6425 C = C->getAggregateElement(Elt: idx_range[0]);
6426 if (!C) return nullptr;
6427 return FindInsertedValue(V: C, idx_range: idx_range.slice(N: 1), InsertBefore);
6428 }
6429
6430 if (InsertValueInst *I = dyn_cast<InsertValueInst>(Val: V)) {
6431 // Loop the indices for the insertvalue instruction in parallel with the
6432 // requested indices
6433 const unsigned *req_idx = idx_range.begin();
6434 for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
6435 i != e; ++i, ++req_idx) {
6436 if (req_idx == idx_range.end()) {
6437 // We can't handle this without inserting insertvalues
6438 if (!InsertBefore)
6439 return nullptr;
6440
6441 // The requested index identifies a part of a nested aggregate. Handle
6442 // this specially. For example,
6443 // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
6444 // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
6445 // %C = extractvalue {i32, { i32, i32 } } %B, 1
6446 // This can be changed into
6447 // %A = insertvalue {i32, i32 } undef, i32 10, 0
6448 // %C = insertvalue {i32, i32 } %A, i32 11, 1
6449 // which allows the unused 0,0 element from the nested struct to be
6450 // removed.
6451 return BuildSubAggregate(From: V, idx_range: ArrayRef(idx_range.begin(), req_idx),
6452 InsertBefore: *InsertBefore);
6453 }
6454
6455 // This insert value inserts something else than what we are looking for.
6456 // See if the (aggregate) value inserted into has the value we are
6457 // looking for, then.
6458 if (*req_idx != *i)
6459 return FindInsertedValue(V: I->getAggregateOperand(), idx_range,
6460 InsertBefore);
6461 }
6462 // If we end up here, the indices of the insertvalue match with those
6463 // requested (though possibly only partially). Now we recursively look at
6464 // the inserted value, passing any remaining indices.
6465 return FindInsertedValue(V: I->getInsertedValueOperand(),
6466 idx_range: ArrayRef(req_idx, idx_range.end()), InsertBefore);
6467 }
6468
6469 if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(Val: V)) {
6470 // If we're extracting a value from an aggregate that was extracted from
6471 // something else, we can extract from that something else directly instead.
6472 // However, we will need to chain I's indices with the requested indices.
6473
6474 // Calculate the number of indices required
6475 unsigned size = I->getNumIndices() + idx_range.size();
6476 // Allocate some space to put the new indices in
6477 SmallVector<unsigned, 5> Idxs;
6478 Idxs.reserve(N: size);
6479 // Add indices from the extract value instruction
6480 Idxs.append(in_start: I->idx_begin(), in_end: I->idx_end());
6481
6482 // Add requested indices
6483 Idxs.append(in_start: idx_range.begin(), in_end: idx_range.end());
6484
6485 assert(Idxs.size() == size
6486 && "Number of indices added not correct?");
6487
6488 return FindInsertedValue(V: I->getAggregateOperand(), idx_range: Idxs, InsertBefore);
6489 }
6490 // Otherwise, we don't know (such as, extracting from a function return value
6491 // or load instruction)
6492 return nullptr;
6493}
6494
6495// If V refers to an initialized global constant, set Slice either to
6496// its initializer if the size of its elements equals ElementSize, or,
6497// for ElementSize == 8, to its representation as an array of unsiged
6498// char. Return true on success.
6499// Offset is in the unit "nr of ElementSize sized elements".
6500bool llvm::getConstantDataArrayInfo(const Value *V,
6501 ConstantDataArraySlice &Slice,
6502 unsigned ElementSize, uint64_t Offset) {
6503 assert(V && "V should not be null.");
6504 assert((ElementSize % 8) == 0 &&
6505 "ElementSize expected to be a multiple of the size of a byte.");
6506 unsigned ElementSizeInBytes = ElementSize / 8;
6507
6508 // Drill down into the pointer expression V, ignoring any intervening
6509 // casts, and determine the identity of the object it references along
6510 // with the cumulative byte offset into it.
6511 const GlobalVariable *GV =
6512 dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V));
6513 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
6514 // Fail if V is not based on constant global object.
6515 return false;
6516
6517 const DataLayout &DL = GV->getDataLayout();
6518 APInt Off(DL.getIndexTypeSizeInBits(Ty: V->getType()), 0);
6519
6520 if (GV != V->stripAndAccumulateConstantOffsets(DL, Offset&: Off,
6521 /*AllowNonInbounds*/ true))
6522 // Fail if a constant offset could not be determined.
6523 return false;
6524
6525 uint64_t StartIdx = Off.getLimitedValue();
6526 if (StartIdx == UINT64_MAX)
6527 // Fail if the constant offset is excessive.
6528 return false;
6529
6530 // Off/StartIdx is in the unit of bytes. So we need to convert to number of
6531 // elements. Simply bail out if that isn't possible.
6532 if ((StartIdx % ElementSizeInBytes) != 0)
6533 return false;
6534
6535 Offset += StartIdx / ElementSizeInBytes;
6536 ConstantDataArray *Array = nullptr;
6537 ArrayType *ArrayTy = nullptr;
6538
6539 if (GV->getInitializer()->isNullValue()) {
6540 Type *GVTy = GV->getValueType();
6541 uint64_t SizeInBytes = DL.getTypeStoreSize(Ty: GVTy).getFixedValue();
6542 uint64_t Length = SizeInBytes / ElementSizeInBytes;
6543
6544 Slice.Array = nullptr;
6545 Slice.Offset = 0;
6546 // Return an empty Slice for undersized constants to let callers
6547 // transform even undefined library calls into simpler, well-defined
6548 // expressions. This is preferable to making the calls although it
6549 // prevents sanitizers from detecting such calls.
6550 Slice.Length = Length < Offset ? 0 : Length - Offset;
6551 return true;
6552 }
6553
6554 auto *Init = const_cast<Constant *>(GV->getInitializer());
6555 if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Val: Init)) {
6556 Type *InitElTy = ArrayInit->getElementType();
6557 if (InitElTy->isIntegerTy(Bitwidth: ElementSize)) {
6558 // If Init is an initializer for an array of the expected type
6559 // and size, use it as is.
6560 Array = ArrayInit;
6561 ArrayTy = ArrayInit->getType();
6562 }
6563 }
6564
6565 if (!Array) {
6566 if (ElementSize != 8)
6567 // TODO: Handle conversions to larger integral types.
6568 return false;
6569
6570 // Otherwise extract the portion of the initializer starting
6571 // at Offset as an array of bytes, and reset Offset.
6572 Init = ReadByteArrayFromGlobal(GV, Offset);
6573 if (!Init)
6574 return false;
6575
6576 Offset = 0;
6577 Array = dyn_cast<ConstantDataArray>(Val: Init);
6578 ArrayTy = dyn_cast<ArrayType>(Val: Init->getType());
6579 }
6580
6581 uint64_t NumElts = ArrayTy->getArrayNumElements();
6582 if (Offset > NumElts)
6583 return false;
6584
6585 Slice.Array = Array;
6586 Slice.Offset = Offset;
6587 Slice.Length = NumElts - Offset;
6588 return true;
6589}
6590
6591/// Extract bytes from the initializer of the constant array V, which need
6592/// not be a nul-terminated string. On success, store the bytes in Str and
6593/// return true. When TrimAtNul is set, Str will contain only the bytes up
6594/// to but not including the first nul. Return false on failure.
6595bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
6596 bool TrimAtNul) {
6597 ConstantDataArraySlice Slice;
6598 if (!getConstantDataArrayInfo(V, Slice, ElementSize: 8))
6599 return false;
6600
6601 if (Slice.Array == nullptr) {
6602 if (TrimAtNul) {
6603 // Return a nul-terminated string even for an empty Slice. This is
6604 // safe because all existing SimplifyLibcalls callers require string
6605 // arguments and the behavior of the functions they fold is undefined
6606 // otherwise. Folding the calls this way is preferable to making
6607 // the undefined library calls, even though it prevents sanitizers
6608 // from reporting such calls.
6609 Str = StringRef();
6610 return true;
6611 }
6612 if (Slice.Length == 1) {
6613 Str = StringRef("", 1);
6614 return true;
6615 }
6616 // We cannot instantiate a StringRef as we do not have an appropriate string
6617 // of 0s at hand.
6618 return false;
6619 }
6620
6621 // Start out with the entire array in the StringRef.
6622 Str = Slice.Array->getAsString();
6623 // Skip over 'offset' bytes.
6624 Str = Str.substr(Start: Slice.Offset);
6625
6626 if (TrimAtNul) {
6627 // Trim off the \0 and anything after it. If the array is not nul
6628 // terminated, we just return the whole end of string. The client may know
6629 // some other way that the string is length-bound.
6630 Str = Str.substr(Start: 0, N: Str.find(C: '\0'));
6631 }
6632 return true;
6633}
6634
6635// These next two are very similar to the above, but also look through PHI
6636// nodes.
6637// TODO: See if we can integrate these two together.
6638
6639/// If we can compute the length of the string pointed to by
6640/// the specified pointer, return 'len+1'. If we can't, return 0.
6641static uint64_t GetStringLengthH(const Value *V,
6642 SmallPtrSetImpl<const PHINode*> &PHIs,
6643 unsigned CharSize) {
6644 // Look through noop bitcast instructions.
6645 V = V->stripPointerCasts();
6646
6647 // If this is a PHI node, there are two cases: either we have already seen it
6648 // or we haven't.
6649 if (const PHINode *PN = dyn_cast<PHINode>(Val: V)) {
6650 if (!PHIs.insert(Ptr: PN).second)
6651 return ~0ULL; // already in the set.
6652
6653 // If it was new, see if all the input strings are the same length.
6654 uint64_t LenSoFar = ~0ULL;
6655 for (Value *IncValue : PN->incoming_values()) {
6656 uint64_t Len = GetStringLengthH(V: IncValue, PHIs, CharSize);
6657 if (Len == 0) return 0; // Unknown length -> unknown.
6658
6659 if (Len == ~0ULL) continue;
6660
6661 if (Len != LenSoFar && LenSoFar != ~0ULL)
6662 return 0; // Disagree -> unknown.
6663 LenSoFar = Len;
6664 }
6665
6666 // Success, all agree.
6667 return LenSoFar;
6668 }
6669
6670 // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
6671 if (const SelectInst *SI = dyn_cast<SelectInst>(Val: V)) {
6672 uint64_t Len1 = GetStringLengthH(V: SI->getTrueValue(), PHIs, CharSize);
6673 if (Len1 == 0) return 0;
6674 uint64_t Len2 = GetStringLengthH(V: SI->getFalseValue(), PHIs, CharSize);
6675 if (Len2 == 0) return 0;
6676 if (Len1 == ~0ULL) return Len2;
6677 if (Len2 == ~0ULL) return Len1;
6678 if (Len1 != Len2) return 0;
6679 return Len1;
6680 }
6681
6682 // Otherwise, see if we can read the string.
6683 ConstantDataArraySlice Slice;
6684 if (!getConstantDataArrayInfo(V, Slice, ElementSize: CharSize))
6685 return 0;
6686
6687 if (Slice.Array == nullptr)
6688 // Zeroinitializer (including an empty one).
6689 return 1;
6690
6691 // Search for the first nul character. Return a conservative result even
6692 // when there is no nul. This is safe since otherwise the string function
6693 // being folded such as strlen is undefined, and can be preferable to
6694 // making the undefined library call.
6695 unsigned NullIndex = 0;
6696 for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
6697 if (Slice.Array->getElementAsInteger(i: Slice.Offset + NullIndex) == 0)
6698 break;
6699 }
6700
6701 return NullIndex + 1;
6702}
6703
6704/// If we can compute the length of the string pointed to by
6705/// the specified pointer, return 'len+1'. If we can't, return 0.
6706uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
6707 if (!V->getType()->isPointerTy())
6708 return 0;
6709
6710 SmallPtrSet<const PHINode*, 32> PHIs;
6711 uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
6712 // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
6713 // an empty string as a length.
6714 return Len == ~0ULL ? 1 : Len;
6715}
6716
6717const Value *
6718llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
6719 bool MustPreserveNullness) {
6720 assert(Call &&
6721 "getArgumentAliasingToReturnedPointer only works on nonnull calls");
6722 if (const Value *RV = Call->getReturnedArgOperand())
6723 return RV;
6724 // This can be used only as a aliasing property.
6725 if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6726 Call, MustPreserveNullness))
6727 return Call->getArgOperand(i: 0);
6728 return nullptr;
6729}
6730
6731bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6732 const CallBase *Call, bool MustPreserveNullness) {
6733 switch (Call->getIntrinsicID()) {
6734 case Intrinsic::launder_invariant_group:
6735 case Intrinsic::strip_invariant_group:
6736 case Intrinsic::aarch64_irg:
6737 case Intrinsic::aarch64_tagp:
6738 // The amdgcn_make_buffer_rsrc function does not alter the address of the
6739 // input pointer (and thus preserve null-ness for the purposes of escape
6740 // analysis, which is where the MustPreserveNullness flag comes in to play).
6741 // However, it will not necessarily map ptr addrspace(N) null to ptr
6742 // addrspace(8) null, aka the "null descriptor", which has "all loads return
6743 // 0, all stores are dropped" semantics. Given the context of this intrinsic
6744 // list, no one should be relying on such a strict interpretation of
6745 // MustPreserveNullness (and, at time of writing, they are not), but we
6746 // document this fact out of an abundance of caution.
6747 case Intrinsic::amdgcn_make_buffer_rsrc:
6748 return true;
6749 case Intrinsic::ptrmask:
6750 return !MustPreserveNullness;
6751 case Intrinsic::threadlocal_address:
6752 // The underlying variable changes with thread ID. The Thread ID may change
6753 // at coroutine suspend points.
6754 return !Call->getParent()->getParent()->isPresplitCoroutine();
6755 default:
6756 return false;
6757 }
6758}
6759
6760/// \p PN defines a loop-variant pointer to an object. Check if the
6761/// previous iteration of the loop was referring to the same object as \p PN.
6762static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
6763 const LoopInfo *LI) {
6764 // Find the loop-defined value.
6765 Loop *L = LI->getLoopFor(BB: PN->getParent());
6766 if (PN->getNumIncomingValues() != 2)
6767 return true;
6768
6769 // Find the value from previous iteration.
6770 auto *PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
6771 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6772 PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
6773 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6774 return true;
6775
6776 // If a new pointer is loaded in the loop, the pointer references a different
6777 // object in every iteration. E.g.:
6778 // for (i)
6779 // int *p = a[i];
6780 // ...
6781 if (auto *Load = dyn_cast<LoadInst>(Val: PrevValue))
6782 if (!L->isLoopInvariant(V: Load->getPointerOperand()))
6783 return false;
6784 return true;
6785}
6786
6787const Value *llvm::getUnderlyingObject(const Value *V, unsigned MaxLookup) {
6788 for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
6789 if (auto *GEP = dyn_cast<GEPOperator>(Val: V)) {
6790 const Value *PtrOp = GEP->getPointerOperand();
6791 if (!PtrOp->getType()->isPointerTy()) // Only handle scalar pointer base.
6792 return V;
6793 V = PtrOp;
6794 } else if (Operator::getOpcode(V) == Instruction::BitCast ||
6795 Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
6796 Value *NewV = cast<Operator>(Val: V)->getOperand(i: 0);
6797 if (!NewV->getType()->isPointerTy())
6798 return V;
6799 V = NewV;
6800 } else if (auto *GA = dyn_cast<GlobalAlias>(Val: V)) {
6801 if (GA->isInterposable())
6802 return V;
6803 V = GA->getAliasee();
6804 } else {
6805 if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
6806 // Look through single-arg phi nodes created by LCSSA.
6807 if (PHI->getNumIncomingValues() == 1) {
6808 V = PHI->getIncomingValue(i: 0);
6809 continue;
6810 }
6811 } else if (auto *Call = dyn_cast<CallBase>(Val: V)) {
6812 // CaptureTracking can know about special capturing properties of some
6813 // intrinsics like launder.invariant.group, that can't be expressed with
6814 // the attributes, but have properties like returning aliasing pointer.
6815 // Because some analysis may assume that nocaptured pointer is not
6816 // returned from some special intrinsic (because function would have to
6817 // be marked with returns attribute), it is crucial to use this function
6818 // because it should be in sync with CaptureTracking. Not using it may
6819 // cause weird miscompilations where 2 aliasing pointers are assumed to
6820 // noalias.
6821 if (auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: false)) {
6822 V = RP;
6823 continue;
6824 }
6825 }
6826
6827 return V;
6828 }
6829 assert(V->getType()->isPointerTy() && "Unexpected operand type!");
6830 }
6831 return V;
6832}
6833
6834void llvm::getUnderlyingObjects(const Value *V,
6835 SmallVectorImpl<const Value *> &Objects,
6836 const LoopInfo *LI, unsigned MaxLookup) {
6837 SmallPtrSet<const Value *, 4> Visited;
6838 SmallVector<const Value *, 4> Worklist;
6839 Worklist.push_back(Elt: V);
6840 do {
6841 const Value *P = Worklist.pop_back_val();
6842 P = getUnderlyingObject(V: P, MaxLookup);
6843
6844 if (!Visited.insert(Ptr: P).second)
6845 continue;
6846
6847 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6848 Worklist.push_back(Elt: SI->getTrueValue());
6849 Worklist.push_back(Elt: SI->getFalseValue());
6850 continue;
6851 }
6852
6853 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6854 // If this PHI changes the underlying object in every iteration of the
6855 // loop, don't look through it. Consider:
6856 // int **A;
6857 // for (i) {
6858 // Prev = Curr; // Prev = PHI (Prev_0, Curr)
6859 // Curr = A[i];
6860 // *Prev, *Curr;
6861 //
6862 // Prev is tracking Curr one iteration behind so they refer to different
6863 // underlying objects.
6864 if (!LI || !LI->isLoopHeader(BB: PN->getParent()) ||
6865 isSameUnderlyingObjectInLoop(PN, LI))
6866 append_range(C&: Worklist, R: PN->incoming_values());
6867 else
6868 Objects.push_back(Elt: P);
6869 continue;
6870 }
6871
6872 Objects.push_back(Elt: P);
6873 } while (!Worklist.empty());
6874}
6875
6876const Value *llvm::getUnderlyingObjectAggressive(const Value *V) {
6877 const unsigned MaxVisited = 8;
6878
6879 SmallPtrSet<const Value *, 8> Visited;
6880 SmallVector<const Value *, 8> Worklist;
6881 Worklist.push_back(Elt: V);
6882 const Value *Object = nullptr;
6883 // Used as fallback if we can't find a common underlying object through
6884 // recursion.
6885 bool First = true;
6886 const Value *FirstObject = getUnderlyingObject(V);
6887 do {
6888 const Value *P = Worklist.pop_back_val();
6889 P = First ? FirstObject : getUnderlyingObject(V: P);
6890 First = false;
6891
6892 if (!Visited.insert(Ptr: P).second)
6893 continue;
6894
6895 if (Visited.size() == MaxVisited)
6896 return FirstObject;
6897
6898 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6899 Worklist.push_back(Elt: SI->getTrueValue());
6900 Worklist.push_back(Elt: SI->getFalseValue());
6901 continue;
6902 }
6903
6904 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6905 append_range(C&: Worklist, R: PN->incoming_values());
6906 continue;
6907 }
6908
6909 if (!Object)
6910 Object = P;
6911 else if (Object != P)
6912 return FirstObject;
6913 } while (!Worklist.empty());
6914
6915 return Object ? Object : FirstObject;
6916}
6917
6918/// This is the function that does the work of looking through basic
6919/// ptrtoint+arithmetic+inttoptr sequences.
6920static const Value *getUnderlyingObjectFromInt(const Value *V) {
6921 do {
6922 if (const Operator *U = dyn_cast<Operator>(Val: V)) {
6923 // If we find a ptrtoint, we can transfer control back to the
6924 // regular getUnderlyingObjectFromInt.
6925 if (U->getOpcode() == Instruction::PtrToInt)
6926 return U->getOperand(i: 0);
6927 // If we find an add of a constant, a multiplied value, or a phi, it's
6928 // likely that the other operand will lead us to the base
6929 // object. We don't have to worry about the case where the
6930 // object address is somehow being computed by the multiply,
6931 // because our callers only care when the result is an
6932 // identifiable object.
6933 if (U->getOpcode() != Instruction::Add ||
6934 (!isa<ConstantInt>(Val: U->getOperand(i: 1)) &&
6935 Operator::getOpcode(V: U->getOperand(i: 1)) != Instruction::Mul &&
6936 !isa<PHINode>(Val: U->getOperand(i: 1))))
6937 return V;
6938 V = U->getOperand(i: 0);
6939 } else {
6940 return V;
6941 }
6942 assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
6943 } while (true);
6944}
6945
6946/// This is a wrapper around getUnderlyingObjects and adds support for basic
6947/// ptrtoint+arithmetic+inttoptr sequences.
6948/// It returns false if unidentified object is found in getUnderlyingObjects.
6949bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
6950 SmallVectorImpl<Value *> &Objects) {
6951 SmallPtrSet<const Value *, 16> Visited;
6952 SmallVector<const Value *, 4> Working(1, V);
6953 do {
6954 V = Working.pop_back_val();
6955
6956 SmallVector<const Value *, 4> Objs;
6957 getUnderlyingObjects(V, Objects&: Objs);
6958
6959 for (const Value *V : Objs) {
6960 if (!Visited.insert(Ptr: V).second)
6961 continue;
6962 if (Operator::getOpcode(V) == Instruction::IntToPtr) {
6963 const Value *O =
6964 getUnderlyingObjectFromInt(V: cast<User>(Val: V)->getOperand(i: 0));
6965 if (O->getType()->isPointerTy()) {
6966 Working.push_back(Elt: O);
6967 continue;
6968 }
6969 }
6970 // If getUnderlyingObjects fails to find an identifiable object,
6971 // getUnderlyingObjectsForCodeGen also fails for safety.
6972 if (!isIdentifiedObject(V)) {
6973 Objects.clear();
6974 return false;
6975 }
6976 Objects.push_back(Elt: const_cast<Value *>(V));
6977 }
6978 } while (!Working.empty());
6979 return true;
6980}
6981
6982AllocaInst *llvm::findAllocaForValue(Value *V, bool OffsetZero) {
6983 AllocaInst *Result = nullptr;
6984 SmallPtrSet<Value *, 4> Visited;
6985 SmallVector<Value *, 4> Worklist;
6986
6987 auto AddWork = [&](Value *V) {
6988 if (Visited.insert(Ptr: V).second)
6989 Worklist.push_back(Elt: V);
6990 };
6991
6992 AddWork(V);
6993 do {
6994 V = Worklist.pop_back_val();
6995 assert(Visited.count(V));
6996
6997 if (AllocaInst *AI = dyn_cast<AllocaInst>(Val: V)) {
6998 if (Result && Result != AI)
6999 return nullptr;
7000 Result = AI;
7001 } else if (CastInst *CI = dyn_cast<CastInst>(Val: V)) {
7002 AddWork(CI->getOperand(i_nocapture: 0));
7003 } else if (PHINode *PN = dyn_cast<PHINode>(Val: V)) {
7004 for (Value *IncValue : PN->incoming_values())
7005 AddWork(IncValue);
7006 } else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
7007 AddWork(SI->getTrueValue());
7008 AddWork(SI->getFalseValue());
7009 } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: V)) {
7010 if (OffsetZero && !GEP->hasAllZeroIndices())
7011 return nullptr;
7012 AddWork(GEP->getPointerOperand());
7013 } else if (CallBase *CB = dyn_cast<CallBase>(Val: V)) {
7014 Value *Returned = CB->getReturnedArgOperand();
7015 if (Returned)
7016 AddWork(Returned);
7017 else
7018 return nullptr;
7019 } else {
7020 return nullptr;
7021 }
7022 } while (!Worklist.empty());
7023
7024 return Result;
7025}
7026
7027static bool onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7028 const Value *V, bool AllowLifetime, bool AllowDroppable) {
7029 for (const User *U : V->users()) {
7030 const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U);
7031 if (!II)
7032 return false;
7033
7034 if (AllowLifetime && II->isLifetimeStartOrEnd())
7035 continue;
7036
7037 if (AllowDroppable && II->isDroppable())
7038 continue;
7039
7040 return false;
7041 }
7042 return true;
7043}
7044
7045bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
7046 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7047 V, /* AllowLifetime */ true, /* AllowDroppable */ false);
7048}
7049bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
7050 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7051 V, /* AllowLifetime */ true, /* AllowDroppable */ true);
7052}
7053
7054bool llvm::isNotCrossLaneOperation(const Instruction *I) {
7055 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
7056 return isTriviallyVectorizable(ID: II->getIntrinsicID());
7057 auto *Shuffle = dyn_cast<ShuffleVectorInst>(Val: I);
7058 return (!Shuffle || Shuffle->isSelect()) &&
7059 !isa<CallBase, BitCastInst, ExtractElementInst>(Val: I);
7060}
7061
7062bool llvm::isSafeToSpeculativelyExecute(
7063 const Instruction *Inst, const Instruction *CtxI, AssumptionCache *AC,
7064 const DominatorTree *DT, const TargetLibraryInfo *TLI, bool UseVariableInfo,
7065 bool IgnoreUBImplyingAttrs) {
7066 return isSafeToSpeculativelyExecuteWithOpcode(Opcode: Inst->getOpcode(), Inst, CtxI,
7067 AC, DT, TLI, UseVariableInfo,
7068 IgnoreUBImplyingAttrs);
7069}
7070
7071bool llvm::isSafeToSpeculativelyExecuteWithOpcode(
7072 unsigned Opcode, const Instruction *Inst, const Instruction *CtxI,
7073 AssumptionCache *AC, const DominatorTree *DT, const TargetLibraryInfo *TLI,
7074 bool UseVariableInfo, bool IgnoreUBImplyingAttrs) {
7075#ifndef NDEBUG
7076 if (Inst->getOpcode() != Opcode) {
7077 // Check that the operands are actually compatible with the Opcode override.
7078 auto hasEqualReturnAndLeadingOperandTypes =
7079 [](const Instruction *Inst, unsigned NumLeadingOperands) {
7080 if (Inst->getNumOperands() < NumLeadingOperands)
7081 return false;
7082 const Type *ExpectedType = Inst->getType();
7083 for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp)
7084 if (Inst->getOperand(ItOp)->getType() != ExpectedType)
7085 return false;
7086 return true;
7087 };
7088 assert(!Instruction::isBinaryOp(Opcode) ||
7089 hasEqualReturnAndLeadingOperandTypes(Inst, 2));
7090 assert(!Instruction::isUnaryOp(Opcode) ||
7091 hasEqualReturnAndLeadingOperandTypes(Inst, 1));
7092 }
7093#endif
7094
7095 switch (Opcode) {
7096 default:
7097 return true;
7098 case Instruction::UDiv:
7099 case Instruction::URem: {
7100 // x / y is undefined if y == 0.
7101 const APInt *V;
7102 if (match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: V)))
7103 return *V != 0;
7104 return false;
7105 }
7106 case Instruction::SDiv:
7107 case Instruction::SRem: {
7108 // x / y is undefined if y == 0 or x == INT_MIN and y == -1
7109 const APInt *Numerator, *Denominator;
7110 if (!match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: Denominator)))
7111 return false;
7112 // We cannot hoist this division if the denominator is 0.
7113 if (*Denominator == 0)
7114 return false;
7115 // It's safe to hoist if the denominator is not 0 or -1.
7116 if (!Denominator->isAllOnes())
7117 return true;
7118 // At this point we know that the denominator is -1. It is safe to hoist as
7119 // long we know that the numerator is not INT_MIN.
7120 if (match(V: Inst->getOperand(i: 0), P: m_APInt(Res&: Numerator)))
7121 return !Numerator->isMinSignedValue();
7122 // The numerator *might* be MinSignedValue.
7123 return false;
7124 }
7125 case Instruction::Load: {
7126 if (!UseVariableInfo)
7127 return false;
7128
7129 const LoadInst *LI = dyn_cast<LoadInst>(Val: Inst);
7130 if (!LI)
7131 return false;
7132 if (mustSuppressSpeculation(LI: *LI))
7133 return false;
7134 const DataLayout &DL = LI->getDataLayout();
7135 return isDereferenceableAndAlignedPointer(V: LI->getPointerOperand(),
7136 Ty: LI->getType(), Alignment: LI->getAlign(), DL,
7137 CtxI, AC, DT, TLI);
7138 }
7139 case Instruction::Call: {
7140 auto *CI = dyn_cast<const CallInst>(Val: Inst);
7141 if (!CI)
7142 return false;
7143 const Function *Callee = CI->getCalledFunction();
7144
7145 // The called function could have undefined behavior or side-effects, even
7146 // if marked readnone nounwind.
7147 if (!Callee || !Callee->isSpeculatable())
7148 return false;
7149 // Since the operands may be changed after hoisting, undefined behavior may
7150 // be triggered by some UB-implying attributes.
7151 return IgnoreUBImplyingAttrs || !CI->hasUBImplyingAttrs();
7152 }
7153 case Instruction::VAArg:
7154 case Instruction::Alloca:
7155 case Instruction::Invoke:
7156 case Instruction::CallBr:
7157 case Instruction::PHI:
7158 case Instruction::Store:
7159 case Instruction::Ret:
7160 case Instruction::Br:
7161 case Instruction::IndirectBr:
7162 case Instruction::Switch:
7163 case Instruction::Unreachable:
7164 case Instruction::Fence:
7165 case Instruction::AtomicRMW:
7166 case Instruction::AtomicCmpXchg:
7167 case Instruction::LandingPad:
7168 case Instruction::Resume:
7169 case Instruction::CatchSwitch:
7170 case Instruction::CatchPad:
7171 case Instruction::CatchRet:
7172 case Instruction::CleanupPad:
7173 case Instruction::CleanupRet:
7174 return false; // Misc instructions which have effects
7175 }
7176}
7177
7178bool llvm::mayHaveNonDefUseDependency(const Instruction &I) {
7179 if (I.mayReadOrWriteMemory())
7180 // Memory dependency possible
7181 return true;
7182 if (!isSafeToSpeculativelyExecute(Inst: &I))
7183 // Can't move above a maythrow call or infinite loop. Or if an
7184 // inalloca alloca, above a stacksave call.
7185 return true;
7186 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7187 // 1) Can't reorder two inf-loop calls, even if readonly
7188 // 2) Also can't reorder an inf-loop call below a instruction which isn't
7189 // safe to speculative execute. (Inverse of above)
7190 return true;
7191 return false;
7192}
7193
7194/// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
7195static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
7196 switch (OR) {
7197 case ConstantRange::OverflowResult::MayOverflow:
7198 return OverflowResult::MayOverflow;
7199 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7200 return OverflowResult::AlwaysOverflowsLow;
7201 case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
7202 return OverflowResult::AlwaysOverflowsHigh;
7203 case ConstantRange::OverflowResult::NeverOverflows:
7204 return OverflowResult::NeverOverflows;
7205 }
7206 llvm_unreachable("Unknown OverflowResult");
7207}
7208
7209/// Combine constant ranges from computeConstantRange() and computeKnownBits().
7210ConstantRange
7211llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
7212 bool ForSigned,
7213 const SimplifyQuery &SQ) {
7214 ConstantRange CR1 =
7215 ConstantRange::fromKnownBits(Known: V.getKnownBits(Q: SQ), IsSigned: ForSigned);
7216 ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo: SQ.IIQ.UseInstrInfo);
7217 ConstantRange::PreferredRangeType RangeType =
7218 ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
7219 return CR1.intersectWith(CR: CR2, Type: RangeType);
7220}
7221
7222OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
7223 const Value *RHS,
7224 const SimplifyQuery &SQ,
7225 bool IsNSW) {
7226 ConstantRange LHSRange =
7227 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7228 ConstantRange RHSRange =
7229 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7230
7231 // mul nsw of two non-negative numbers is also nuw.
7232 if (IsNSW && LHSRange.isAllNonNegative() && RHSRange.isAllNonNegative())
7233 return OverflowResult::NeverOverflows;
7234
7235 return mapOverflowResult(OR: LHSRange.unsignedMulMayOverflow(Other: RHSRange));
7236}
7237
7238OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
7239 const Value *RHS,
7240 const SimplifyQuery &SQ) {
7241 // Multiplying n * m significant bits yields a result of n + m significant
7242 // bits. If the total number of significant bits does not exceed the
7243 // result bit width (minus 1), there is no overflow.
7244 // This means if we have enough leading sign bits in the operands
7245 // we can guarantee that the result does not overflow.
7246 // Ref: "Hacker's Delight" by Henry Warren
7247 unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
7248
7249 // Note that underestimating the number of sign bits gives a more
7250 // conservative answer.
7251 unsigned SignBits =
7252 ::ComputeNumSignBits(V: LHS, Q: SQ) + ::ComputeNumSignBits(V: RHS, Q: SQ);
7253
7254 // First handle the easy case: if we have enough sign bits there's
7255 // definitely no overflow.
7256 if (SignBits > BitWidth + 1)
7257 return OverflowResult::NeverOverflows;
7258
7259 // There are two ambiguous cases where there can be no overflow:
7260 // SignBits == BitWidth + 1 and
7261 // SignBits == BitWidth
7262 // The second case is difficult to check, therefore we only handle the
7263 // first case.
7264 if (SignBits == BitWidth + 1) {
7265 // It overflows only when both arguments are negative and the true
7266 // product is exactly the minimum negative number.
7267 // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
7268 // For simplicity we just check if at least one side is not negative.
7269 KnownBits LHSKnown = computeKnownBits(V: LHS, Q: SQ);
7270 KnownBits RHSKnown = computeKnownBits(V: RHS, Q: SQ);
7271 if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
7272 return OverflowResult::NeverOverflows;
7273 }
7274 return OverflowResult::MayOverflow;
7275}
7276
7277OverflowResult
7278llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
7279 const WithCache<const Value *> &RHS,
7280 const SimplifyQuery &SQ) {
7281 ConstantRange LHSRange =
7282 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7283 ConstantRange RHSRange =
7284 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7285 return mapOverflowResult(OR: LHSRange.unsignedAddMayOverflow(Other: RHSRange));
7286}
7287
7288static OverflowResult
7289computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7290 const WithCache<const Value *> &RHS,
7291 const AddOperator *Add, const SimplifyQuery &SQ) {
7292 if (Add && Add->hasNoSignedWrap()) {
7293 return OverflowResult::NeverOverflows;
7294 }
7295
7296 // If LHS and RHS each have at least two sign bits, the addition will look
7297 // like
7298 //
7299 // XX..... +
7300 // YY.....
7301 //
7302 // If the carry into the most significant position is 0, X and Y can't both
7303 // be 1 and therefore the carry out of the addition is also 0.
7304 //
7305 // If the carry into the most significant position is 1, X and Y can't both
7306 // be 0 and therefore the carry out of the addition is also 1.
7307 //
7308 // Since the carry into the most significant position is always equal to
7309 // the carry out of the addition, there is no signed overflow.
7310 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7311 return OverflowResult::NeverOverflows;
7312
7313 ConstantRange LHSRange =
7314 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7315 ConstantRange RHSRange =
7316 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7317 OverflowResult OR =
7318 mapOverflowResult(OR: LHSRange.signedAddMayOverflow(Other: RHSRange));
7319 if (OR != OverflowResult::MayOverflow)
7320 return OR;
7321
7322 // The remaining code needs Add to be available. Early returns if not so.
7323 if (!Add)
7324 return OverflowResult::MayOverflow;
7325
7326 // If the sign of Add is the same as at least one of the operands, this add
7327 // CANNOT overflow. If this can be determined from the known bits of the
7328 // operands the above signedAddMayOverflow() check will have already done so.
7329 // The only other way to improve on the known bits is from an assumption, so
7330 // call computeKnownBitsFromContext() directly.
7331 bool LHSOrRHSKnownNonNegative =
7332 (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
7333 bool LHSOrRHSKnownNegative =
7334 (LHSRange.isAllNegative() || RHSRange.isAllNegative());
7335 if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
7336 KnownBits AddKnown(LHSRange.getBitWidth());
7337 computeKnownBitsFromContext(V: Add, Known&: AddKnown, Q: SQ);
7338 if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
7339 (AddKnown.isNegative() && LHSOrRHSKnownNegative))
7340 return OverflowResult::NeverOverflows;
7341 }
7342
7343 return OverflowResult::MayOverflow;
7344}
7345
7346OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
7347 const Value *RHS,
7348 const SimplifyQuery &SQ) {
7349 // X - (X % ?)
7350 // The remainder of a value can't have greater magnitude than itself,
7351 // so the subtraction can't overflow.
7352
7353 // X - (X -nuw ?)
7354 // In the minimal case, this would simplify to "?", so there's no subtract
7355 // at all. But if this analysis is used to peek through casts, for example,
7356 // then determining no-overflow may allow other transforms.
7357
7358 // TODO: There are other patterns like this.
7359 // See simplifyICmpWithBinOpOnLHS() for candidates.
7360 if (match(V: RHS, P: m_URem(L: m_Specific(V: LHS), R: m_Value())) ||
7361 match(V: RHS, P: m_NUWSub(L: m_Specific(V: LHS), R: m_Value())))
7362 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7363 return OverflowResult::NeverOverflows;
7364
7365 if (auto C = isImpliedByDomCondition(Pred: CmpInst::ICMP_UGE, LHS, RHS, ContextI: SQ.CxtI,
7366 DL: SQ.DL)) {
7367 if (*C)
7368 return OverflowResult::NeverOverflows;
7369 return OverflowResult::AlwaysOverflowsLow;
7370 }
7371
7372 ConstantRange LHSRange =
7373 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7374 ConstantRange RHSRange =
7375 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7376 return mapOverflowResult(OR: LHSRange.unsignedSubMayOverflow(Other: RHSRange));
7377}
7378
7379OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
7380 const Value *RHS,
7381 const SimplifyQuery &SQ) {
7382 // X - (X % ?)
7383 // The remainder of a value can't have greater magnitude than itself,
7384 // so the subtraction can't overflow.
7385
7386 // X - (X -nsw ?)
7387 // In the minimal case, this would simplify to "?", so there's no subtract
7388 // at all. But if this analysis is used to peek through casts, for example,
7389 // then determining no-overflow may allow other transforms.
7390 if (match(V: RHS, P: m_SRem(L: m_Specific(V: LHS), R: m_Value())) ||
7391 match(V: RHS, P: m_NSWSub(L: m_Specific(V: LHS), R: m_Value())))
7392 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7393 return OverflowResult::NeverOverflows;
7394
7395 // If LHS and RHS each have at least two sign bits, the subtraction
7396 // cannot overflow.
7397 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7398 return OverflowResult::NeverOverflows;
7399
7400 ConstantRange LHSRange =
7401 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7402 ConstantRange RHSRange =
7403 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7404 return mapOverflowResult(OR: LHSRange.signedSubMayOverflow(Other: RHSRange));
7405}
7406
7407bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
7408 const DominatorTree &DT) {
7409 SmallVector<const BranchInst *, 2> GuardingBranches;
7410 SmallVector<const ExtractValueInst *, 2> Results;
7411
7412 for (const User *U : WO->users()) {
7413 if (const auto *EVI = dyn_cast<ExtractValueInst>(Val: U)) {
7414 assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
7415
7416 if (EVI->getIndices()[0] == 0)
7417 Results.push_back(Elt: EVI);
7418 else {
7419 assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
7420
7421 for (const auto *U : EVI->users())
7422 if (const auto *B = dyn_cast<BranchInst>(Val: U)) {
7423 assert(B->isConditional() && "How else is it using an i1?");
7424 GuardingBranches.push_back(Elt: B);
7425 }
7426 }
7427 } else {
7428 // We are using the aggregate directly in a way we don't want to analyze
7429 // here (storing it to a global, say).
7430 return false;
7431 }
7432 }
7433
7434 auto AllUsesGuardedByBranch = [&](const BranchInst *BI) {
7435 BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(i: 1));
7436 if (!NoWrapEdge.isSingleEdge())
7437 return false;
7438
7439 // Check if all users of the add are provably no-wrap.
7440 for (const auto *Result : Results) {
7441 // If the extractvalue itself is not executed on overflow, the we don't
7442 // need to check each use separately, since domination is transitive.
7443 if (DT.dominates(BBE: NoWrapEdge, BB: Result->getParent()))
7444 continue;
7445
7446 for (const auto &RU : Result->uses())
7447 if (!DT.dominates(BBE: NoWrapEdge, U: RU))
7448 return false;
7449 }
7450
7451 return true;
7452 };
7453
7454 return llvm::any_of(Range&: GuardingBranches, P: AllUsesGuardedByBranch);
7455}
7456
7457/// Shifts return poison if shiftwidth is larger than the bitwidth.
7458static bool shiftAmountKnownInRange(const Value *ShiftAmount) {
7459 auto *C = dyn_cast<Constant>(Val: ShiftAmount);
7460 if (!C)
7461 return false;
7462
7463 // Shifts return poison if shiftwidth is larger than the bitwidth.
7464 SmallVector<const Constant *, 4> ShiftAmounts;
7465 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: C->getType())) {
7466 unsigned NumElts = FVTy->getNumElements();
7467 for (unsigned i = 0; i < NumElts; ++i)
7468 ShiftAmounts.push_back(Elt: C->getAggregateElement(Elt: i));
7469 } else if (isa<ScalableVectorType>(Val: C->getType()))
7470 return false; // Can't tell, just return false to be safe
7471 else
7472 ShiftAmounts.push_back(Elt: C);
7473
7474 bool Safe = llvm::all_of(Range&: ShiftAmounts, P: [](const Constant *C) {
7475 auto *CI = dyn_cast_or_null<ConstantInt>(Val: C);
7476 return CI && CI->getValue().ult(RHS: C->getType()->getIntegerBitWidth());
7477 });
7478
7479 return Safe;
7480}
7481
7482enum class UndefPoisonKind {
7483 PoisonOnly = (1 << 0),
7484 UndefOnly = (1 << 1),
7485 UndefOrPoison = PoisonOnly | UndefOnly,
7486};
7487
7488static bool includesPoison(UndefPoisonKind Kind) {
7489 return (unsigned(Kind) & unsigned(UndefPoisonKind::PoisonOnly)) != 0;
7490}
7491
7492static bool includesUndef(UndefPoisonKind Kind) {
7493 return (unsigned(Kind) & unsigned(UndefPoisonKind::UndefOnly)) != 0;
7494}
7495
7496static bool canCreateUndefOrPoison(const Operator *Op, UndefPoisonKind Kind,
7497 bool ConsiderFlagsAndMetadata) {
7498
7499 if (ConsiderFlagsAndMetadata && includesPoison(Kind) &&
7500 Op->hasPoisonGeneratingAnnotations())
7501 return true;
7502
7503 unsigned Opcode = Op->getOpcode();
7504
7505 // Check whether opcode is a poison/undef-generating operation
7506 switch (Opcode) {
7507 case Instruction::Shl:
7508 case Instruction::AShr:
7509 case Instruction::LShr:
7510 return includesPoison(Kind) && !shiftAmountKnownInRange(ShiftAmount: Op->getOperand(i: 1));
7511 case Instruction::FPToSI:
7512 case Instruction::FPToUI:
7513 // fptosi/ui yields poison if the resulting value does not fit in the
7514 // destination type.
7515 return true;
7516 case Instruction::Call:
7517 if (auto *II = dyn_cast<IntrinsicInst>(Val: Op)) {
7518 switch (II->getIntrinsicID()) {
7519 // TODO: Add more intrinsics.
7520 case Intrinsic::ctlz:
7521 case Intrinsic::cttz:
7522 case Intrinsic::abs:
7523 // We're not considering flags so it is safe to just return false.
7524 return false;
7525 case Intrinsic::sshl_sat:
7526 case Intrinsic::ushl_sat:
7527 if (!includesPoison(Kind) ||
7528 shiftAmountKnownInRange(ShiftAmount: II->getArgOperand(i: 1)))
7529 return false;
7530 break;
7531 }
7532 }
7533 [[fallthrough]];
7534 case Instruction::CallBr:
7535 case Instruction::Invoke: {
7536 const auto *CB = cast<CallBase>(Val: Op);
7537 return !CB->hasRetAttr(Kind: Attribute::NoUndef) &&
7538 !CB->hasFnAttr(Kind: Attribute::NoCreateUndefOrPoison);
7539 }
7540 case Instruction::InsertElement:
7541 case Instruction::ExtractElement: {
7542 // If index exceeds the length of the vector, it returns poison
7543 auto *VTy = cast<VectorType>(Val: Op->getOperand(i: 0)->getType());
7544 unsigned IdxOp = Op->getOpcode() == Instruction::InsertElement ? 2 : 1;
7545 auto *Idx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: IdxOp));
7546 if (includesPoison(Kind))
7547 return !Idx ||
7548 Idx->getValue().uge(RHS: VTy->getElementCount().getKnownMinValue());
7549 return false;
7550 }
7551 case Instruction::ShuffleVector: {
7552 ArrayRef<int> Mask = isa<ConstantExpr>(Val: Op)
7553 ? cast<ConstantExpr>(Val: Op)->getShuffleMask()
7554 : cast<ShuffleVectorInst>(Val: Op)->getShuffleMask();
7555 return includesPoison(Kind) && is_contained(Range&: Mask, Element: PoisonMaskElem);
7556 }
7557 case Instruction::FNeg:
7558 case Instruction::PHI:
7559 case Instruction::Select:
7560 case Instruction::ExtractValue:
7561 case Instruction::InsertValue:
7562 case Instruction::Freeze:
7563 case Instruction::ICmp:
7564 case Instruction::FCmp:
7565 case Instruction::GetElementPtr:
7566 return false;
7567 case Instruction::AddrSpaceCast:
7568 return true;
7569 default: {
7570 const auto *CE = dyn_cast<ConstantExpr>(Val: Op);
7571 if (isa<CastInst>(Val: Op) || (CE && CE->isCast()))
7572 return false;
7573 else if (Instruction::isBinaryOp(Opcode))
7574 return false;
7575 // Be conservative and return true.
7576 return true;
7577 }
7578 }
7579}
7580
7581bool llvm::canCreateUndefOrPoison(const Operator *Op,
7582 bool ConsiderFlagsAndMetadata) {
7583 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::UndefOrPoison,
7584 ConsiderFlagsAndMetadata);
7585}
7586
7587bool llvm::canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata) {
7588 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::PoisonOnly,
7589 ConsiderFlagsAndMetadata);
7590}
7591
7592static bool directlyImpliesPoison(const Value *ValAssumedPoison, const Value *V,
7593 unsigned Depth) {
7594 if (ValAssumedPoison == V)
7595 return true;
7596
7597 const unsigned MaxDepth = 2;
7598 if (Depth >= MaxDepth)
7599 return false;
7600
7601 if (const auto *I = dyn_cast<Instruction>(Val: V)) {
7602 if (any_of(Range: I->operands(), P: [=](const Use &Op) {
7603 return propagatesPoison(PoisonOp: Op) &&
7604 directlyImpliesPoison(ValAssumedPoison, V: Op, Depth: Depth + 1);
7605 }))
7606 return true;
7607
7608 // V = extractvalue V0, idx
7609 // V2 = extractvalue V0, idx2
7610 // V0's elements are all poison or not. (e.g., add_with_overflow)
7611 const WithOverflowInst *II;
7612 if (match(V: I, P: m_ExtractValue(V: m_WithOverflowInst(I&: II))) &&
7613 (match(V: ValAssumedPoison, P: m_ExtractValue(V: m_Specific(V: II))) ||
7614 llvm::is_contained(Range: II->args(), Element: ValAssumedPoison)))
7615 return true;
7616 }
7617 return false;
7618}
7619
7620static bool impliesPoison(const Value *ValAssumedPoison, const Value *V,
7621 unsigned Depth) {
7622 if (isGuaranteedNotToBePoison(V: ValAssumedPoison))
7623 return true;
7624
7625 if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0))
7626 return true;
7627
7628 const unsigned MaxDepth = 2;
7629 if (Depth >= MaxDepth)
7630 return false;
7631
7632 const auto *I = dyn_cast<Instruction>(Val: ValAssumedPoison);
7633 if (I && !canCreatePoison(Op: cast<Operator>(Val: I))) {
7634 return all_of(Range: I->operands(), P: [=](const Value *Op) {
7635 return impliesPoison(ValAssumedPoison: Op, V, Depth: Depth + 1);
7636 });
7637 }
7638 return false;
7639}
7640
7641bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) {
7642 return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0);
7643}
7644
7645static bool programUndefinedIfUndefOrPoison(const Value *V, bool PoisonOnly);
7646
7647static bool isGuaranteedNotToBeUndefOrPoison(
7648 const Value *V, AssumptionCache *AC, const Instruction *CtxI,
7649 const DominatorTree *DT, unsigned Depth, UndefPoisonKind Kind) {
7650 if (Depth >= MaxAnalysisRecursionDepth)
7651 return false;
7652
7653 if (isa<MetadataAsValue>(Val: V))
7654 return false;
7655
7656 if (const auto *A = dyn_cast<Argument>(Val: V)) {
7657 if (A->hasAttribute(Kind: Attribute::NoUndef) ||
7658 A->hasAttribute(Kind: Attribute::Dereferenceable) ||
7659 A->hasAttribute(Kind: Attribute::DereferenceableOrNull))
7660 return true;
7661 }
7662
7663 if (auto *C = dyn_cast<Constant>(Val: V)) {
7664 if (isa<PoisonValue>(Val: C))
7665 return !includesPoison(Kind);
7666
7667 if (isa<UndefValue>(Val: C))
7668 return !includesUndef(Kind);
7669
7670 if (isa<ConstantInt>(Val: C) || isa<GlobalVariable>(Val: C) || isa<ConstantFP>(Val: C) ||
7671 isa<ConstantPointerNull>(Val: C) || isa<Function>(Val: C))
7672 return true;
7673
7674 if (C->getType()->isVectorTy()) {
7675 if (isa<ConstantExpr>(Val: C)) {
7676 // Scalable vectors can use a ConstantExpr to build a splat.
7677 if (Constant *SplatC = C->getSplatValue())
7678 if (isa<ConstantInt>(Val: SplatC) || isa<ConstantFP>(Val: SplatC))
7679 return true;
7680 } else {
7681 if (includesUndef(Kind) && C->containsUndefElement())
7682 return false;
7683 if (includesPoison(Kind) && C->containsPoisonElement())
7684 return false;
7685 return !C->containsConstantExpression();
7686 }
7687 }
7688 }
7689
7690 // Strip cast operations from a pointer value.
7691 // Note that stripPointerCastsSameRepresentation can strip off getelementptr
7692 // inbounds with zero offset. To guarantee that the result isn't poison, the
7693 // stripped pointer is checked as it has to be pointing into an allocated
7694 // object or be null `null` to ensure `inbounds` getelement pointers with a
7695 // zero offset could not produce poison.
7696 // It can strip off addrspacecast that do not change bit representation as
7697 // well. We believe that such addrspacecast is equivalent to no-op.
7698 auto *StrippedV = V->stripPointerCastsSameRepresentation();
7699 if (isa<AllocaInst>(Val: StrippedV) || isa<GlobalVariable>(Val: StrippedV) ||
7700 isa<Function>(Val: StrippedV) || isa<ConstantPointerNull>(Val: StrippedV))
7701 return true;
7702
7703 auto OpCheck = [&](const Value *V) {
7704 return isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth: Depth + 1, Kind);
7705 };
7706
7707 if (auto *Opr = dyn_cast<Operator>(Val: V)) {
7708 // If the value is a freeze instruction, then it can never
7709 // be undef or poison.
7710 if (isa<FreezeInst>(Val: V))
7711 return true;
7712
7713 if (const auto *CB = dyn_cast<CallBase>(Val: V)) {
7714 if (CB->hasRetAttr(Kind: Attribute::NoUndef) ||
7715 CB->hasRetAttr(Kind: Attribute::Dereferenceable) ||
7716 CB->hasRetAttr(Kind: Attribute::DereferenceableOrNull))
7717 return true;
7718 }
7719
7720 if (!::canCreateUndefOrPoison(Op: Opr, Kind,
7721 /*ConsiderFlagsAndMetadata=*/true)) {
7722 if (const auto *PN = dyn_cast<PHINode>(Val: V)) {
7723 unsigned Num = PN->getNumIncomingValues();
7724 bool IsWellDefined = true;
7725 for (unsigned i = 0; i < Num; ++i) {
7726 if (PN == PN->getIncomingValue(i))
7727 continue;
7728 auto *TI = PN->getIncomingBlock(i)->getTerminator();
7729 if (!isGuaranteedNotToBeUndefOrPoison(V: PN->getIncomingValue(i), AC, CtxI: TI,
7730 DT, Depth: Depth + 1, Kind)) {
7731 IsWellDefined = false;
7732 break;
7733 }
7734 }
7735 if (IsWellDefined)
7736 return true;
7737 } else if (auto *Splat = isa<ShuffleVectorInst>(Val: Opr) ? getSplatValue(V: Opr)
7738 : nullptr) {
7739 // For splats we only need to check the value being splatted.
7740 if (OpCheck(Splat))
7741 return true;
7742 } else if (all_of(Range: Opr->operands(), P: OpCheck))
7743 return true;
7744 }
7745 }
7746
7747 if (auto *I = dyn_cast<LoadInst>(Val: V))
7748 if (I->hasMetadata(KindID: LLVMContext::MD_noundef) ||
7749 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable) ||
7750 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable_or_null))
7751 return true;
7752
7753 if (programUndefinedIfUndefOrPoison(V, PoisonOnly: !includesUndef(Kind)))
7754 return true;
7755
7756 // CxtI may be null or a cloned instruction.
7757 if (!CtxI || !CtxI->getParent() || !DT)
7758 return false;
7759
7760 auto *DNode = DT->getNode(BB: CtxI->getParent());
7761 if (!DNode)
7762 // Unreachable block
7763 return false;
7764
7765 // If V is used as a branch condition before reaching CtxI, V cannot be
7766 // undef or poison.
7767 // br V, BB1, BB2
7768 // BB1:
7769 // CtxI ; V cannot be undef or poison here
7770 auto *Dominator = DNode->getIDom();
7771 // This check is purely for compile time reasons: we can skip the IDom walk
7772 // if what we are checking for includes undef and the value is not an integer.
7773 if (!includesUndef(Kind) || V->getType()->isIntegerTy())
7774 while (Dominator) {
7775 auto *TI = Dominator->getBlock()->getTerminator();
7776
7777 Value *Cond = nullptr;
7778 if (auto BI = dyn_cast_or_null<BranchInst>(Val: TI)) {
7779 if (BI->isConditional())
7780 Cond = BI->getCondition();
7781 } else if (auto SI = dyn_cast_or_null<SwitchInst>(Val: TI)) {
7782 Cond = SI->getCondition();
7783 }
7784
7785 if (Cond) {
7786 if (Cond == V)
7787 return true;
7788 else if (!includesUndef(Kind) && isa<Operator>(Val: Cond)) {
7789 // For poison, we can analyze further
7790 auto *Opr = cast<Operator>(Val: Cond);
7791 if (any_of(Range: Opr->operands(), P: [V](const Use &U) {
7792 return V == U && propagatesPoison(PoisonOp: U);
7793 }))
7794 return true;
7795 }
7796 }
7797
7798 Dominator = Dominator->getIDom();
7799 }
7800
7801 if (AC && getKnowledgeValidInContext(V, AttrKinds: {Attribute::NoUndef}, AC&: *AC, CtxI, DT))
7802 return true;
7803
7804 return false;
7805}
7806
7807bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC,
7808 const Instruction *CtxI,
7809 const DominatorTree *DT,
7810 unsigned Depth) {
7811 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7812 Kind: UndefPoisonKind::UndefOrPoison);
7813}
7814
7815bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
7816 const Instruction *CtxI,
7817 const DominatorTree *DT, unsigned Depth) {
7818 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7819 Kind: UndefPoisonKind::PoisonOnly);
7820}
7821
7822bool llvm::isGuaranteedNotToBeUndef(const Value *V, AssumptionCache *AC,
7823 const Instruction *CtxI,
7824 const DominatorTree *DT, unsigned Depth) {
7825 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7826 Kind: UndefPoisonKind::UndefOnly);
7827}
7828
7829/// Return true if undefined behavior would provably be executed on the path to
7830/// OnPathTo if Root produced a posion result. Note that this doesn't say
7831/// anything about whether OnPathTo is actually executed or whether Root is
7832/// actually poison. This can be used to assess whether a new use of Root can
7833/// be added at a location which is control equivalent with OnPathTo (such as
7834/// immediately before it) without introducing UB which didn't previously
7835/// exist. Note that a false result conveys no information.
7836bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
7837 Instruction *OnPathTo,
7838 DominatorTree *DT) {
7839 // Basic approach is to assume Root is poison, propagate poison forward
7840 // through all users we can easily track, and then check whether any of those
7841 // users are provable UB and must execute before out exiting block might
7842 // exit.
7843
7844 // The set of all recursive users we've visited (which are assumed to all be
7845 // poison because of said visit)
7846 SmallPtrSet<const Value *, 16> KnownPoison;
7847 SmallVector<const Instruction*, 16> Worklist;
7848 Worklist.push_back(Elt: Root);
7849 while (!Worklist.empty()) {
7850 const Instruction *I = Worklist.pop_back_val();
7851
7852 // If we know this must trigger UB on a path leading our target.
7853 if (mustTriggerUB(I, KnownPoison) && DT->dominates(Def: I, User: OnPathTo))
7854 return true;
7855
7856 // If we can't analyze propagation through this instruction, just skip it
7857 // and transitive users. Safe as false is a conservative result.
7858 if (I != Root && !any_of(Range: I->operands(), P: [&KnownPoison](const Use &U) {
7859 return KnownPoison.contains(Ptr: U) && propagatesPoison(PoisonOp: U);
7860 }))
7861 continue;
7862
7863 if (KnownPoison.insert(Ptr: I).second)
7864 for (const User *User : I->users())
7865 Worklist.push_back(Elt: cast<Instruction>(Val: User));
7866 }
7867
7868 // Might be non-UB, or might have a path we couldn't prove must execute on
7869 // way to exiting bb.
7870 return false;
7871}
7872
7873OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
7874 const SimplifyQuery &SQ) {
7875 return ::computeOverflowForSignedAdd(LHS: Add->getOperand(i_nocapture: 0), RHS: Add->getOperand(i_nocapture: 1),
7876 Add, SQ);
7877}
7878
7879OverflowResult
7880llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7881 const WithCache<const Value *> &RHS,
7882 const SimplifyQuery &SQ) {
7883 return ::computeOverflowForSignedAdd(LHS, RHS, Add: nullptr, SQ);
7884}
7885
7886bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
7887 // Note: An atomic operation isn't guaranteed to return in a reasonable amount
7888 // of time because it's possible for another thread to interfere with it for an
7889 // arbitrary length of time, but programs aren't allowed to rely on that.
7890
7891 // If there is no successor, then execution can't transfer to it.
7892 if (isa<ReturnInst>(Val: I))
7893 return false;
7894 if (isa<UnreachableInst>(Val: I))
7895 return false;
7896
7897 // Note: Do not add new checks here; instead, change Instruction::mayThrow or
7898 // Instruction::willReturn.
7899 //
7900 // FIXME: Move this check into Instruction::willReturn.
7901 if (isa<CatchPadInst>(Val: I)) {
7902 switch (classifyEHPersonality(Pers: I->getFunction()->getPersonalityFn())) {
7903 default:
7904 // A catchpad may invoke exception object constructors and such, which
7905 // in some languages can be arbitrary code, so be conservative by default.
7906 return false;
7907 case EHPersonality::CoreCLR:
7908 // For CoreCLR, it just involves a type test.
7909 return true;
7910 }
7911 }
7912
7913 // An instruction that returns without throwing must transfer control flow
7914 // to a successor.
7915 return !I->mayThrow() && I->willReturn();
7916}
7917
7918bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
7919 // TODO: This is slightly conservative for invoke instruction since exiting
7920 // via an exception *is* normal control for them.
7921 for (const Instruction &I : *BB)
7922 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7923 return false;
7924 return true;
7925}
7926
7927bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7928 BasicBlock::const_iterator Begin, BasicBlock::const_iterator End,
7929 unsigned ScanLimit) {
7930 return isGuaranteedToTransferExecutionToSuccessor(Range: make_range(x: Begin, y: End),
7931 ScanLimit);
7932}
7933
7934bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7935 iterator_range<BasicBlock::const_iterator> Range, unsigned ScanLimit) {
7936 assert(ScanLimit && "scan limit must be non-zero");
7937 for (const Instruction &I : Range) {
7938 if (--ScanLimit == 0)
7939 return false;
7940 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7941 return false;
7942 }
7943 return true;
7944}
7945
7946bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
7947 const Loop *L) {
7948 // The loop header is guaranteed to be executed for every iteration.
7949 //
7950 // FIXME: Relax this constraint to cover all basic blocks that are
7951 // guaranteed to be executed at every iteration.
7952 if (I->getParent() != L->getHeader()) return false;
7953
7954 for (const Instruction &LI : *L->getHeader()) {
7955 if (&LI == I) return true;
7956 if (!isGuaranteedToTransferExecutionToSuccessor(I: &LI)) return false;
7957 }
7958 llvm_unreachable("Instruction not contained in its own parent basic block.");
7959}
7960
7961bool llvm::intrinsicPropagatesPoison(Intrinsic::ID IID) {
7962 switch (IID) {
7963 // TODO: Add more intrinsics.
7964 case Intrinsic::sadd_with_overflow:
7965 case Intrinsic::ssub_with_overflow:
7966 case Intrinsic::smul_with_overflow:
7967 case Intrinsic::uadd_with_overflow:
7968 case Intrinsic::usub_with_overflow:
7969 case Intrinsic::umul_with_overflow:
7970 // If an input is a vector containing a poison element, the
7971 // two output vectors (calculated results, overflow bits)'
7972 // corresponding lanes are poison.
7973 return true;
7974 case Intrinsic::ctpop:
7975 case Intrinsic::ctlz:
7976 case Intrinsic::cttz:
7977 case Intrinsic::abs:
7978 case Intrinsic::smax:
7979 case Intrinsic::smin:
7980 case Intrinsic::umax:
7981 case Intrinsic::umin:
7982 case Intrinsic::scmp:
7983 case Intrinsic::is_fpclass:
7984 case Intrinsic::ptrmask:
7985 case Intrinsic::ucmp:
7986 case Intrinsic::bitreverse:
7987 case Intrinsic::bswap:
7988 case Intrinsic::sadd_sat:
7989 case Intrinsic::ssub_sat:
7990 case Intrinsic::sshl_sat:
7991 case Intrinsic::uadd_sat:
7992 case Intrinsic::usub_sat:
7993 case Intrinsic::ushl_sat:
7994 case Intrinsic::smul_fix:
7995 case Intrinsic::smul_fix_sat:
7996 case Intrinsic::umul_fix:
7997 case Intrinsic::umul_fix_sat:
7998 case Intrinsic::pow:
7999 case Intrinsic::powi:
8000 case Intrinsic::sin:
8001 case Intrinsic::sinh:
8002 case Intrinsic::cos:
8003 case Intrinsic::cosh:
8004 case Intrinsic::sincos:
8005 case Intrinsic::sincospi:
8006 case Intrinsic::tan:
8007 case Intrinsic::tanh:
8008 case Intrinsic::asin:
8009 case Intrinsic::acos:
8010 case Intrinsic::atan:
8011 case Intrinsic::atan2:
8012 case Intrinsic::canonicalize:
8013 case Intrinsic::sqrt:
8014 case Intrinsic::exp:
8015 case Intrinsic::exp2:
8016 case Intrinsic::exp10:
8017 case Intrinsic::log:
8018 case Intrinsic::log2:
8019 case Intrinsic::log10:
8020 case Intrinsic::modf:
8021 case Intrinsic::floor:
8022 case Intrinsic::ceil:
8023 case Intrinsic::trunc:
8024 case Intrinsic::rint:
8025 case Intrinsic::nearbyint:
8026 case Intrinsic::round:
8027 case Intrinsic::roundeven:
8028 case Intrinsic::lrint:
8029 case Intrinsic::llrint:
8030 case Intrinsic::fshl:
8031 case Intrinsic::fshr:
8032 return true;
8033 default:
8034 return false;
8035 }
8036}
8037
8038bool llvm::propagatesPoison(const Use &PoisonOp) {
8039 const Operator *I = cast<Operator>(Val: PoisonOp.getUser());
8040 switch (I->getOpcode()) {
8041 case Instruction::Freeze:
8042 case Instruction::PHI:
8043 case Instruction::Invoke:
8044 return false;
8045 case Instruction::Select:
8046 return PoisonOp.getOperandNo() == 0;
8047 case Instruction::Call:
8048 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
8049 return intrinsicPropagatesPoison(IID: II->getIntrinsicID());
8050 return false;
8051 case Instruction::ICmp:
8052 case Instruction::FCmp:
8053 case Instruction::GetElementPtr:
8054 return true;
8055 default:
8056 if (isa<BinaryOperator>(Val: I) || isa<UnaryOperator>(Val: I) || isa<CastInst>(Val: I))
8057 return true;
8058
8059 // Be conservative and return false.
8060 return false;
8061 }
8062}
8063
8064/// Enumerates all operands of \p I that are guaranteed to not be undef or
8065/// poison. If the callback \p Handle returns true, stop processing and return
8066/// true. Otherwise, return false.
8067template <typename CallableT>
8068static bool handleGuaranteedWellDefinedOps(const Instruction *I,
8069 const CallableT &Handle) {
8070 switch (I->getOpcode()) {
8071 case Instruction::Store:
8072 if (Handle(cast<StoreInst>(Val: I)->getPointerOperand()))
8073 return true;
8074 break;
8075
8076 case Instruction::Load:
8077 if (Handle(cast<LoadInst>(Val: I)->getPointerOperand()))
8078 return true;
8079 break;
8080
8081 // Since dereferenceable attribute imply noundef, atomic operations
8082 // also implicitly have noundef pointers too
8083 case Instruction::AtomicCmpXchg:
8084 if (Handle(cast<AtomicCmpXchgInst>(Val: I)->getPointerOperand()))
8085 return true;
8086 break;
8087
8088 case Instruction::AtomicRMW:
8089 if (Handle(cast<AtomicRMWInst>(Val: I)->getPointerOperand()))
8090 return true;
8091 break;
8092
8093 case Instruction::Call:
8094 case Instruction::Invoke: {
8095 const CallBase *CB = cast<CallBase>(Val: I);
8096 if (CB->isIndirectCall() && Handle(CB->getCalledOperand()))
8097 return true;
8098 for (unsigned i = 0; i < CB->arg_size(); ++i)
8099 if ((CB->paramHasAttr(ArgNo: i, Kind: Attribute::NoUndef) ||
8100 CB->paramHasAttr(ArgNo: i, Kind: Attribute::Dereferenceable) ||
8101 CB->paramHasAttr(ArgNo: i, Kind: Attribute::DereferenceableOrNull)) &&
8102 Handle(CB->getArgOperand(i)))
8103 return true;
8104 break;
8105 }
8106 case Instruction::Ret:
8107 if (I->getFunction()->hasRetAttribute(Kind: Attribute::NoUndef) &&
8108 Handle(I->getOperand(i: 0)))
8109 return true;
8110 break;
8111 case Instruction::Switch:
8112 if (Handle(cast<SwitchInst>(Val: I)->getCondition()))
8113 return true;
8114 break;
8115 case Instruction::Br: {
8116 auto *BR = cast<BranchInst>(Val: I);
8117 if (BR->isConditional() && Handle(BR->getCondition()))
8118 return true;
8119 break;
8120 }
8121 default:
8122 break;
8123 }
8124
8125 return false;
8126}
8127
8128/// Enumerates all operands of \p I that are guaranteed to not be poison.
8129template <typename CallableT>
8130static bool handleGuaranteedNonPoisonOps(const Instruction *I,
8131 const CallableT &Handle) {
8132 if (handleGuaranteedWellDefinedOps(I, Handle))
8133 return true;
8134 switch (I->getOpcode()) {
8135 // Divisors of these operations are allowed to be partially undef.
8136 case Instruction::UDiv:
8137 case Instruction::SDiv:
8138 case Instruction::URem:
8139 case Instruction::SRem:
8140 return Handle(I->getOperand(i: 1));
8141 default:
8142 return false;
8143 }
8144}
8145
8146bool llvm::mustTriggerUB(const Instruction *I,
8147 const SmallPtrSetImpl<const Value *> &KnownPoison) {
8148 return handleGuaranteedNonPoisonOps(
8149 I, Handle: [&](const Value *V) { return KnownPoison.count(Ptr: V); });
8150}
8151
8152static bool programUndefinedIfUndefOrPoison(const Value *V,
8153 bool PoisonOnly) {
8154 // We currently only look for uses of values within the same basic
8155 // block, as that makes it easier to guarantee that the uses will be
8156 // executed given that Inst is executed.
8157 //
8158 // FIXME: Expand this to consider uses beyond the same basic block. To do
8159 // this, look out for the distinction between post-dominance and strong
8160 // post-dominance.
8161 const BasicBlock *BB = nullptr;
8162 BasicBlock::const_iterator Begin;
8163 if (const auto *Inst = dyn_cast<Instruction>(Val: V)) {
8164 BB = Inst->getParent();
8165 Begin = Inst->getIterator();
8166 Begin++;
8167 } else if (const auto *Arg = dyn_cast<Argument>(Val: V)) {
8168 if (Arg->getParent()->isDeclaration())
8169 return false;
8170 BB = &Arg->getParent()->getEntryBlock();
8171 Begin = BB->begin();
8172 } else {
8173 return false;
8174 }
8175
8176 // Limit number of instructions we look at, to avoid scanning through large
8177 // blocks. The current limit is chosen arbitrarily.
8178 unsigned ScanLimit = 32;
8179 BasicBlock::const_iterator End = BB->end();
8180
8181 if (!PoisonOnly) {
8182 // Since undef does not propagate eagerly, be conservative & just check
8183 // whether a value is directly passed to an instruction that must take
8184 // well-defined operands.
8185
8186 for (const auto &I : make_range(x: Begin, y: End)) {
8187 if (--ScanLimit == 0)
8188 break;
8189
8190 if (handleGuaranteedWellDefinedOps(I: &I, Handle: [V](const Value *WellDefinedOp) {
8191 return WellDefinedOp == V;
8192 }))
8193 return true;
8194
8195 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8196 break;
8197 }
8198 return false;
8199 }
8200
8201 // Set of instructions that we have proved will yield poison if Inst
8202 // does.
8203 SmallPtrSet<const Value *, 16> YieldsPoison;
8204 SmallPtrSet<const BasicBlock *, 4> Visited;
8205
8206 YieldsPoison.insert(Ptr: V);
8207 Visited.insert(Ptr: BB);
8208
8209 while (true) {
8210 for (const auto &I : make_range(x: Begin, y: End)) {
8211 if (--ScanLimit == 0)
8212 return false;
8213 if (mustTriggerUB(I: &I, KnownPoison: YieldsPoison))
8214 return true;
8215 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8216 return false;
8217
8218 // If an operand is poison and propagates it, mark I as yielding poison.
8219 for (const Use &Op : I.operands()) {
8220 if (YieldsPoison.count(Ptr: Op) && propagatesPoison(PoisonOp: Op)) {
8221 YieldsPoison.insert(Ptr: &I);
8222 break;
8223 }
8224 }
8225
8226 // Special handling for select, which returns poison if its operand 0 is
8227 // poison (handled in the loop above) *or* if both its true/false operands
8228 // are poison (handled here).
8229 if (I.getOpcode() == Instruction::Select &&
8230 YieldsPoison.count(Ptr: I.getOperand(i: 1)) &&
8231 YieldsPoison.count(Ptr: I.getOperand(i: 2))) {
8232 YieldsPoison.insert(Ptr: &I);
8233 }
8234 }
8235
8236 BB = BB->getSingleSuccessor();
8237 if (!BB || !Visited.insert(Ptr: BB).second)
8238 break;
8239
8240 Begin = BB->getFirstNonPHIIt();
8241 End = BB->end();
8242 }
8243 return false;
8244}
8245
8246bool llvm::programUndefinedIfUndefOrPoison(const Instruction *Inst) {
8247 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: false);
8248}
8249
8250bool llvm::programUndefinedIfPoison(const Instruction *Inst) {
8251 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: true);
8252}
8253
8254static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
8255 if (FMF.noNaNs())
8256 return true;
8257
8258 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8259 return !C->isNaN();
8260
8261 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8262 if (!C->getElementType()->isFloatingPointTy())
8263 return false;
8264 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8265 if (C->getElementAsAPFloat(i: I).isNaN())
8266 return false;
8267 }
8268 return true;
8269 }
8270
8271 if (isa<ConstantAggregateZero>(Val: V))
8272 return true;
8273
8274 return false;
8275}
8276
8277static bool isKnownNonZero(const Value *V) {
8278 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8279 return !C->isZero();
8280
8281 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8282 if (!C->getElementType()->isFloatingPointTy())
8283 return false;
8284 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8285 if (C->getElementAsAPFloat(i: I).isZero())
8286 return false;
8287 }
8288 return true;
8289 }
8290
8291 return false;
8292}
8293
8294/// Match clamp pattern for float types without care about NaNs or signed zeros.
8295/// Given non-min/max outer cmp/select from the clamp pattern this
8296/// function recognizes if it can be substitued by a "canonical" min/max
8297/// pattern.
8298static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
8299 Value *CmpLHS, Value *CmpRHS,
8300 Value *TrueVal, Value *FalseVal,
8301 Value *&LHS, Value *&RHS) {
8302 // Try to match
8303 // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
8304 // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
8305 // and return description of the outer Max/Min.
8306
8307 // First, check if select has inverse order:
8308 if (CmpRHS == FalseVal) {
8309 std::swap(a&: TrueVal, b&: FalseVal);
8310 Pred = CmpInst::getInversePredicate(pred: Pred);
8311 }
8312
8313 // Assume success now. If there's no match, callers should not use these anyway.
8314 LHS = TrueVal;
8315 RHS = FalseVal;
8316
8317 const APFloat *FC1;
8318 if (CmpRHS != TrueVal || !match(V: CmpRHS, P: m_APFloat(Res&: FC1)) || !FC1->isFinite())
8319 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8320
8321 const APFloat *FC2;
8322 switch (Pred) {
8323 case CmpInst::FCMP_OLT:
8324 case CmpInst::FCMP_OLE:
8325 case CmpInst::FCMP_ULT:
8326 case CmpInst::FCMP_ULE:
8327 if (match(V: FalseVal, P: m_OrdOrUnordFMin(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8328 *FC1 < *FC2)
8329 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8330 break;
8331 case CmpInst::FCMP_OGT:
8332 case CmpInst::FCMP_OGE:
8333 case CmpInst::FCMP_UGT:
8334 case CmpInst::FCMP_UGE:
8335 if (match(V: FalseVal, P: m_OrdOrUnordFMax(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8336 *FC1 > *FC2)
8337 return {.Flavor: SPF_FMINNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8338 break;
8339 default:
8340 break;
8341 }
8342
8343 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8344}
8345
8346/// Recognize variations of:
8347/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
8348static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
8349 Value *CmpLHS, Value *CmpRHS,
8350 Value *TrueVal, Value *FalseVal) {
8351 // Swap the select operands and predicate to match the patterns below.
8352 if (CmpRHS != TrueVal) {
8353 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8354 std::swap(a&: TrueVal, b&: FalseVal);
8355 }
8356 const APInt *C1;
8357 if (CmpRHS == TrueVal && match(V: CmpRHS, P: m_APInt(Res&: C1))) {
8358 const APInt *C2;
8359 // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
8360 if (match(V: FalseVal, P: m_SMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8361 C1->slt(RHS: *C2) && Pred == CmpInst::ICMP_SLT)
8362 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8363
8364 // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
8365 if (match(V: FalseVal, P: m_SMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8366 C1->sgt(RHS: *C2) && Pred == CmpInst::ICMP_SGT)
8367 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8368
8369 // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
8370 if (match(V: FalseVal, P: m_UMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8371 C1->ult(RHS: *C2) && Pred == CmpInst::ICMP_ULT)
8372 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8373
8374 // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
8375 if (match(V: FalseVal, P: m_UMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8376 C1->ugt(RHS: *C2) && Pred == CmpInst::ICMP_UGT)
8377 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8378 }
8379 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8380}
8381
8382/// Recognize variations of:
8383/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
8384static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
8385 Value *CmpLHS, Value *CmpRHS,
8386 Value *TVal, Value *FVal,
8387 unsigned Depth) {
8388 // TODO: Allow FP min/max with nnan/nsz.
8389 assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
8390
8391 Value *A = nullptr, *B = nullptr;
8392 SelectPatternResult L = matchSelectPattern(V: TVal, LHS&: A, RHS&: B, CastOp: nullptr, Depth: Depth + 1);
8393 if (!SelectPatternResult::isMinOrMax(SPF: L.Flavor))
8394 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8395
8396 Value *C = nullptr, *D = nullptr;
8397 SelectPatternResult R = matchSelectPattern(V: FVal, LHS&: C, RHS&: D, CastOp: nullptr, Depth: Depth + 1);
8398 if (L.Flavor != R.Flavor)
8399 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8400
8401 // We have something like: x Pred y ? min(a, b) : min(c, d).
8402 // Try to match the compare to the min/max operations of the select operands.
8403 // First, make sure we have the right compare predicate.
8404 switch (L.Flavor) {
8405 case SPF_SMIN:
8406 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
8407 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8408 std::swap(a&: CmpLHS, b&: CmpRHS);
8409 }
8410 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
8411 break;
8412 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8413 case SPF_SMAX:
8414 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
8415 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8416 std::swap(a&: CmpLHS, b&: CmpRHS);
8417 }
8418 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
8419 break;
8420 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8421 case SPF_UMIN:
8422 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
8423 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8424 std::swap(a&: CmpLHS, b&: CmpRHS);
8425 }
8426 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
8427 break;
8428 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8429 case SPF_UMAX:
8430 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
8431 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8432 std::swap(a&: CmpLHS, b&: CmpRHS);
8433 }
8434 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
8435 break;
8436 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8437 default:
8438 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8439 }
8440
8441 // If there is a common operand in the already matched min/max and the other
8442 // min/max operands match the compare operands (either directly or inverted),
8443 // then this is min/max of the same flavor.
8444
8445 // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8446 // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8447 if (D == B) {
8448 if ((CmpLHS == A && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8449 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8450 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8451 }
8452 // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8453 // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8454 if (C == B) {
8455 if ((CmpLHS == A && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8456 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8457 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8458 }
8459 // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8460 // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8461 if (D == A) {
8462 if ((CmpLHS == B && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8463 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8464 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8465 }
8466 // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8467 // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8468 if (C == A) {
8469 if ((CmpLHS == B && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8470 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8471 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8472 }
8473
8474 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8475}
8476
8477/// If the input value is the result of a 'not' op, constant integer, or vector
8478/// splat of a constant integer, return the bitwise-not source value.
8479/// TODO: This could be extended to handle non-splat vector integer constants.
8480static Value *getNotValue(Value *V) {
8481 Value *NotV;
8482 if (match(V, P: m_Not(V: m_Value(V&: NotV))))
8483 return NotV;
8484
8485 const APInt *C;
8486 if (match(V, P: m_APInt(Res&: C)))
8487 return ConstantInt::get(Ty: V->getType(), V: ~(*C));
8488
8489 return nullptr;
8490}
8491
8492/// Match non-obvious integer minimum and maximum sequences.
8493static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
8494 Value *CmpLHS, Value *CmpRHS,
8495 Value *TrueVal, Value *FalseVal,
8496 Value *&LHS, Value *&RHS,
8497 unsigned Depth) {
8498 // Assume success. If there's no match, callers should not use these anyway.
8499 LHS = TrueVal;
8500 RHS = FalseVal;
8501
8502 SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
8503 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8504 return SPR;
8505
8506 SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TVal: TrueVal, FVal: FalseVal, Depth);
8507 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8508 return SPR;
8509
8510 // Look through 'not' ops to find disguised min/max.
8511 // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y)
8512 // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y)
8513 if (CmpLHS == getNotValue(V: TrueVal) && CmpRHS == getNotValue(V: FalseVal)) {
8514 switch (Pred) {
8515 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8516 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8517 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8518 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8519 default: break;
8520 }
8521 }
8522
8523 // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X)
8524 // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X)
8525 if (CmpLHS == getNotValue(V: FalseVal) && CmpRHS == getNotValue(V: TrueVal)) {
8526 switch (Pred) {
8527 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8528 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8529 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8530 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8531 default: break;
8532 }
8533 }
8534
8535 if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
8536 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8537
8538 const APInt *C1;
8539 if (!match(V: CmpRHS, P: m_APInt(Res&: C1)))
8540 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8541
8542 // An unsigned min/max can be written with a signed compare.
8543 const APInt *C2;
8544 if ((CmpLHS == TrueVal && match(V: FalseVal, P: m_APInt(Res&: C2))) ||
8545 (CmpLHS == FalseVal && match(V: TrueVal, P: m_APInt(Res&: C2)))) {
8546 // Is the sign bit set?
8547 // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
8548 // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
8549 if (Pred == CmpInst::ICMP_SLT && C1->isZero() && C2->isMaxSignedValue())
8550 return {.Flavor: CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8551
8552 // Is the sign bit clear?
8553 // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
8554 // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
8555 if (Pred == CmpInst::ICMP_SGT && C1->isAllOnes() && C2->isMinSignedValue())
8556 return {.Flavor: CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8557 }
8558
8559 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8560}
8561
8562bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
8563 bool AllowPoison) {
8564 assert(X && Y && "Invalid operand");
8565
8566 auto IsNegationOf = [&](const Value *X, const Value *Y) {
8567 if (!match(V: X, P: m_Neg(V: m_Specific(V: Y))))
8568 return false;
8569
8570 auto *BO = cast<BinaryOperator>(Val: X);
8571 if (NeedNSW && !BO->hasNoSignedWrap())
8572 return false;
8573
8574 auto *Zero = cast<Constant>(Val: BO->getOperand(i_nocapture: 0));
8575 if (!AllowPoison && !Zero->isNullValue())
8576 return false;
8577
8578 return true;
8579 };
8580
8581 // X = -Y or Y = -X
8582 if (IsNegationOf(X, Y) || IsNegationOf(Y, X))
8583 return true;
8584
8585 // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
8586 Value *A, *B;
8587 return (!NeedNSW && (match(V: X, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8588 match(V: Y, P: m_Sub(L: m_Specific(V: B), R: m_Specific(V: A))))) ||
8589 (NeedNSW && (match(V: X, P: m_NSWSub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8590 match(V: Y, P: m_NSWSub(L: m_Specific(V: B), R: m_Specific(V: A)))));
8591}
8592
8593bool llvm::isKnownInversion(const Value *X, const Value *Y) {
8594 // Handle X = icmp pred A, B, Y = icmp pred A, C.
8595 Value *A, *B, *C;
8596 CmpPredicate Pred1, Pred2;
8597 if (!match(V: X, P: m_ICmp(Pred&: Pred1, L: m_Value(V&: A), R: m_Value(V&: B))) ||
8598 !match(V: Y, P: m_c_ICmp(Pred&: Pred2, L: m_Specific(V: A), R: m_Value(V&: C))))
8599 return false;
8600
8601 // They must both have samesign flag or not.
8602 if (Pred1.hasSameSign() != Pred2.hasSameSign())
8603 return false;
8604
8605 if (B == C)
8606 return Pred1 == ICmpInst::getInversePredicate(pred: Pred2);
8607
8608 // Try to infer the relationship from constant ranges.
8609 const APInt *RHSC1, *RHSC2;
8610 if (!match(V: B, P: m_APInt(Res&: RHSC1)) || !match(V: C, P: m_APInt(Res&: RHSC2)))
8611 return false;
8612
8613 // Sign bits of two RHSCs should match.
8614 if (Pred1.hasSameSign() && RHSC1->isNonNegative() != RHSC2->isNonNegative())
8615 return false;
8616
8617 const auto CR1 = ConstantRange::makeExactICmpRegion(Pred: Pred1, Other: *RHSC1);
8618 const auto CR2 = ConstantRange::makeExactICmpRegion(Pred: Pred2, Other: *RHSC2);
8619
8620 return CR1.inverse() == CR2;
8621}
8622
8623SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
8624 SelectPatternNaNBehavior NaNBehavior,
8625 bool Ordered) {
8626 switch (Pred) {
8627 default:
8628 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false}; // Equality.
8629 case ICmpInst::ICMP_UGT:
8630 case ICmpInst::ICMP_UGE:
8631 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8632 case ICmpInst::ICMP_SGT:
8633 case ICmpInst::ICMP_SGE:
8634 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8635 case ICmpInst::ICMP_ULT:
8636 case ICmpInst::ICMP_ULE:
8637 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8638 case ICmpInst::ICMP_SLT:
8639 case ICmpInst::ICMP_SLE:
8640 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8641 case FCmpInst::FCMP_UGT:
8642 case FCmpInst::FCMP_UGE:
8643 case FCmpInst::FCMP_OGT:
8644 case FCmpInst::FCMP_OGE:
8645 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8646 case FCmpInst::FCMP_ULT:
8647 case FCmpInst::FCMP_ULE:
8648 case FCmpInst::FCMP_OLT:
8649 case FCmpInst::FCMP_OLE:
8650 return {.Flavor: SPF_FMINNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8651 }
8652}
8653
8654std::optional<std::pair<CmpPredicate, Constant *>>
8655llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8656 assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8657 "Only for relational integer predicates.");
8658 if (isa<UndefValue>(Val: C))
8659 return std::nullopt;
8660
8661 Type *Type = C->getType();
8662 bool IsSigned = ICmpInst::isSigned(predicate: Pred);
8663
8664 CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8665 bool WillIncrement =
8666 UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8667
8668 // Check if the constant operand can be safely incremented/decremented
8669 // without overflowing/underflowing.
8670 auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8671 return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8672 };
8673
8674 Constant *SafeReplacementConstant = nullptr;
8675 if (auto *CI = dyn_cast<ConstantInt>(Val: C)) {
8676 // Bail out if the constant can't be safely incremented/decremented.
8677 if (!ConstantIsOk(CI))
8678 return std::nullopt;
8679 } else if (auto *FVTy = dyn_cast<FixedVectorType>(Val: Type)) {
8680 unsigned NumElts = FVTy->getNumElements();
8681 for (unsigned i = 0; i != NumElts; ++i) {
8682 Constant *Elt = C->getAggregateElement(Elt: i);
8683 if (!Elt)
8684 return std::nullopt;
8685
8686 if (isa<UndefValue>(Val: Elt))
8687 continue;
8688
8689 // Bail out if we can't determine if this constant is min/max or if we
8690 // know that this constant is min/max.
8691 auto *CI = dyn_cast<ConstantInt>(Val: Elt);
8692 if (!CI || !ConstantIsOk(CI))
8693 return std::nullopt;
8694
8695 if (!SafeReplacementConstant)
8696 SafeReplacementConstant = CI;
8697 }
8698 } else if (isa<VectorType>(Val: C->getType())) {
8699 // Handle scalable splat
8700 Value *SplatC = C->getSplatValue();
8701 auto *CI = dyn_cast_or_null<ConstantInt>(Val: SplatC);
8702 // Bail out if the constant can't be safely incremented/decremented.
8703 if (!CI || !ConstantIsOk(CI))
8704 return std::nullopt;
8705 } else {
8706 // ConstantExpr?
8707 return std::nullopt;
8708 }
8709
8710 // It may not be safe to change a compare predicate in the presence of
8711 // undefined elements, so replace those elements with the first safe constant
8712 // that we found.
8713 // TODO: in case of poison, it is safe; let's replace undefs only.
8714 if (C->containsUndefOrPoisonElement()) {
8715 assert(SafeReplacementConstant && "Replacement constant not set");
8716 C = Constant::replaceUndefsWith(C, Replacement: SafeReplacementConstant);
8717 }
8718
8719 CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(pred: Pred);
8720
8721 // Increment or decrement the constant.
8722 Constant *OneOrNegOne = ConstantInt::get(Ty: Type, V: WillIncrement ? 1 : -1, IsSigned: true);
8723 Constant *NewC = ConstantExpr::getAdd(C1: C, C2: OneOrNegOne);
8724
8725 return std::make_pair(x&: NewPred, y&: NewC);
8726}
8727
8728static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
8729 FastMathFlags FMF,
8730 Value *CmpLHS, Value *CmpRHS,
8731 Value *TrueVal, Value *FalseVal,
8732 Value *&LHS, Value *&RHS,
8733 unsigned Depth) {
8734 bool HasMismatchedZeros = false;
8735 if (CmpInst::isFPPredicate(P: Pred)) {
8736 // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
8737 // 0.0 operand, set the compare's 0.0 operands to that same value for the
8738 // purpose of identifying min/max. Disregard vector constants with undefined
8739 // elements because those can not be back-propagated for analysis.
8740 Value *OutputZeroVal = nullptr;
8741 if (match(V: TrueVal, P: m_AnyZeroFP()) && !match(V: FalseVal, P: m_AnyZeroFP()) &&
8742 !cast<Constant>(Val: TrueVal)->containsUndefOrPoisonElement())
8743 OutputZeroVal = TrueVal;
8744 else if (match(V: FalseVal, P: m_AnyZeroFP()) && !match(V: TrueVal, P: m_AnyZeroFP()) &&
8745 !cast<Constant>(Val: FalseVal)->containsUndefOrPoisonElement())
8746 OutputZeroVal = FalseVal;
8747
8748 if (OutputZeroVal) {
8749 if (match(V: CmpLHS, P: m_AnyZeroFP()) && CmpLHS != OutputZeroVal) {
8750 HasMismatchedZeros = true;
8751 CmpLHS = OutputZeroVal;
8752 }
8753 if (match(V: CmpRHS, P: m_AnyZeroFP()) && CmpRHS != OutputZeroVal) {
8754 HasMismatchedZeros = true;
8755 CmpRHS = OutputZeroVal;
8756 }
8757 }
8758 }
8759
8760 LHS = CmpLHS;
8761 RHS = CmpRHS;
8762
8763 // Signed zero may return inconsistent results between implementations.
8764 // (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
8765 // minNum(0.0, -0.0) // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
8766 // Therefore, we behave conservatively and only proceed if at least one of the
8767 // operands is known to not be zero or if we don't care about signed zero.
8768 switch (Pred) {
8769 default: break;
8770 case CmpInst::FCMP_OGT: case CmpInst::FCMP_OLT:
8771 case CmpInst::FCMP_UGT: case CmpInst::FCMP_ULT:
8772 if (!HasMismatchedZeros)
8773 break;
8774 [[fallthrough]];
8775 case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
8776 case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
8777 if (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8778 !isKnownNonZero(V: CmpRHS))
8779 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8780 }
8781
8782 SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
8783 bool Ordered = false;
8784
8785 // When given one NaN and one non-NaN input:
8786 // - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
8787 // - A simple C99 (a < b ? a : b) construction will return 'b' (as the
8788 // ordered comparison fails), which could be NaN or non-NaN.
8789 // so here we discover exactly what NaN behavior is required/accepted.
8790 if (CmpInst::isFPPredicate(P: Pred)) {
8791 bool LHSSafe = isKnownNonNaN(V: CmpLHS, FMF);
8792 bool RHSSafe = isKnownNonNaN(V: CmpRHS, FMF);
8793
8794 if (LHSSafe && RHSSafe) {
8795 // Both operands are known non-NaN.
8796 NaNBehavior = SPNB_RETURNS_ANY;
8797 Ordered = CmpInst::isOrdered(predicate: Pred);
8798 } else if (CmpInst::isOrdered(predicate: Pred)) {
8799 // An ordered comparison will return false when given a NaN, so it
8800 // returns the RHS.
8801 Ordered = true;
8802 if (LHSSafe)
8803 // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
8804 NaNBehavior = SPNB_RETURNS_NAN;
8805 else if (RHSSafe)
8806 NaNBehavior = SPNB_RETURNS_OTHER;
8807 else
8808 // Completely unsafe.
8809 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8810 } else {
8811 Ordered = false;
8812 // An unordered comparison will return true when given a NaN, so it
8813 // returns the LHS.
8814 if (LHSSafe)
8815 // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
8816 NaNBehavior = SPNB_RETURNS_OTHER;
8817 else if (RHSSafe)
8818 NaNBehavior = SPNB_RETURNS_NAN;
8819 else
8820 // Completely unsafe.
8821 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8822 }
8823 }
8824
8825 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
8826 std::swap(a&: CmpLHS, b&: CmpRHS);
8827 Pred = CmpInst::getSwappedPredicate(pred: Pred);
8828 if (NaNBehavior == SPNB_RETURNS_NAN)
8829 NaNBehavior = SPNB_RETURNS_OTHER;
8830 else if (NaNBehavior == SPNB_RETURNS_OTHER)
8831 NaNBehavior = SPNB_RETURNS_NAN;
8832 Ordered = !Ordered;
8833 }
8834
8835 // ([if]cmp X, Y) ? X : Y
8836 if (TrueVal == CmpLHS && FalseVal == CmpRHS)
8837 return getSelectPattern(Pred, NaNBehavior, Ordered);
8838
8839 if (isKnownNegation(X: TrueVal, Y: FalseVal)) {
8840 // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
8841 // match against either LHS or sext(LHS).
8842 auto MaybeSExtCmpLHS =
8843 m_CombineOr(L: m_Specific(V: CmpLHS), R: m_SExt(Op: m_Specific(V: CmpLHS)));
8844 auto ZeroOrAllOnes = m_CombineOr(L: m_ZeroInt(), R: m_AllOnes());
8845 auto ZeroOrOne = m_CombineOr(L: m_ZeroInt(), R: m_One());
8846 if (match(V: TrueVal, P: MaybeSExtCmpLHS)) {
8847 // Set the return values. If the compare uses the negated value (-X >s 0),
8848 // swap the return values because the negated value is always 'RHS'.
8849 LHS = TrueVal;
8850 RHS = FalseVal;
8851 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: FalseVal))))
8852 std::swap(a&: LHS, b&: RHS);
8853
8854 // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
8855 // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
8856 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8857 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8858
8859 // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
8860 if (Pred == ICmpInst::ICMP_SGE && match(V: CmpRHS, P: ZeroOrOne))
8861 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8862
8863 // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
8864 // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
8865 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8866 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8867 }
8868 else if (match(V: FalseVal, P: MaybeSExtCmpLHS)) {
8869 // Set the return values. If the compare uses the negated value (-X >s 0),
8870 // swap the return values because the negated value is always 'RHS'.
8871 LHS = FalseVal;
8872 RHS = TrueVal;
8873 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: TrueVal))))
8874 std::swap(a&: LHS, b&: RHS);
8875
8876 // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
8877 // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
8878 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8879 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8880
8881 // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
8882 // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
8883 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8884 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8885 }
8886 }
8887
8888 if (CmpInst::isIntPredicate(P: Pred))
8889 return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
8890
8891 // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
8892 // may return either -0.0 or 0.0, so fcmp/select pair has stricter
8893 // semantics than minNum. Be conservative in such case.
8894 if (NaNBehavior != SPNB_RETURNS_ANY ||
8895 (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8896 !isKnownNonZero(V: CmpRHS)))
8897 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8898
8899 return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
8900}
8901
8902static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
8903 Instruction::CastOps *CastOp) {
8904 const DataLayout &DL = CmpI->getDataLayout();
8905
8906 Constant *CastedTo = nullptr;
8907 switch (*CastOp) {
8908 case Instruction::ZExt:
8909 if (CmpI->isUnsigned())
8910 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy);
8911 break;
8912 case Instruction::SExt:
8913 if (CmpI->isSigned())
8914 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy, OnlyIfReduced: true);
8915 break;
8916 case Instruction::Trunc:
8917 Constant *CmpConst;
8918 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_Constant(C&: CmpConst)) &&
8919 CmpConst->getType() == SrcTy) {
8920 // Here we have the following case:
8921 //
8922 // %cond = cmp iN %x, CmpConst
8923 // %tr = trunc iN %x to iK
8924 // %narrowsel = select i1 %cond, iK %t, iK C
8925 //
8926 // We can always move trunc after select operation:
8927 //
8928 // %cond = cmp iN %x, CmpConst
8929 // %widesel = select i1 %cond, iN %x, iN CmpConst
8930 // %tr = trunc iN %widesel to iK
8931 //
8932 // Note that C could be extended in any way because we don't care about
8933 // upper bits after truncation. It can't be abs pattern, because it would
8934 // look like:
8935 //
8936 // select i1 %cond, x, -x.
8937 //
8938 // So only min/max pattern could be matched. Such match requires widened C
8939 // == CmpConst. That is why set widened C = CmpConst, condition trunc
8940 // CmpConst == C is checked below.
8941 CastedTo = CmpConst;
8942 } else {
8943 unsigned ExtOp = CmpI->isSigned() ? Instruction::SExt : Instruction::ZExt;
8944 CastedTo = ConstantFoldCastOperand(Opcode: ExtOp, C, DestTy: SrcTy, DL);
8945 }
8946 break;
8947 case Instruction::FPTrunc:
8948 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPExt, C, DestTy: SrcTy, DL);
8949 break;
8950 case Instruction::FPExt:
8951 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPTrunc, C, DestTy: SrcTy, DL);
8952 break;
8953 case Instruction::FPToUI:
8954 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::UIToFP, C, DestTy: SrcTy, DL);
8955 break;
8956 case Instruction::FPToSI:
8957 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::SIToFP, C, DestTy: SrcTy, DL);
8958 break;
8959 case Instruction::UIToFP:
8960 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToUI, C, DestTy: SrcTy, DL);
8961 break;
8962 case Instruction::SIToFP:
8963 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToSI, C, DestTy: SrcTy, DL);
8964 break;
8965 default:
8966 break;
8967 }
8968
8969 if (!CastedTo)
8970 return nullptr;
8971
8972 // Make sure the cast doesn't lose any information.
8973 Constant *CastedBack =
8974 ConstantFoldCastOperand(Opcode: *CastOp, C: CastedTo, DestTy: C->getType(), DL);
8975 if (CastedBack && CastedBack != C)
8976 return nullptr;
8977
8978 return CastedTo;
8979}
8980
8981/// Helps to match a select pattern in case of a type mismatch.
8982///
8983/// The function processes the case when type of true and false values of a
8984/// select instruction differs from type of the cmp instruction operands because
8985/// of a cast instruction. The function checks if it is legal to move the cast
8986/// operation after "select". If yes, it returns the new second value of
8987/// "select" (with the assumption that cast is moved):
8988/// 1. As operand of cast instruction when both values of "select" are same cast
8989/// instructions.
8990/// 2. As restored constant (by applying reverse cast operation) when the first
8991/// value of the "select" is a cast operation and the second value is a
8992/// constant. It is implemented in lookThroughCastConst().
8993/// 3. As one operand is cast instruction and the other is not. The operands in
8994/// sel(cmp) are in different type integer.
8995/// NOTE: We return only the new second value because the first value could be
8996/// accessed as operand of cast instruction.
8997static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
8998 Instruction::CastOps *CastOp) {
8999 auto *Cast1 = dyn_cast<CastInst>(Val: V1);
9000 if (!Cast1)
9001 return nullptr;
9002
9003 *CastOp = Cast1->getOpcode();
9004 Type *SrcTy = Cast1->getSrcTy();
9005 if (auto *Cast2 = dyn_cast<CastInst>(Val: V2)) {
9006 // If V1 and V2 are both the same cast from the same type, look through V1.
9007 if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
9008 return Cast2->getOperand(i_nocapture: 0);
9009 return nullptr;
9010 }
9011
9012 auto *C = dyn_cast<Constant>(Val: V2);
9013 if (C)
9014 return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
9015
9016 Value *CastedTo = nullptr;
9017 if (*CastOp == Instruction::Trunc) {
9018 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_ZExtOrSExt(Op: m_Specific(V: V2)))) {
9019 // Here we have the following case:
9020 // %y_ext = sext iK %y to iN
9021 // %cond = cmp iN %x, %y_ext
9022 // %tr = trunc iN %x to iK
9023 // %narrowsel = select i1 %cond, iK %tr, iK %y
9024 //
9025 // We can always move trunc after select operation:
9026 // %y_ext = sext iK %y to iN
9027 // %cond = cmp iN %x, %y_ext
9028 // %widesel = select i1 %cond, iN %x, iN %y_ext
9029 // %tr = trunc iN %widesel to iK
9030 assert(V2->getType() == Cast1->getType() &&
9031 "V2 and Cast1 should be the same type.");
9032 CastedTo = CmpI->getOperand(i_nocapture: 1);
9033 }
9034 }
9035
9036 return CastedTo;
9037}
9038SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
9039 Instruction::CastOps *CastOp,
9040 unsigned Depth) {
9041 if (Depth >= MaxAnalysisRecursionDepth)
9042 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9043
9044 SelectInst *SI = dyn_cast<SelectInst>(Val: V);
9045 if (!SI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9046
9047 CmpInst *CmpI = dyn_cast<CmpInst>(Val: SI->getCondition());
9048 if (!CmpI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9049
9050 Value *TrueVal = SI->getTrueValue();
9051 Value *FalseVal = SI->getFalseValue();
9052
9053 return llvm::matchDecomposedSelectPattern(
9054 CmpI, TrueVal, FalseVal, LHS, RHS,
9055 FMF: isa<FPMathOperator>(Val: SI) ? SI->getFastMathFlags() : FastMathFlags(),
9056 CastOp, Depth);
9057}
9058
9059SelectPatternResult llvm::matchDecomposedSelectPattern(
9060 CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
9061 FastMathFlags FMF, Instruction::CastOps *CastOp, unsigned Depth) {
9062 CmpInst::Predicate Pred = CmpI->getPredicate();
9063 Value *CmpLHS = CmpI->getOperand(i_nocapture: 0);
9064 Value *CmpRHS = CmpI->getOperand(i_nocapture: 1);
9065 if (isa<FPMathOperator>(Val: CmpI) && CmpI->hasNoNaNs())
9066 FMF.setNoNaNs();
9067
9068 // Bail out early.
9069 if (CmpI->isEquality())
9070 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9071
9072 // Deal with type mismatches.
9073 if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
9074 if (Value *C = lookThroughCast(CmpI, V1: TrueVal, V2: FalseVal, CastOp)) {
9075 // If this is a potential fmin/fmax with a cast to integer, then ignore
9076 // -0.0 because there is no corresponding integer value.
9077 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9078 FMF.setNoSignedZeros();
9079 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9080 TrueVal: cast<CastInst>(Val: TrueVal)->getOperand(i_nocapture: 0), FalseVal: C,
9081 LHS, RHS, Depth);
9082 }
9083 if (Value *C = lookThroughCast(CmpI, V1: FalseVal, V2: TrueVal, CastOp)) {
9084 // If this is a potential fmin/fmax with a cast to integer, then ignore
9085 // -0.0 because there is no corresponding integer value.
9086 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9087 FMF.setNoSignedZeros();
9088 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9089 TrueVal: C, FalseVal: cast<CastInst>(Val: FalseVal)->getOperand(i_nocapture: 0),
9090 LHS, RHS, Depth);
9091 }
9092 }
9093 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
9094 LHS, RHS, Depth);
9095}
9096
9097CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
9098 if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
9099 if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
9100 if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
9101 if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
9102 if (SPF == SPF_FMINNUM)
9103 return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
9104 if (SPF == SPF_FMAXNUM)
9105 return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
9106 llvm_unreachable("unhandled!");
9107}
9108
9109Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
9110 switch (SPF) {
9111 case SelectPatternFlavor::SPF_UMIN:
9112 return Intrinsic::umin;
9113 case SelectPatternFlavor::SPF_UMAX:
9114 return Intrinsic::umax;
9115 case SelectPatternFlavor::SPF_SMIN:
9116 return Intrinsic::smin;
9117 case SelectPatternFlavor::SPF_SMAX:
9118 return Intrinsic::smax;
9119 default:
9120 llvm_unreachable("Unexpected SPF");
9121 }
9122}
9123
9124SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
9125 if (SPF == SPF_SMIN) return SPF_SMAX;
9126 if (SPF == SPF_UMIN) return SPF_UMAX;
9127 if (SPF == SPF_SMAX) return SPF_SMIN;
9128 if (SPF == SPF_UMAX) return SPF_UMIN;
9129 llvm_unreachable("unhandled!");
9130}
9131
9132Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
9133 switch (MinMaxID) {
9134 case Intrinsic::smax: return Intrinsic::smin;
9135 case Intrinsic::smin: return Intrinsic::smax;
9136 case Intrinsic::umax: return Intrinsic::umin;
9137 case Intrinsic::umin: return Intrinsic::umax;
9138 // Please note that next four intrinsics may produce the same result for
9139 // original and inverted case even if X != Y due to NaN is handled specially.
9140 case Intrinsic::maximum: return Intrinsic::minimum;
9141 case Intrinsic::minimum: return Intrinsic::maximum;
9142 case Intrinsic::maxnum: return Intrinsic::minnum;
9143 case Intrinsic::minnum: return Intrinsic::maxnum;
9144 case Intrinsic::maximumnum:
9145 return Intrinsic::minimumnum;
9146 case Intrinsic::minimumnum:
9147 return Intrinsic::maximumnum;
9148 default: llvm_unreachable("Unexpected intrinsic");
9149 }
9150}
9151
9152APInt llvm::getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth) {
9153 switch (SPF) {
9154 case SPF_SMAX: return APInt::getSignedMaxValue(numBits: BitWidth);
9155 case SPF_SMIN: return APInt::getSignedMinValue(numBits: BitWidth);
9156 case SPF_UMAX: return APInt::getMaxValue(numBits: BitWidth);
9157 case SPF_UMIN: return APInt::getMinValue(numBits: BitWidth);
9158 default: llvm_unreachable("Unexpected flavor");
9159 }
9160}
9161
9162std::pair<Intrinsic::ID, bool>
9163llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
9164 // Check if VL contains select instructions that can be folded into a min/max
9165 // vector intrinsic and return the intrinsic if it is possible.
9166 // TODO: Support floating point min/max.
9167 bool AllCmpSingleUse = true;
9168 SelectPatternResult SelectPattern;
9169 SelectPattern.Flavor = SPF_UNKNOWN;
9170 if (all_of(Range&: VL, P: [&SelectPattern, &AllCmpSingleUse](Value *I) {
9171 Value *LHS, *RHS;
9172 auto CurrentPattern = matchSelectPattern(V: I, LHS, RHS);
9173 if (!SelectPatternResult::isMinOrMax(SPF: CurrentPattern.Flavor))
9174 return false;
9175 if (SelectPattern.Flavor != SPF_UNKNOWN &&
9176 SelectPattern.Flavor != CurrentPattern.Flavor)
9177 return false;
9178 SelectPattern = CurrentPattern;
9179 AllCmpSingleUse &=
9180 match(V: I, P: m_Select(C: m_OneUse(SubPattern: m_Value()), L: m_Value(), R: m_Value()));
9181 return true;
9182 })) {
9183 switch (SelectPattern.Flavor) {
9184 case SPF_SMIN:
9185 return {Intrinsic::smin, AllCmpSingleUse};
9186 case SPF_UMIN:
9187 return {Intrinsic::umin, AllCmpSingleUse};
9188 case SPF_SMAX:
9189 return {Intrinsic::smax, AllCmpSingleUse};
9190 case SPF_UMAX:
9191 return {Intrinsic::umax, AllCmpSingleUse};
9192 case SPF_FMAXNUM:
9193 return {Intrinsic::maxnum, AllCmpSingleUse};
9194 case SPF_FMINNUM:
9195 return {Intrinsic::minnum, AllCmpSingleUse};
9196 default:
9197 llvm_unreachable("unexpected select pattern flavor");
9198 }
9199 }
9200 return {Intrinsic::not_intrinsic, false};
9201}
9202
9203template <typename InstTy>
9204static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
9205 Value *&Init, Value *&OtherOp) {
9206 // Handle the case of a simple two-predecessor recurrence PHI.
9207 // There's a lot more that could theoretically be done here, but
9208 // this is sufficient to catch some interesting cases.
9209 // TODO: Expand list -- gep, uadd.sat etc.
9210 if (PN->getNumIncomingValues() != 2)
9211 return false;
9212
9213 for (unsigned I = 0; I != 2; ++I) {
9214 if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(i: I));
9215 Operation && Operation->getNumOperands() >= 2) {
9216 Value *LHS = Operation->getOperand(0);
9217 Value *RHS = Operation->getOperand(1);
9218 if (LHS != PN && RHS != PN)
9219 continue;
9220
9221 Inst = Operation;
9222 Init = PN->getIncomingValue(i: !I);
9223 OtherOp = (LHS == PN) ? RHS : LHS;
9224 return true;
9225 }
9226 }
9227 return false;
9228}
9229
9230bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9231 Value *&Start, Value *&Step) {
9232 // We try to match a recurrence of the form:
9233 // %iv = [Start, %entry], [%iv.next, %backedge]
9234 // %iv.next = binop %iv, Step
9235 // Or:
9236 // %iv = [Start, %entry], [%iv.next, %backedge]
9237 // %iv.next = binop Step, %iv
9238 return matchTwoInputRecurrence(PN: P, Inst&: BO, Init&: Start, OtherOp&: Step);
9239}
9240
9241bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
9242 Value *&Start, Value *&Step) {
9243 BinaryOperator *BO = nullptr;
9244 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 0));
9245 if (!P)
9246 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 1));
9247 return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
9248}
9249
9250bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
9251 PHINode *&P, Value *&Init,
9252 Value *&OtherOp) {
9253 // Binary intrinsics only supported for now.
9254 if (I->arg_size() != 2 || I->getType() != I->getArgOperand(i: 0)->getType() ||
9255 I->getType() != I->getArgOperand(i: 1)->getType())
9256 return false;
9257
9258 IntrinsicInst *II = nullptr;
9259 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 0));
9260 if (!P)
9261 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 1));
9262
9263 return P && matchTwoInputRecurrence(PN: P, Inst&: II, Init, OtherOp) && II == I;
9264}
9265
9266/// Return true if "icmp Pred LHS RHS" is always true.
9267static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
9268 const Value *RHS) {
9269 if (ICmpInst::isTrueWhenEqual(predicate: Pred) && LHS == RHS)
9270 return true;
9271
9272 switch (Pred) {
9273 default:
9274 return false;
9275
9276 case CmpInst::ICMP_SLE: {
9277 const APInt *C;
9278
9279 // LHS s<= LHS +_{nsw} C if C >= 0
9280 // LHS s<= LHS | C if C >= 0
9281 if (match(V: RHS, P: m_NSWAdd(L: m_Specific(V: LHS), R: m_APInt(Res&: C))) ||
9282 match(V: RHS, P: m_Or(L: m_Specific(V: LHS), R: m_APInt(Res&: C))))
9283 return !C->isNegative();
9284
9285 // LHS s<= smax(LHS, V) for any V
9286 if (match(V: RHS, P: m_c_SMax(L: m_Specific(V: LHS), R: m_Value())))
9287 return true;
9288
9289 // smin(RHS, V) s<= RHS for any V
9290 if (match(V: LHS, P: m_c_SMin(L: m_Specific(V: RHS), R: m_Value())))
9291 return true;
9292
9293 // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
9294 const Value *X;
9295 const APInt *CLHS, *CRHS;
9296 if (match(V: LHS, P: m_NSWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9297 match(V: RHS, P: m_NSWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9298 return CLHS->sle(RHS: *CRHS);
9299
9300 return false;
9301 }
9302
9303 case CmpInst::ICMP_ULE: {
9304 // LHS u<= LHS +_{nuw} V for any V
9305 if (match(V: RHS, P: m_c_Add(L: m_Specific(V: LHS), R: m_Value())) &&
9306 cast<OverflowingBinaryOperator>(Val: RHS)->hasNoUnsignedWrap())
9307 return true;
9308
9309 // LHS u<= LHS | V for any V
9310 if (match(V: RHS, P: m_c_Or(L: m_Specific(V: LHS), R: m_Value())))
9311 return true;
9312
9313 // LHS u<= umax(LHS, V) for any V
9314 if (match(V: RHS, P: m_c_UMax(L: m_Specific(V: LHS), R: m_Value())))
9315 return true;
9316
9317 // RHS >> V u<= RHS for any V
9318 if (match(V: LHS, P: m_LShr(L: m_Specific(V: RHS), R: m_Value())))
9319 return true;
9320
9321 // RHS u/ C_ugt_1 u<= RHS
9322 const APInt *C;
9323 if (match(V: LHS, P: m_UDiv(L: m_Specific(V: RHS), R: m_APInt(Res&: C))) && C->ugt(RHS: 1))
9324 return true;
9325
9326 // RHS & V u<= RHS for any V
9327 if (match(V: LHS, P: m_c_And(L: m_Specific(V: RHS), R: m_Value())))
9328 return true;
9329
9330 // umin(RHS, V) u<= RHS for any V
9331 if (match(V: LHS, P: m_c_UMin(L: m_Specific(V: RHS), R: m_Value())))
9332 return true;
9333
9334 // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
9335 const Value *X;
9336 const APInt *CLHS, *CRHS;
9337 if (match(V: LHS, P: m_NUWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9338 match(V: RHS, P: m_NUWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9339 return CLHS->ule(RHS: *CRHS);
9340
9341 return false;
9342 }
9343 }
9344}
9345
9346/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
9347/// ALHS ARHS" is true. Otherwise, return std::nullopt.
9348static std::optional<bool>
9349isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
9350 const Value *ARHS, const Value *BLHS, const Value *BRHS) {
9351 switch (Pred) {
9352 default:
9353 return std::nullopt;
9354
9355 case CmpInst::ICMP_SLT:
9356 case CmpInst::ICMP_SLE:
9357 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BLHS, RHS: ALHS) &&
9358 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ARHS, RHS: BRHS))
9359 return true;
9360 return std::nullopt;
9361
9362 case CmpInst::ICMP_SGT:
9363 case CmpInst::ICMP_SGE:
9364 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ALHS, RHS: BLHS) &&
9365 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BRHS, RHS: ARHS))
9366 return true;
9367 return std::nullopt;
9368
9369 case CmpInst::ICMP_ULT:
9370 case CmpInst::ICMP_ULE:
9371 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BLHS, RHS: ALHS) &&
9372 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ARHS, RHS: BRHS))
9373 return true;
9374 return std::nullopt;
9375
9376 case CmpInst::ICMP_UGT:
9377 case CmpInst::ICMP_UGE:
9378 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ALHS, RHS: BLHS) &&
9379 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BRHS, RHS: ARHS))
9380 return true;
9381 return std::nullopt;
9382 }
9383}
9384
9385/// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
9386/// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
9387/// Otherwise, return std::nullopt if we can't infer anything.
9388static std::optional<bool>
9389isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
9390 CmpPredicate RPred, const ConstantRange &RCR) {
9391 auto CRImpliesPred = [&](ConstantRange CR,
9392 CmpInst::Predicate Pred) -> std::optional<bool> {
9393 // If all true values for lhs and true for rhs, lhs implies rhs
9394 if (CR.icmp(Pred, Other: RCR))
9395 return true;
9396
9397 // If there is no overlap, lhs implies not rhs
9398 if (CR.icmp(Pred: CmpInst::getInversePredicate(pred: Pred), Other: RCR))
9399 return false;
9400
9401 return std::nullopt;
9402 };
9403 if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9404 RPred))
9405 return Res;
9406 if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
9407 LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: LPred)
9408 : LPred.dropSameSign();
9409 RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: RPred)
9410 : RPred.dropSameSign();
9411 return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9412 RPred);
9413 }
9414 return std::nullopt;
9415}
9416
9417/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9418/// is true. Return false if LHS implies RHS is false. Otherwise, return
9419/// std::nullopt if we can't infer anything.
9420static std::optional<bool>
9421isImpliedCondICmps(CmpPredicate LPred, const Value *L0, const Value *L1,
9422 CmpPredicate RPred, const Value *R0, const Value *R1,
9423 const DataLayout &DL, bool LHSIsTrue) {
9424 // The rest of the logic assumes the LHS condition is true. If that's not the
9425 // case, invert the predicate to make it so.
9426 if (!LHSIsTrue)
9427 LPred = ICmpInst::getInverseCmpPredicate(Pred: LPred);
9428
9429 // We can have non-canonical operands, so try to normalize any common operand
9430 // to L0/R0.
9431 if (L0 == R1) {
9432 std::swap(a&: R0, b&: R1);
9433 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9434 }
9435 if (R0 == L1) {
9436 std::swap(a&: L0, b&: L1);
9437 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9438 }
9439 if (L1 == R1) {
9440 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9441 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9442 std::swap(a&: L0, b&: L1);
9443 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9444 std::swap(a&: R0, b&: R1);
9445 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9446 }
9447 }
9448
9449 // See if we can infer anything if operand-0 matches and we have at least one
9450 // constant.
9451 const APInt *Unused;
9452 if (L0 == R0 && (match(V: L1, P: m_APInt(Res&: Unused)) || match(V: R1, P: m_APInt(Res&: Unused)))) {
9453 // Potential TODO: We could also further use the constant range of L0/R0 to
9454 // further constraint the constant ranges. At the moment this leads to
9455 // several regressions related to not transforming `multi_use(A + C0) eq/ne
9456 // C1` (see discussion: D58633).
9457 ConstantRange LCR = computeConstantRange(
9458 V: L1, ForSigned: ICmpInst::isSigned(predicate: LPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9459 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9460 ConstantRange RCR = computeConstantRange(
9461 V: R1, ForSigned: ICmpInst::isSigned(predicate: RPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9462 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9463 // Even if L1/R1 are not both constant, we can still sometimes deduce
9464 // relationship from a single constant. For example X u> Y implies X != 0.
9465 if (auto R = isImpliedCondCommonOperandWithCR(LPred, LCR, RPred, RCR))
9466 return R;
9467 // If both L1/R1 were exact constant ranges and we didn't get anything
9468 // here, we won't be able to deduce this.
9469 if (match(V: L1, P: m_APInt(Res&: Unused)) && match(V: R1, P: m_APInt(Res&: Unused)))
9470 return std::nullopt;
9471 }
9472
9473 // Can we infer anything when the two compares have matching operands?
9474 if (L0 == R0 && L1 == R1)
9475 return ICmpInst::isImpliedByMatchingCmp(Pred1: LPred, Pred2: RPred);
9476
9477 // It only really makes sense in the context of signed comparison for "X - Y
9478 // must be positive if X >= Y and no overflow".
9479 // Take SGT as an example: L0:x > L1:y and C >= 0
9480 // ==> R0:(x -nsw y) < R1:(-C) is false
9481 CmpInst::Predicate SignedLPred = LPred.getPreferredSignedPredicate();
9482 if ((SignedLPred == ICmpInst::ICMP_SGT ||
9483 SignedLPred == ICmpInst::ICMP_SGE) &&
9484 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9485 if (match(V: R1, P: m_NonPositive()) &&
9486 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == false)
9487 return false;
9488 }
9489
9490 // Take SLT as an example: L0:x < L1:y and C <= 0
9491 // ==> R0:(x -nsw y) < R1:(-C) is true
9492 if ((SignedLPred == ICmpInst::ICMP_SLT ||
9493 SignedLPred == ICmpInst::ICMP_SLE) &&
9494 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9495 if (match(V: R1, P: m_NonNegative()) &&
9496 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == true)
9497 return true;
9498 }
9499
9500 // a - b == NonZero -> a != b
9501 // ptrtoint(a) - ptrtoint(b) == NonZero -> a != b
9502 const APInt *L1C;
9503 Value *A, *B;
9504 if (LPred == ICmpInst::ICMP_EQ && ICmpInst::isEquality(P: RPred) &&
9505 match(V: L1, P: m_APInt(Res&: L1C)) && !L1C->isZero() &&
9506 match(V: L0, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
9507 ((A == R0 && B == R1) || (A == R1 && B == R0) ||
9508 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0))) &&
9509 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1)))) ||
9510 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1))) &&
9511 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0)))))) {
9512 return RPred.dropSameSign() == ICmpInst::ICMP_NE;
9513 }
9514
9515 // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
9516 if (L0 == R0 &&
9517 (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
9518 (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
9519 match(V: L0, P: m_c_Add(L: m_Specific(V: L1), R: m_Specific(V: R1))))
9520 return CmpPredicate::getMatching(A: LPred, B: RPred).has_value();
9521
9522 if (auto P = CmpPredicate::getMatching(A: LPred, B: RPred))
9523 return isImpliedCondOperands(Pred: *P, ALHS: L0, ARHS: L1, BLHS: R0, BRHS: R1);
9524
9525 return std::nullopt;
9526}
9527
9528/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9529/// is true. Return false if LHS implies RHS is false. Otherwise, return
9530/// std::nullopt if we can't infer anything.
9531static std::optional<bool>
9532isImpliedCondFCmps(FCmpInst::Predicate LPred, const Value *L0, const Value *L1,
9533 FCmpInst::Predicate RPred, const Value *R0, const Value *R1,
9534 const DataLayout &DL, bool LHSIsTrue) {
9535 // The rest of the logic assumes the LHS condition is true. If that's not the
9536 // case, invert the predicate to make it so.
9537 if (!LHSIsTrue)
9538 LPred = FCmpInst::getInversePredicate(pred: LPred);
9539
9540 // We can have non-canonical operands, so try to normalize any common operand
9541 // to L0/R0.
9542 if (L0 == R1) {
9543 std::swap(a&: R0, b&: R1);
9544 RPred = FCmpInst::getSwappedPredicate(pred: RPred);
9545 }
9546 if (R0 == L1) {
9547 std::swap(a&: L0, b&: L1);
9548 LPred = FCmpInst::getSwappedPredicate(pred: LPred);
9549 }
9550 if (L1 == R1) {
9551 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9552 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9553 std::swap(a&: L0, b&: L1);
9554 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9555 std::swap(a&: R0, b&: R1);
9556 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9557 }
9558 }
9559
9560 // Can we infer anything when the two compares have matching operands?
9561 if (L0 == R0 && L1 == R1) {
9562 if ((LPred & RPred) == LPred)
9563 return true;
9564 if ((LPred & ~RPred) == LPred)
9565 return false;
9566 }
9567
9568 // See if we can infer anything if operand-0 matches and we have at least one
9569 // constant.
9570 const APFloat *L1C, *R1C;
9571 if (L0 == R0 && match(V: L1, P: m_APFloat(Res&: L1C)) && match(V: R1, P: m_APFloat(Res&: R1C))) {
9572 if (std::optional<ConstantFPRange> DomCR =
9573 ConstantFPRange::makeExactFCmpRegion(Pred: LPred, Other: *L1C)) {
9574 if (std::optional<ConstantFPRange> ImpliedCR =
9575 ConstantFPRange::makeExactFCmpRegion(Pred: RPred, Other: *R1C)) {
9576 if (ImpliedCR->contains(CR: *DomCR))
9577 return true;
9578 }
9579 if (std::optional<ConstantFPRange> ImpliedCR =
9580 ConstantFPRange::makeExactFCmpRegion(
9581 Pred: FCmpInst::getInversePredicate(pred: RPred), Other: *R1C)) {
9582 if (ImpliedCR->contains(CR: *DomCR))
9583 return false;
9584 }
9585 }
9586 }
9587
9588 return std::nullopt;
9589}
9590
9591/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is
9592/// false. Otherwise, return std::nullopt if we can't infer anything. We
9593/// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
9594/// instruction.
9595static std::optional<bool>
9596isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
9597 const Value *RHSOp0, const Value *RHSOp1,
9598 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9599 // The LHS must be an 'or', 'and', or a 'select' instruction.
9600 assert((LHS->getOpcode() == Instruction::And ||
9601 LHS->getOpcode() == Instruction::Or ||
9602 LHS->getOpcode() == Instruction::Select) &&
9603 "Expected LHS to be 'and', 'or', or 'select'.");
9604
9605 assert(Depth <= MaxAnalysisRecursionDepth && "Hit recursion limit");
9606
9607 // If the result of an 'or' is false, then we know both legs of the 'or' are
9608 // false. Similarly, if the result of an 'and' is true, then we know both
9609 // legs of the 'and' are true.
9610 const Value *ALHS, *ARHS;
9611 if ((!LHSIsTrue && match(V: LHS, P: m_LogicalOr(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS)))) ||
9612 (LHSIsTrue && match(V: LHS, P: m_LogicalAnd(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS))))) {
9613 // FIXME: Make this non-recursion.
9614 if (std::optional<bool> Implication = isImpliedCondition(
9615 LHS: ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9616 return Implication;
9617 if (std::optional<bool> Implication = isImpliedCondition(
9618 LHS: ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9619 return Implication;
9620 return std::nullopt;
9621 }
9622 return std::nullopt;
9623}
9624
9625std::optional<bool>
9626llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
9627 const Value *RHSOp0, const Value *RHSOp1,
9628 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9629 // Bail out when we hit the limit.
9630 if (Depth == MaxAnalysisRecursionDepth)
9631 return std::nullopt;
9632
9633 // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
9634 // example.
9635 if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy())
9636 return std::nullopt;
9637
9638 assert(LHS->getType()->isIntOrIntVectorTy(1) &&
9639 "Expected integer type only!");
9640
9641 // Match not
9642 if (match(V: LHS, P: m_Not(V: m_Value(V&: LHS))))
9643 LHSIsTrue = !LHSIsTrue;
9644
9645 // Both LHS and RHS are icmps.
9646 if (RHSOp0->getType()->getScalarType()->isIntOrPtrTy()) {
9647 if (const auto *LHSCmp = dyn_cast<ICmpInst>(Val: LHS))
9648 return isImpliedCondICmps(LPred: LHSCmp->getCmpPredicate(),
9649 L0: LHSCmp->getOperand(i_nocapture: 0), L1: LHSCmp->getOperand(i_nocapture: 1),
9650 RPred: RHSPred, R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9651 const Value *V;
9652 if (match(V: LHS, P: m_NUWTrunc(Op: m_Value(V))))
9653 return isImpliedCondICmps(LPred: CmpInst::ICMP_NE, L0: V,
9654 L1: ConstantInt::get(Ty: V->getType(), V: 0), RPred: RHSPred,
9655 R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9656 } else {
9657 assert(RHSOp0->getType()->isFPOrFPVectorTy() &&
9658 "Expected floating point type only!");
9659 if (const auto *LHSCmp = dyn_cast<FCmpInst>(Val: LHS))
9660 return isImpliedCondFCmps(LPred: LHSCmp->getPredicate(), L0: LHSCmp->getOperand(i_nocapture: 0),
9661 L1: LHSCmp->getOperand(i_nocapture: 1), RPred: RHSPred, R0: RHSOp0, R1: RHSOp1,
9662 DL, LHSIsTrue);
9663 }
9664
9665 /// The LHS should be an 'or', 'and', or a 'select' instruction. We expect
9666 /// the RHS to be an icmp.
9667 /// FIXME: Add support for and/or/select on the RHS.
9668 if (const Instruction *LHSI = dyn_cast<Instruction>(Val: LHS)) {
9669 if ((LHSI->getOpcode() == Instruction::And ||
9670 LHSI->getOpcode() == Instruction::Or ||
9671 LHSI->getOpcode() == Instruction::Select))
9672 return isImpliedCondAndOr(LHS: LHSI, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
9673 Depth);
9674 }
9675 return std::nullopt;
9676}
9677
9678std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
9679 const DataLayout &DL,
9680 bool LHSIsTrue, unsigned Depth) {
9681 // LHS ==> RHS by definition
9682 if (LHS == RHS)
9683 return LHSIsTrue;
9684
9685 // Match not
9686 bool InvertRHS = false;
9687 if (match(V: RHS, P: m_Not(V: m_Value(V&: RHS)))) {
9688 if (LHS == RHS)
9689 return !LHSIsTrue;
9690 InvertRHS = true;
9691 }
9692
9693 if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(Val: RHS)) {
9694 if (auto Implied = isImpliedCondition(
9695 LHS, RHSPred: RHSCmp->getCmpPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9696 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9697 return InvertRHS ? !*Implied : *Implied;
9698 return std::nullopt;
9699 }
9700 if (const FCmpInst *RHSCmp = dyn_cast<FCmpInst>(Val: RHS)) {
9701 if (auto Implied = isImpliedCondition(
9702 LHS, RHSPred: RHSCmp->getPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9703 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9704 return InvertRHS ? !*Implied : *Implied;
9705 return std::nullopt;
9706 }
9707
9708 const Value *V;
9709 if (match(V: RHS, P: m_NUWTrunc(Op: m_Value(V)))) {
9710 if (auto Implied = isImpliedCondition(LHS, RHSPred: CmpInst::ICMP_NE, RHSOp0: V,
9711 RHSOp1: ConstantInt::get(Ty: V->getType(), V: 0), DL,
9712 LHSIsTrue, Depth))
9713 return InvertRHS ? !*Implied : *Implied;
9714 return std::nullopt;
9715 }
9716
9717 if (Depth == MaxAnalysisRecursionDepth)
9718 return std::nullopt;
9719
9720 // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2
9721 // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2
9722 const Value *RHS1, *RHS2;
9723 if (match(V: RHS, P: m_LogicalOr(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9724 if (std::optional<bool> Imp =
9725 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9726 if (*Imp == true)
9727 return !InvertRHS;
9728 if (std::optional<bool> Imp =
9729 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9730 if (*Imp == true)
9731 return !InvertRHS;
9732 }
9733 if (match(V: RHS, P: m_LogicalAnd(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9734 if (std::optional<bool> Imp =
9735 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9736 if (*Imp == false)
9737 return InvertRHS;
9738 if (std::optional<bool> Imp =
9739 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9740 if (*Imp == false)
9741 return InvertRHS;
9742 }
9743
9744 return std::nullopt;
9745}
9746
9747// Returns a pair (Condition, ConditionIsTrue), where Condition is a branch
9748// condition dominating ContextI or nullptr, if no condition is found.
9749static std::pair<Value *, bool>
9750getDomPredecessorCondition(const Instruction *ContextI) {
9751 if (!ContextI || !ContextI->getParent())
9752 return {nullptr, false};
9753
9754 // TODO: This is a poor/cheap way to determine dominance. Should we use a
9755 // dominator tree (eg, from a SimplifyQuery) instead?
9756 const BasicBlock *ContextBB = ContextI->getParent();
9757 const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
9758 if (!PredBB)
9759 return {nullptr, false};
9760
9761 // We need a conditional branch in the predecessor.
9762 Value *PredCond;
9763 BasicBlock *TrueBB, *FalseBB;
9764 if (!match(V: PredBB->getTerminator(), P: m_Br(C: m_Value(V&: PredCond), T&: TrueBB, F&: FalseBB)))
9765 return {nullptr, false};
9766
9767 // The branch should get simplified. Don't bother simplifying this condition.
9768 if (TrueBB == FalseBB)
9769 return {nullptr, false};
9770
9771 assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
9772 "Predecessor block does not point to successor?");
9773
9774 // Is this condition implied by the predecessor condition?
9775 return {PredCond, TrueBB == ContextBB};
9776}
9777
9778std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
9779 const Instruction *ContextI,
9780 const DataLayout &DL) {
9781 assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
9782 auto PredCond = getDomPredecessorCondition(ContextI);
9783 if (PredCond.first)
9784 return isImpliedCondition(LHS: PredCond.first, RHS: Cond, DL, LHSIsTrue: PredCond.second);
9785 return std::nullopt;
9786}
9787
9788std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
9789 const Value *LHS,
9790 const Value *RHS,
9791 const Instruction *ContextI,
9792 const DataLayout &DL) {
9793 auto PredCond = getDomPredecessorCondition(ContextI);
9794 if (PredCond.first)
9795 return isImpliedCondition(LHS: PredCond.first, RHSPred: Pred, RHSOp0: LHS, RHSOp1: RHS, DL,
9796 LHSIsTrue: PredCond.second);
9797 return std::nullopt;
9798}
9799
9800static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
9801 APInt &Upper, const InstrInfoQuery &IIQ,
9802 bool PreferSignedRange) {
9803 unsigned Width = Lower.getBitWidth();
9804 const APInt *C;
9805 switch (BO.getOpcode()) {
9806 case Instruction::Sub:
9807 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9808 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9809 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9810
9811 // If the caller expects a signed compare, then try to use a signed range.
9812 // Otherwise if both no-wraps are set, use the unsigned range because it
9813 // is never larger than the signed range. Example:
9814 // "sub nuw nsw i8 -2, x" is unsigned [0, 254] vs. signed [-128, 126].
9815 // "sub nuw nsw i8 2, x" is unsigned [0, 2] vs. signed [-125, 127].
9816 if (PreferSignedRange && HasNSW && HasNUW)
9817 HasNUW = false;
9818
9819 if (HasNUW) {
9820 // 'sub nuw c, x' produces [0, C].
9821 Upper = *C + 1;
9822 } else if (HasNSW) {
9823 if (C->isNegative()) {
9824 // 'sub nsw -C, x' produces [SINT_MIN, -C - SINT_MIN].
9825 Lower = APInt::getSignedMinValue(numBits: Width);
9826 Upper = *C - APInt::getSignedMaxValue(numBits: Width);
9827 } else {
9828 // Note that sub 0, INT_MIN is not NSW. It techically is a signed wrap
9829 // 'sub nsw C, x' produces [C - SINT_MAX, SINT_MAX].
9830 Lower = *C - APInt::getSignedMaxValue(numBits: Width);
9831 Upper = APInt::getSignedMinValue(numBits: Width);
9832 }
9833 }
9834 }
9835 break;
9836 case Instruction::Add:
9837 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
9838 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9839 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9840
9841 // If the caller expects a signed compare, then try to use a signed
9842 // range. Otherwise if both no-wraps are set, use the unsigned range
9843 // because it is never larger than the signed range. Example: "add nuw
9844 // nsw i8 X, -2" is unsigned [254,255] vs. signed [-128, 125].
9845 if (PreferSignedRange && HasNSW && HasNUW)
9846 HasNUW = false;
9847
9848 if (HasNUW) {
9849 // 'add nuw x, C' produces [C, UINT_MAX].
9850 Lower = *C;
9851 } else if (HasNSW) {
9852 if (C->isNegative()) {
9853 // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
9854 Lower = APInt::getSignedMinValue(numBits: Width);
9855 Upper = APInt::getSignedMaxValue(numBits: Width) + *C + 1;
9856 } else {
9857 // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
9858 Lower = APInt::getSignedMinValue(numBits: Width) + *C;
9859 Upper = APInt::getSignedMaxValue(numBits: Width) + 1;
9860 }
9861 }
9862 }
9863 break;
9864
9865 case Instruction::And:
9866 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
9867 // 'and x, C' produces [0, C].
9868 Upper = *C + 1;
9869 // X & -X is a power of two or zero. So we can cap the value at max power of
9870 // two.
9871 if (match(V: BO.getOperand(i_nocapture: 0), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 1)))) ||
9872 match(V: BO.getOperand(i_nocapture: 1), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 0)))))
9873 Upper = APInt::getSignedMinValue(numBits: Width) + 1;
9874 break;
9875
9876 case Instruction::Or:
9877 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
9878 // 'or x, C' produces [C, UINT_MAX].
9879 Lower = *C;
9880 break;
9881
9882 case Instruction::AShr:
9883 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9884 // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
9885 Lower = APInt::getSignedMinValue(numBits: Width).ashr(ShiftAmt: *C);
9886 Upper = APInt::getSignedMaxValue(numBits: Width).ashr(ShiftAmt: *C) + 1;
9887 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9888 unsigned ShiftAmount = Width - 1;
9889 if (!C->isZero() && IIQ.isExact(Op: &BO))
9890 ShiftAmount = C->countr_zero();
9891 if (C->isNegative()) {
9892 // 'ashr C, x' produces [C, C >> (Width-1)]
9893 Lower = *C;
9894 Upper = C->ashr(ShiftAmt: ShiftAmount) + 1;
9895 } else {
9896 // 'ashr C, x' produces [C >> (Width-1), C]
9897 Lower = C->ashr(ShiftAmt: ShiftAmount);
9898 Upper = *C + 1;
9899 }
9900 }
9901 break;
9902
9903 case Instruction::LShr:
9904 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9905 // 'lshr x, C' produces [0, UINT_MAX >> C].
9906 Upper = APInt::getAllOnes(numBits: Width).lshr(ShiftAmt: *C) + 1;
9907 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9908 // 'lshr C, x' produces [C >> (Width-1), C].
9909 unsigned ShiftAmount = Width - 1;
9910 if (!C->isZero() && IIQ.isExact(Op: &BO))
9911 ShiftAmount = C->countr_zero();
9912 Lower = C->lshr(shiftAmt: ShiftAmount);
9913 Upper = *C + 1;
9914 }
9915 break;
9916
9917 case Instruction::Shl:
9918 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9919 if (IIQ.hasNoUnsignedWrap(Op: &BO)) {
9920 // 'shl nuw C, x' produces [C, C << CLZ(C)]
9921 Lower = *C;
9922 Upper = Lower.shl(shiftAmt: Lower.countl_zero()) + 1;
9923 } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
9924 if (C->isNegative()) {
9925 // 'shl nsw C, x' produces [C << CLO(C)-1, C]
9926 unsigned ShiftAmount = C->countl_one() - 1;
9927 Lower = C->shl(shiftAmt: ShiftAmount);
9928 Upper = *C + 1;
9929 } else {
9930 // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
9931 unsigned ShiftAmount = C->countl_zero() - 1;
9932 Lower = *C;
9933 Upper = C->shl(shiftAmt: ShiftAmount) + 1;
9934 }
9935 } else {
9936 // If lowbit is set, value can never be zero.
9937 if ((*C)[0])
9938 Lower = APInt::getOneBitSet(numBits: Width, BitNo: 0);
9939 // If we are shifting a constant the largest it can be is if the longest
9940 // sequence of consecutive ones is shifted to the highbits (breaking
9941 // ties for which sequence is higher). At the moment we take a liberal
9942 // upper bound on this by just popcounting the constant.
9943 // TODO: There may be a bitwise trick for it longest/highest
9944 // consecutative sequence of ones (naive method is O(Width) loop).
9945 Upper = APInt::getHighBitsSet(numBits: Width, hiBitsSet: C->popcount()) + 1;
9946 }
9947 } else if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9948 Upper = APInt::getBitsSetFrom(numBits: Width, loBit: C->getZExtValue()) + 1;
9949 }
9950 break;
9951
9952 case Instruction::SDiv:
9953 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
9954 APInt IntMin = APInt::getSignedMinValue(numBits: Width);
9955 APInt IntMax = APInt::getSignedMaxValue(numBits: Width);
9956 if (C->isAllOnes()) {
9957 // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
9958 // where C != -1 and C != 0 and C != 1
9959 Lower = IntMin + 1;
9960 Upper = IntMax + 1;
9961 } else if (C->countl_zero() < Width - 1) {
9962 // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
9963 // where C != -1 and C != 0 and C != 1
9964 Lower = IntMin.sdiv(RHS: *C);
9965 Upper = IntMax.sdiv(RHS: *C);
9966 if (Lower.sgt(RHS: Upper))
9967 std::swap(a&: Lower, b&: Upper);
9968 Upper = Upper + 1;
9969 assert(Upper != Lower && "Upper part of range has wrapped!");
9970 }
9971 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9972 if (C->isMinSignedValue()) {
9973 // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
9974 Lower = *C;
9975 Upper = Lower.lshr(shiftAmt: 1) + 1;
9976 } else {
9977 // 'sdiv C, x' produces [-|C|, |C|].
9978 Upper = C->abs() + 1;
9979 Lower = (-Upper) + 1;
9980 }
9981 }
9982 break;
9983
9984 case Instruction::UDiv:
9985 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
9986 // 'udiv x, C' produces [0, UINT_MAX / C].
9987 Upper = APInt::getMaxValue(numBits: Width).udiv(RHS: *C) + 1;
9988 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9989 // 'udiv C, x' produces [0, C].
9990 Upper = *C + 1;
9991 }
9992 break;
9993
9994 case Instruction::SRem:
9995 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
9996 // 'srem x, C' produces (-|C|, |C|).
9997 Upper = C->abs();
9998 Lower = (-Upper) + 1;
9999 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10000 if (C->isNegative()) {
10001 // 'srem -|C|, x' produces [-|C|, 0].
10002 Upper = 1;
10003 Lower = *C;
10004 } else {
10005 // 'srem |C|, x' produces [0, |C|].
10006 Upper = *C + 1;
10007 }
10008 }
10009 break;
10010
10011 case Instruction::URem:
10012 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10013 // 'urem x, C' produces [0, C).
10014 Upper = *C;
10015 else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10016 // 'urem C, x' produces [0, C].
10017 Upper = *C + 1;
10018 break;
10019
10020 default:
10021 break;
10022 }
10023}
10024
10025static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II,
10026 bool UseInstrInfo) {
10027 unsigned Width = II.getType()->getScalarSizeInBits();
10028 const APInt *C;
10029 switch (II.getIntrinsicID()) {
10030 case Intrinsic::ctlz:
10031 case Intrinsic::cttz: {
10032 APInt Upper(Width, Width);
10033 if (!UseInstrInfo || !match(V: II.getArgOperand(i: 1), P: m_One()))
10034 Upper += 1;
10035 // Maximum of set/clear bits is the bit width.
10036 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper);
10037 }
10038 case Intrinsic::ctpop:
10039 // Maximum of set/clear bits is the bit width.
10040 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10041 Upper: APInt(Width, Width) + 1);
10042 case Intrinsic::uadd_sat:
10043 // uadd.sat(x, C) produces [C, UINT_MAX].
10044 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10045 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10046 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10047 break;
10048 case Intrinsic::sadd_sat:
10049 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10050 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10051 if (C->isNegative())
10052 // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
10053 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10054 Upper: APInt::getSignedMaxValue(numBits: Width) + *C +
10055 1);
10056
10057 // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
10058 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) + *C,
10059 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10060 }
10061 break;
10062 case Intrinsic::usub_sat:
10063 // usub.sat(C, x) produces [0, C].
10064 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10065 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10066
10067 // usub.sat(x, C) produces [0, UINT_MAX - C].
10068 if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10069 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10070 Upper: APInt::getMaxValue(numBits: Width) - *C + 1);
10071 break;
10072 case Intrinsic::ssub_sat:
10073 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10074 if (C->isNegative())
10075 // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
10076 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10077 Upper: *C - APInt::getSignedMinValue(numBits: Width) +
10078 1);
10079
10080 // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
10081 return ConstantRange::getNonEmpty(Lower: *C - APInt::getSignedMaxValue(numBits: Width),
10082 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10083 } else if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10084 if (C->isNegative())
10085 // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
10086 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) - *C,
10087 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10088
10089 // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
10090 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10091 Upper: APInt::getSignedMaxValue(numBits: Width) - *C +
10092 1);
10093 }
10094 break;
10095 case Intrinsic::umin:
10096 case Intrinsic::umax:
10097 case Intrinsic::smin:
10098 case Intrinsic::smax:
10099 if (!match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) &&
10100 !match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10101 break;
10102
10103 switch (II.getIntrinsicID()) {
10104 case Intrinsic::umin:
10105 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10106 case Intrinsic::umax:
10107 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10108 case Intrinsic::smin:
10109 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10110 Upper: *C + 1);
10111 case Intrinsic::smax:
10112 return ConstantRange::getNonEmpty(Lower: *C,
10113 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10114 default:
10115 llvm_unreachable("Must be min/max intrinsic");
10116 }
10117 break;
10118 case Intrinsic::abs:
10119 // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX],
10120 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10121 if (match(V: II.getOperand(i_nocapture: 1), P: m_One()))
10122 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10123 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10124
10125 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10126 Upper: APInt::getSignedMinValue(numBits: Width) + 1);
10127 case Intrinsic::vscale:
10128 if (!II.getParent() || !II.getFunction())
10129 break;
10130 return getVScaleRange(F: II.getFunction(), BitWidth: Width);
10131 default:
10132 break;
10133 }
10134
10135 return ConstantRange::getFull(BitWidth: Width);
10136}
10137
10138static ConstantRange getRangeForSelectPattern(const SelectInst &SI,
10139 const InstrInfoQuery &IIQ) {
10140 unsigned BitWidth = SI.getType()->getScalarSizeInBits();
10141 const Value *LHS = nullptr, *RHS = nullptr;
10142 SelectPatternResult R = matchSelectPattern(V: &SI, LHS, RHS);
10143 if (R.Flavor == SPF_UNKNOWN)
10144 return ConstantRange::getFull(BitWidth);
10145
10146 if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
10147 // If the negation part of the abs (in RHS) has the NSW flag,
10148 // then the result of abs(X) is [0..SIGNED_MAX],
10149 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10150 if (match(V: RHS, P: m_Neg(V: m_Specific(V: LHS))) &&
10151 IIQ.hasNoSignedWrap(Op: cast<Instruction>(Val: RHS)))
10152 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10153 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10154
10155 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10156 Upper: APInt::getSignedMinValue(numBits: BitWidth) + 1);
10157 }
10158
10159 if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
10160 // The result of -abs(X) is <= 0.
10161 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10162 Upper: APInt(BitWidth, 1));
10163 }
10164
10165 const APInt *C;
10166 if (!match(V: LHS, P: m_APInt(Res&: C)) && !match(V: RHS, P: m_APInt(Res&: C)))
10167 return ConstantRange::getFull(BitWidth);
10168
10169 switch (R.Flavor) {
10170 case SPF_UMIN:
10171 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth), Upper: *C + 1);
10172 case SPF_UMAX:
10173 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: BitWidth));
10174 case SPF_SMIN:
10175 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10176 Upper: *C + 1);
10177 case SPF_SMAX:
10178 return ConstantRange::getNonEmpty(Lower: *C,
10179 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10180 default:
10181 return ConstantRange::getFull(BitWidth);
10182 }
10183}
10184
10185static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
10186 // The maximum representable value of a half is 65504. For floats the maximum
10187 // value is 3.4e38 which requires roughly 129 bits.
10188 unsigned BitWidth = I->getType()->getScalarSizeInBits();
10189 if (!I->getOperand(i: 0)->getType()->getScalarType()->isHalfTy())
10190 return;
10191 if (isa<FPToSIInst>(Val: I) && BitWidth >= 17) {
10192 Lower = APInt(BitWidth, -65504, true);
10193 Upper = APInt(BitWidth, 65505);
10194 }
10195
10196 if (isa<FPToUIInst>(Val: I) && BitWidth >= 16) {
10197 // For a fptoui the lower limit is left as 0.
10198 Upper = APInt(BitWidth, 65505);
10199 }
10200}
10201
10202ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
10203 bool UseInstrInfo, AssumptionCache *AC,
10204 const Instruction *CtxI,
10205 const DominatorTree *DT,
10206 unsigned Depth) {
10207 assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
10208
10209 if (Depth == MaxAnalysisRecursionDepth)
10210 return ConstantRange::getFull(BitWidth: V->getType()->getScalarSizeInBits());
10211
10212 if (auto *C = dyn_cast<Constant>(Val: V))
10213 return C->toConstantRange();
10214
10215 unsigned BitWidth = V->getType()->getScalarSizeInBits();
10216 InstrInfoQuery IIQ(UseInstrInfo);
10217 ConstantRange CR = ConstantRange::getFull(BitWidth);
10218 if (auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10219 APInt Lower = APInt(BitWidth, 0);
10220 APInt Upper = APInt(BitWidth, 0);
10221 // TODO: Return ConstantRange.
10222 setLimitsForBinOp(BO: *BO, Lower, Upper, IIQ, PreferSignedRange: ForSigned);
10223 CR = ConstantRange::getNonEmpty(Lower, Upper);
10224 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: V))
10225 CR = getRangeForIntrinsic(II: *II, UseInstrInfo);
10226 else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
10227 ConstantRange CRTrue = computeConstantRange(
10228 V: SI->getTrueValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10229 ConstantRange CRFalse = computeConstantRange(
10230 V: SI->getFalseValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10231 CR = CRTrue.unionWith(CR: CRFalse);
10232 CR = CR.intersectWith(CR: getRangeForSelectPattern(SI: *SI, IIQ));
10233 } else if (isa<FPToUIInst>(Val: V) || isa<FPToSIInst>(Val: V)) {
10234 APInt Lower = APInt(BitWidth, 0);
10235 APInt Upper = APInt(BitWidth, 0);
10236 // TODO: Return ConstantRange.
10237 setLimitForFPToI(I: cast<Instruction>(Val: V), Lower, Upper);
10238 CR = ConstantRange::getNonEmpty(Lower, Upper);
10239 } else if (const auto *A = dyn_cast<Argument>(Val: V))
10240 if (std::optional<ConstantRange> Range = A->getRange())
10241 CR = *Range;
10242
10243 if (auto *I = dyn_cast<Instruction>(Val: V)) {
10244 if (auto *Range = IIQ.getMetadata(I, KindID: LLVMContext::MD_range))
10245 CR = CR.intersectWith(CR: getConstantRangeFromMetadata(RangeMD: *Range));
10246
10247 if (const auto *CB = dyn_cast<CallBase>(Val: V))
10248 if (std::optional<ConstantRange> Range = CB->getRange())
10249 CR = CR.intersectWith(CR: *Range);
10250 }
10251
10252 if (CtxI && AC) {
10253 // Try to restrict the range based on information from assumptions.
10254 for (auto &AssumeVH : AC->assumptionsFor(V)) {
10255 if (!AssumeVH)
10256 continue;
10257 CallInst *I = cast<CallInst>(Val&: AssumeVH);
10258 assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
10259 "Got assumption for the wrong function!");
10260 assert(I->getIntrinsicID() == Intrinsic::assume &&
10261 "must be an assume intrinsic");
10262
10263 if (!isValidAssumeForContext(Inv: I, CxtI: CtxI, DT))
10264 continue;
10265 Value *Arg = I->getArgOperand(i: 0);
10266 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
10267 // Currently we just use information from comparisons.
10268 if (!Cmp || Cmp->getOperand(i_nocapture: 0) != V)
10269 continue;
10270 // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
10271 ConstantRange RHS =
10272 computeConstantRange(V: Cmp->getOperand(i_nocapture: 1), /* ForSigned */ false,
10273 UseInstrInfo, AC, CtxI: I, DT, Depth: Depth + 1);
10274 CR = CR.intersectWith(
10275 CR: ConstantRange::makeAllowedICmpRegion(Pred: Cmp->getPredicate(), Other: RHS));
10276 }
10277 }
10278
10279 return CR;
10280}
10281
10282static void
10283addValueAffectedByCondition(Value *V,
10284 function_ref<void(Value *)> InsertAffected) {
10285 assert(V != nullptr);
10286 if (isa<Argument>(Val: V) || isa<GlobalValue>(Val: V)) {
10287 InsertAffected(V);
10288 } else if (auto *I = dyn_cast<Instruction>(Val: V)) {
10289 InsertAffected(V);
10290
10291 // Peek through unary operators to find the source of the condition.
10292 Value *Op;
10293 if (match(V: I, P: m_CombineOr(L: m_PtrToIntOrAddr(Op: m_Value(V&: Op)),
10294 R: m_Trunc(Op: m_Value(V&: Op))))) {
10295 if (isa<Instruction>(Val: Op) || isa<Argument>(Val: Op))
10296 InsertAffected(Op);
10297 }
10298 }
10299}
10300
10301void llvm::findValuesAffectedByCondition(
10302 Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10303 auto AddAffected = [&InsertAffected](Value *V) {
10304 addValueAffectedByCondition(V, InsertAffected);
10305 };
10306
10307 auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10308 if (IsAssume) {
10309 AddAffected(LHS);
10310 AddAffected(RHS);
10311 } else if (match(V: RHS, P: m_Constant()))
10312 AddAffected(LHS);
10313 };
10314
10315 SmallVector<Value *, 8> Worklist;
10316 SmallPtrSet<Value *, 8> Visited;
10317 Worklist.push_back(Elt: Cond);
10318 while (!Worklist.empty()) {
10319 Value *V = Worklist.pop_back_val();
10320 if (!Visited.insert(Ptr: V).second)
10321 continue;
10322
10323 CmpPredicate Pred;
10324 Value *A, *B, *X;
10325
10326 if (IsAssume) {
10327 AddAffected(V);
10328 if (match(V, P: m_Not(V: m_Value(V&: X))))
10329 AddAffected(X);
10330 }
10331
10332 if (match(V, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
10333 // assume(A && B) is split to -> assume(A); assume(B);
10334 // assume(!(A || B)) is split to -> assume(!A); assume(!B);
10335 // Finally, assume(A || B) / assume(!(A && B)) generally don't provide
10336 // enough information to be worth handling (intersection of information as
10337 // opposed to union).
10338 if (!IsAssume) {
10339 Worklist.push_back(Elt: A);
10340 Worklist.push_back(Elt: B);
10341 }
10342 } else if (match(V, P: m_ICmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10343 bool HasRHSC = match(V: B, P: m_ConstantInt());
10344 if (ICmpInst::isEquality(P: Pred)) {
10345 AddAffected(A);
10346 if (IsAssume)
10347 AddAffected(B);
10348 if (HasRHSC) {
10349 Value *Y;
10350 // (X << C) or (X >>_s C) or (X >>_u C).
10351 if (match(V: A, P: m_Shift(L: m_Value(V&: X), R: m_ConstantInt())))
10352 AddAffected(X);
10353 // (X & C) or (X | C).
10354 else if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10355 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10356 AddAffected(X);
10357 AddAffected(Y);
10358 }
10359 // X - Y
10360 else if (match(V: A, P: m_Sub(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10361 AddAffected(X);
10362 AddAffected(Y);
10363 }
10364 }
10365 } else {
10366 AddCmpOperands(A, B);
10367 if (HasRHSC) {
10368 // Handle (A + C1) u< C2, which is the canonical form of
10369 // A > C3 && A < C4.
10370 if (match(V: A, P: m_AddLike(L: m_Value(V&: X), R: m_ConstantInt())))
10371 AddAffected(X);
10372
10373 if (ICmpInst::isUnsigned(predicate: Pred)) {
10374 Value *Y;
10375 // X & Y u> C -> X >u C && Y >u C
10376 // X | Y u< C -> X u< C && Y u< C
10377 // X nuw+ Y u< C -> X u< C && Y u< C
10378 if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10379 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10380 match(V: A, P: m_NUWAdd(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10381 AddAffected(X);
10382 AddAffected(Y);
10383 }
10384 // X nuw- Y u> C -> X u> C
10385 if (match(V: A, P: m_NUWSub(L: m_Value(V&: X), R: m_Value())))
10386 AddAffected(X);
10387 }
10388 }
10389
10390 // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
10391 // by computeKnownFPClass().
10392 if (match(V: A, P: m_ElementWiseBitCast(Op: m_Value(V&: X)))) {
10393 if (Pred == ICmpInst::ICMP_SLT && match(V: B, P: m_Zero()))
10394 InsertAffected(X);
10395 else if (Pred == ICmpInst::ICMP_SGT && match(V: B, P: m_AllOnes()))
10396 InsertAffected(X);
10397 }
10398 }
10399
10400 if (HasRHSC && match(V: A, P: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Value(V&: X))))
10401 AddAffected(X);
10402 } else if (match(V, P: m_FCmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10403 AddCmpOperands(A, B);
10404
10405 // fcmp fneg(x), y
10406 // fcmp fabs(x), y
10407 // fcmp fneg(fabs(x)), y
10408 if (match(V: A, P: m_FNeg(X: m_Value(V&: A))))
10409 AddAffected(A);
10410 if (match(V: A, P: m_FAbs(Op0: m_Value(V&: A))))
10411 AddAffected(A);
10412
10413 } else if (match(V, P: m_Intrinsic<Intrinsic::is_fpclass>(Op0: m_Value(V&: A),
10414 Op1: m_Value()))) {
10415 // Handle patterns that computeKnownFPClass() support.
10416 AddAffected(A);
10417 } else if (!IsAssume && match(V, P: m_Trunc(Op: m_Value(V&: X)))) {
10418 // Assume is checked here as X is already added above for assumes in
10419 // addValueAffectedByCondition
10420 AddAffected(X);
10421 } else if (!IsAssume && match(V, P: m_Not(V: m_Value(V&: X)))) {
10422 // Assume is checked here to avoid issues with ephemeral values
10423 Worklist.push_back(Elt: X);
10424 }
10425 }
10426}
10427
10428const Value *llvm::stripNullTest(const Value *V) {
10429 // (X >> C) or/add (X & mask(C) != 0)
10430 if (const auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10431 if (BO->getOpcode() == Instruction::Add ||
10432 BO->getOpcode() == Instruction::Or) {
10433 const Value *X;
10434 const APInt *C1, *C2;
10435 if (match(V: BO, P: m_c_BinOp(L: m_LShr(L: m_Value(V&: X), R: m_APInt(Res&: C1)),
10436 R: m_ZExt(Op: m_SpecificICmp(
10437 MatchPred: ICmpInst::ICMP_NE,
10438 L: m_And(L: m_Deferred(V: X), R: m_LowBitMask(V&: C2)),
10439 R: m_Zero())))) &&
10440 C2->popcount() == C1->getZExtValue())
10441 return X;
10442 }
10443 }
10444 return nullptr;
10445}
10446
10447Value *llvm::stripNullTest(Value *V) {
10448 return const_cast<Value *>(stripNullTest(V: const_cast<const Value *>(V)));
10449}
10450
10451bool llvm::collectPossibleValues(const Value *V,
10452 SmallPtrSetImpl<const Constant *> &Constants,
10453 unsigned MaxCount, bool AllowUndefOrPoison) {
10454 SmallPtrSet<const Instruction *, 8> Visited;
10455 SmallVector<const Instruction *, 8> Worklist;
10456 auto Push = [&](const Value *V) -> bool {
10457 Constant *C;
10458 if (match(V: const_cast<Value *>(V), P: m_ImmConstant(C))) {
10459 if (!AllowUndefOrPoison && !isGuaranteedNotToBeUndefOrPoison(V: C))
10460 return false;
10461 // Check existence first to avoid unnecessary allocations.
10462 if (Constants.contains(Ptr: C))
10463 return true;
10464 if (Constants.size() == MaxCount)
10465 return false;
10466 Constants.insert(Ptr: C);
10467 return true;
10468 }
10469
10470 if (auto *Inst = dyn_cast<Instruction>(Val: V)) {
10471 if (Visited.insert(Ptr: Inst).second)
10472 Worklist.push_back(Elt: Inst);
10473 return true;
10474 }
10475 return false;
10476 };
10477 if (!Push(V))
10478 return false;
10479 while (!Worklist.empty()) {
10480 const Instruction *CurInst = Worklist.pop_back_val();
10481 switch (CurInst->getOpcode()) {
10482 case Instruction::Select:
10483 if (!Push(CurInst->getOperand(i: 1)))
10484 return false;
10485 if (!Push(CurInst->getOperand(i: 2)))
10486 return false;
10487 break;
10488 case Instruction::PHI:
10489 for (Value *IncomingValue : cast<PHINode>(Val: CurInst)->incoming_values()) {
10490 // Fast path for recurrence PHI.
10491 if (IncomingValue == CurInst)
10492 continue;
10493 if (!Push(IncomingValue))
10494 return false;
10495 }
10496 break;
10497 default:
10498 return false;
10499 }
10500 }
10501 return true;
10502}
10503