1//===- ValueTracking.cpp - Walk computations to compute properties --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains routines that help analyze properties that chains of
10// computations have.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/ValueTracking.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/APInt.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/FloatingPointMode.h"
19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/ScopeExit.h"
21#include "llvm/ADT/SmallPtrSet.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/StringRef.h"
24#include "llvm/ADT/iterator_range.h"
25#include "llvm/Analysis/AliasAnalysis.h"
26#include "llvm/Analysis/AssumeBundleQueries.h"
27#include "llvm/Analysis/AssumptionCache.h"
28#include "llvm/Analysis/ConstantFolding.h"
29#include "llvm/Analysis/DomConditionCache.h"
30#include "llvm/Analysis/FloatingPointPredicateUtils.h"
31#include "llvm/Analysis/GuardUtils.h"
32#include "llvm/Analysis/InstructionSimplify.h"
33#include "llvm/Analysis/Loads.h"
34#include "llvm/Analysis/LoopInfo.h"
35#include "llvm/Analysis/TargetLibraryInfo.h"
36#include "llvm/Analysis/VectorUtils.h"
37#include "llvm/Analysis/WithCache.h"
38#include "llvm/IR/Argument.h"
39#include "llvm/IR/Attributes.h"
40#include "llvm/IR/BasicBlock.h"
41#include "llvm/IR/Constant.h"
42#include "llvm/IR/ConstantFPRange.h"
43#include "llvm/IR/ConstantRange.h"
44#include "llvm/IR/Constants.h"
45#include "llvm/IR/DerivedTypes.h"
46#include "llvm/IR/DiagnosticInfo.h"
47#include "llvm/IR/Dominators.h"
48#include "llvm/IR/EHPersonalities.h"
49#include "llvm/IR/Function.h"
50#include "llvm/IR/GetElementPtrTypeIterator.h"
51#include "llvm/IR/GlobalAlias.h"
52#include "llvm/IR/GlobalValue.h"
53#include "llvm/IR/GlobalVariable.h"
54#include "llvm/IR/InstrTypes.h"
55#include "llvm/IR/Instruction.h"
56#include "llvm/IR/Instructions.h"
57#include "llvm/IR/IntrinsicInst.h"
58#include "llvm/IR/Intrinsics.h"
59#include "llvm/IR/IntrinsicsAArch64.h"
60#include "llvm/IR/IntrinsicsAMDGPU.h"
61#include "llvm/IR/IntrinsicsRISCV.h"
62#include "llvm/IR/IntrinsicsX86.h"
63#include "llvm/IR/LLVMContext.h"
64#include "llvm/IR/Metadata.h"
65#include "llvm/IR/Module.h"
66#include "llvm/IR/Operator.h"
67#include "llvm/IR/PatternMatch.h"
68#include "llvm/IR/Type.h"
69#include "llvm/IR/User.h"
70#include "llvm/IR/Value.h"
71#include "llvm/Support/Casting.h"
72#include "llvm/Support/CommandLine.h"
73#include "llvm/Support/Compiler.h"
74#include "llvm/Support/ErrorHandling.h"
75#include "llvm/Support/KnownBits.h"
76#include "llvm/Support/KnownFPClass.h"
77#include "llvm/Support/MathExtras.h"
78#include "llvm/TargetParser/RISCVTargetParser.h"
79#include <algorithm>
80#include <cassert>
81#include <cstdint>
82#include <optional>
83#include <utility>
84
85using namespace llvm;
86using namespace llvm::PatternMatch;
87
88// Controls the number of uses of the value searched for possible
89// dominating comparisons.
90static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses",
91 cl::Hidden, cl::init(Val: 20));
92
93/// Maximum number of instructions to check between assume and context
94/// instruction.
95static constexpr unsigned MaxInstrsToCheckForFree = 16;
96
97/// Returns the bitwidth of the given scalar or pointer type. For vector types,
98/// returns the element type's bitwidth.
99static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
100 if (unsigned BitWidth = Ty->getScalarSizeInBits())
101 return BitWidth;
102
103 return DL.getPointerTypeSizeInBits(Ty);
104}
105
106// Given the provided Value and, potentially, a context instruction, return
107// the preferred context instruction (if any).
108static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) {
109 // If we've been provided with a context instruction, then use that (provided
110 // it has been inserted).
111 if (CxtI && CxtI->getParent())
112 return CxtI;
113
114 // If the value is really an already-inserted instruction, then use that.
115 CxtI = dyn_cast<Instruction>(Val: V);
116 if (CxtI && CxtI->getParent())
117 return CxtI;
118
119 return nullptr;
120}
121
122static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf,
123 const APInt &DemandedElts,
124 APInt &DemandedLHS, APInt &DemandedRHS) {
125 if (isa<ScalableVectorType>(Val: Shuf->getType())) {
126 assert(DemandedElts == APInt(1,1));
127 DemandedLHS = DemandedRHS = DemandedElts;
128 return true;
129 }
130
131 int NumElts =
132 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
133 return llvm::getShuffleDemandedElts(SrcWidth: NumElts, Mask: Shuf->getShuffleMask(),
134 DemandedElts, DemandedLHS, DemandedRHS);
135}
136
137static void computeKnownBits(const Value *V, const APInt &DemandedElts,
138 KnownBits &Known, const SimplifyQuery &Q,
139 unsigned Depth);
140
141void llvm::computeKnownBits(const Value *V, KnownBits &Known,
142 const SimplifyQuery &Q, unsigned Depth) {
143 // Since the number of lanes in a scalable vector is unknown at compile time,
144 // we track one bit which is implicitly broadcast to all lanes. This means
145 // that all lanes in a scalable vector are considered demanded.
146 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
147 APInt DemandedElts =
148 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
149 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
150}
151
152void llvm::computeKnownBits(const Value *V, KnownBits &Known,
153 const DataLayout &DL, AssumptionCache *AC,
154 const Instruction *CxtI, const DominatorTree *DT,
155 bool UseInstrInfo, unsigned Depth) {
156 computeKnownBits(V, Known,
157 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
158 Depth);
159}
160
161KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
162 AssumptionCache *AC, const Instruction *CxtI,
163 const DominatorTree *DT, bool UseInstrInfo,
164 unsigned Depth) {
165 return computeKnownBits(
166 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
167}
168
169KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
170 const DataLayout &DL, AssumptionCache *AC,
171 const Instruction *CxtI,
172 const DominatorTree *DT, bool UseInstrInfo,
173 unsigned Depth) {
174 return computeKnownBits(
175 V, DemandedElts,
176 Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
177}
178
179static bool haveNoCommonBitsSetSpecialCases(const Value *LHS, const Value *RHS,
180 const SimplifyQuery &SQ) {
181 // Look for an inverted mask: (X & ~M) op (Y & M).
182 {
183 Value *M;
184 if (match(V: LHS, P: m_c_And(L: m_Not(V: m_Value(V&: M)), R: m_Value())) &&
185 match(V: RHS, P: m_c_And(L: m_Specific(V: M), R: m_Value())) &&
186 isGuaranteedNotToBeUndef(V: M, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
187 return true;
188 }
189
190 // X op (Y & ~X)
191 if (match(V: RHS, P: m_c_And(L: m_Not(V: m_Specific(V: LHS)), R: m_Value())) &&
192 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
193 return true;
194
195 // X op ((X & Y) ^ Y) -- this is the canonical form of the previous pattern
196 // for constant Y.
197 Value *Y;
198 if (match(V: RHS,
199 P: m_c_Xor(L: m_c_And(L: m_Specific(V: LHS), R: m_Value(V&: Y)), R: m_Deferred(V: Y))) &&
200 isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
201 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
202 return true;
203
204 // Peek through extends to find a 'not' of the other side:
205 // (ext Y) op ext(~Y)
206 if (match(V: LHS, P: m_ZExtOrSExt(Op: m_Value(V&: Y))) &&
207 match(V: RHS, P: m_ZExtOrSExt(Op: m_Not(V: m_Specific(V: Y)))) &&
208 isGuaranteedNotToBeUndef(V: Y, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
209 return true;
210
211 // Look for: (A & B) op ~(A | B)
212 {
213 Value *A, *B;
214 if (match(V: LHS, P: m_And(L: m_Value(V&: A), R: m_Value(V&: B))) &&
215 match(V: RHS, P: m_Not(V: m_c_Or(L: m_Specific(V: A), R: m_Specific(V: B)))) &&
216 isGuaranteedNotToBeUndef(V: A, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT) &&
217 isGuaranteedNotToBeUndef(V: B, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
218 return true;
219 }
220
221 // Look for: (X << V) op (Y >> (BitWidth - V))
222 // or (X >> V) op (Y << (BitWidth - V))
223 {
224 const Value *V;
225 const APInt *R;
226 if (((match(V: RHS, P: m_Shl(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
227 match(V: LHS, P: m_LShr(L: m_Value(), R: m_Specific(V)))) ||
228 (match(V: RHS, P: m_LShr(L: m_Value(), R: m_Sub(L: m_APInt(Res&: R), R: m_Value(V)))) &&
229 match(V: LHS, P: m_Shl(L: m_Value(), R: m_Specific(V))))) &&
230 R->uge(RHS: LHS->getType()->getScalarSizeInBits()))
231 return true;
232 }
233
234 return false;
235}
236
237bool llvm::haveNoCommonBitsSet(const WithCache<const Value *> &LHSCache,
238 const WithCache<const Value *> &RHSCache,
239 const SimplifyQuery &SQ) {
240 const Value *LHS = LHSCache.getValue();
241 const Value *RHS = RHSCache.getValue();
242
243 assert(LHS->getType() == RHS->getType() &&
244 "LHS and RHS should have the same type");
245 assert(LHS->getType()->isIntOrIntVectorTy() &&
246 "LHS and RHS should be integers");
247
248 if (haveNoCommonBitsSetSpecialCases(LHS, RHS, SQ) ||
249 haveNoCommonBitsSetSpecialCases(LHS: RHS, RHS: LHS, SQ))
250 return true;
251
252 return KnownBits::haveNoCommonBitsSet(LHS: LHSCache.getKnownBits(Q: SQ),
253 RHS: RHSCache.getKnownBits(Q: SQ));
254}
255
256bool llvm::isOnlyUsedInZeroComparison(const Instruction *I) {
257 return !I->user_empty() &&
258 all_of(Range: I->users(), P: match_fn(P: m_ICmp(L: m_Value(), R: m_Zero())));
259}
260
261bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *I) {
262 return !I->user_empty() && all_of(Range: I->users(), P: [](const User *U) {
263 CmpPredicate P;
264 return match(V: U, P: m_ICmp(Pred&: P, L: m_Value(), R: m_Zero())) && ICmpInst::isEquality(P);
265 });
266}
267
268bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL,
269 bool OrZero, AssumptionCache *AC,
270 const Instruction *CxtI,
271 const DominatorTree *DT, bool UseInstrInfo,
272 unsigned Depth) {
273 return ::isKnownToBeAPowerOfTwo(
274 V, OrZero, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo),
275 Depth);
276}
277
278static bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
279 const SimplifyQuery &Q, unsigned Depth);
280
281bool llvm::isKnownNonNegative(const Value *V, const SimplifyQuery &SQ,
282 unsigned Depth) {
283 return computeKnownBits(V, Q: SQ, Depth).isNonNegative();
284}
285
286bool llvm::isKnownPositive(const Value *V, const SimplifyQuery &SQ,
287 unsigned Depth) {
288 if (auto *CI = dyn_cast<ConstantInt>(Val: V))
289 return CI->getValue().isStrictlyPositive();
290
291 // If `isKnownNonNegative` ever becomes more sophisticated, make sure to keep
292 // this updated.
293 KnownBits Known = computeKnownBits(V, Q: SQ, Depth);
294 return Known.isNonNegative() &&
295 (Known.isNonZero() || isKnownNonZero(V, Q: SQ, Depth));
296}
297
298bool llvm::isKnownNegative(const Value *V, const SimplifyQuery &SQ,
299 unsigned Depth) {
300 return computeKnownBits(V, Q: SQ, Depth).isNegative();
301}
302
303static bool isKnownNonEqual(const Value *V1, const Value *V2,
304 const APInt &DemandedElts, const SimplifyQuery &Q,
305 unsigned Depth);
306
307bool llvm::isKnownNonEqual(const Value *V1, const Value *V2,
308 const SimplifyQuery &Q, unsigned Depth) {
309 // We don't support looking through casts.
310 if (V1 == V2 || V1->getType() != V2->getType())
311 return false;
312 auto *FVTy = dyn_cast<FixedVectorType>(Val: V1->getType());
313 APInt DemandedElts =
314 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
315 return ::isKnownNonEqual(V1, V2, DemandedElts, Q, Depth);
316}
317
318bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask,
319 const SimplifyQuery &SQ, unsigned Depth) {
320 KnownBits Known(Mask.getBitWidth());
321 computeKnownBits(V, Known, Q: SQ, Depth);
322 return Mask.isSubsetOf(RHS: Known.Zero);
323}
324
325static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
326 const SimplifyQuery &Q, unsigned Depth);
327
328static unsigned ComputeNumSignBits(const Value *V, const SimplifyQuery &Q,
329 unsigned Depth = 0) {
330 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
331 APInt DemandedElts =
332 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
333 return ComputeNumSignBits(V, DemandedElts, Q, Depth);
334}
335
336unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL,
337 AssumptionCache *AC, const Instruction *CxtI,
338 const DominatorTree *DT, bool UseInstrInfo,
339 unsigned Depth) {
340 return ::ComputeNumSignBits(
341 V, Q: SimplifyQuery(DL, DT, AC, safeCxtI(V, CxtI), UseInstrInfo), Depth);
342}
343
344unsigned llvm::ComputeMaxSignificantBits(const Value *V, const DataLayout &DL,
345 AssumptionCache *AC,
346 const Instruction *CxtI,
347 const DominatorTree *DT,
348 unsigned Depth) {
349 unsigned SignBits = ComputeNumSignBits(V, DL, AC, CxtI, DT, UseInstrInfo: Depth);
350 return V->getType()->getScalarSizeInBits() - SignBits + 1;
351}
352
353/// Try to detect the lerp pattern: a * (b - c) + c * d
354/// where a >= 0, b >= 0, c >= 0, d >= 0, and b >= c.
355///
356/// In that particular case, we can use the following chain of reasoning:
357///
358/// a * (b - c) + c * d <= a' * (b - c) + a' * c = a' * b where a' = max(a, d)
359///
360/// Since that is true for arbitrary a, b, c and d within our constraints, we
361/// can conclude that:
362///
363/// max(a * (b - c) + c * d) <= max(max(a), max(d)) * max(b) = U
364///
365/// Considering that any result of the lerp would be less or equal to U, it
366/// would have at least the number of leading 0s as in U.
367///
368/// While being quite a specific situation, it is fairly common in computer
369/// graphics in the shape of alpha blending.
370///
371/// Modifies given KnownOut in-place with the inferred information.
372static void computeKnownBitsFromLerpPattern(const Value *Op0, const Value *Op1,
373 const APInt &DemandedElts,
374 KnownBits &KnownOut,
375 const SimplifyQuery &Q,
376 unsigned Depth) {
377
378 Type *Ty = Op0->getType();
379 const unsigned BitWidth = Ty->getScalarSizeInBits();
380
381 // Only handle scalar types for now
382 if (Ty->isVectorTy())
383 return;
384
385 // Try to match: a * (b - c) + c * d.
386 // When a == 1 => A == nullptr, the same applies to d/D as well.
387 const Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
388 const Instruction *SubBC = nullptr;
389
390 const auto MatchSubBC = [&]() {
391 // (b - c) can have two forms that interest us:
392 //
393 // 1. sub nuw %b, %c
394 // 2. xor %c, %b
395 //
396 // For the first case, nuw flag guarantees our requirement b >= c.
397 //
398 // The second case might happen when the analysis can infer that b is a mask
399 // for c and we can transform sub operation into xor (that is usually true
400 // for constant b's). Even though xor is symmetrical, canonicalization
401 // ensures that the constant will be the RHS. We have additional checks
402 // later on to ensure that this xor operation is equivalent to subtraction.
403 return m_Instruction(I&: SubBC, Match: m_CombineOr(L: m_NUWSub(L: m_Value(V&: B), R: m_Value(V&: C)),
404 R: m_Xor(L: m_Value(V&: C), R: m_Value(V&: B))));
405 };
406
407 const auto MatchASubBC = [&]() {
408 // Cases:
409 // - a * (b - c)
410 // - (b - c) * a
411 // - (b - c) <- a implicitly equals 1
412 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: A), R: MatchSubBC()), R: MatchSubBC());
413 };
414
415 const auto MatchCD = [&]() {
416 // Cases:
417 // - d * c
418 // - c * d
419 // - c <- d implicitly equals 1
420 return m_CombineOr(L: m_c_Mul(L: m_Value(V&: D), R: m_Specific(V: C)), R: m_Specific(V: C));
421 };
422
423 const auto Match = [&](const Value *LHS, const Value *RHS) {
424 // We do use m_Specific(C) in MatchCD, so we have to make sure that
425 // it's bound to anything and match(LHS, MatchASubBC()) absolutely
426 // has to evaluate first and return true.
427 //
428 // If Match returns true, it is guaranteed that B != nullptr, C != nullptr.
429 return match(V: LHS, P: MatchASubBC()) && match(V: RHS, P: MatchCD());
430 };
431
432 if (!Match(Op0, Op1) && !Match(Op1, Op0))
433 return;
434
435 const auto ComputeKnownBitsOrOne = [&](const Value *V) {
436 // For some of the values we use the convention of leaving
437 // it nullptr to signify an implicit constant 1.
438 return V ? computeKnownBits(V, DemandedElts, Q, Depth: Depth + 1)
439 : KnownBits::makeConstant(C: APInt(BitWidth, 1));
440 };
441
442 // Check that all operands are non-negative
443 const KnownBits KnownA = ComputeKnownBitsOrOne(A);
444 if (!KnownA.isNonNegative())
445 return;
446
447 const KnownBits KnownD = ComputeKnownBitsOrOne(D);
448 if (!KnownD.isNonNegative())
449 return;
450
451 const KnownBits KnownB = computeKnownBits(V: B, DemandedElts, Q, Depth: Depth + 1);
452 if (!KnownB.isNonNegative())
453 return;
454
455 const KnownBits KnownC = computeKnownBits(V: C, DemandedElts, Q, Depth: Depth + 1);
456 if (!KnownC.isNonNegative())
457 return;
458
459 // If we matched subtraction as xor, we need to actually check that xor
460 // is semantically equivalent to subtraction.
461 //
462 // For that to be true, b has to be a mask for c or that b's known
463 // ones cover all known and possible ones of c.
464 if (SubBC->getOpcode() == Instruction::Xor &&
465 !KnownC.getMaxValue().isSubsetOf(RHS: KnownB.getMinValue()))
466 return;
467
468 const APInt MaxA = KnownA.getMaxValue();
469 const APInt MaxD = KnownD.getMaxValue();
470 const APInt MaxAD = APIntOps::umax(A: MaxA, B: MaxD);
471 const APInt MaxB = KnownB.getMaxValue();
472
473 // We can't infer leading zeros info if the upper-bound estimate wraps.
474 bool Overflow;
475 const APInt UpperBound = MaxAD.umul_ov(RHS: MaxB, Overflow);
476
477 if (Overflow)
478 return;
479
480 // If we know that x <= y and both are positive than x has at least the same
481 // number of leading zeros as y.
482 const unsigned MinimumNumberOfLeadingZeros = UpperBound.countl_zero();
483 KnownOut.Zero.setHighBits(MinimumNumberOfLeadingZeros);
484}
485
486static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1,
487 bool NSW, bool NUW,
488 const APInt &DemandedElts,
489 KnownBits &KnownOut, KnownBits &Known2,
490 const SimplifyQuery &Q, unsigned Depth) {
491 computeKnownBits(V: Op1, DemandedElts, Known&: KnownOut, Q, Depth: Depth + 1);
492
493 // If one operand is unknown and we have no nowrap information,
494 // the result will be unknown independently of the second operand.
495 if (KnownOut.isUnknown() && !NSW && !NUW)
496 return;
497
498 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
499 KnownOut = KnownBits::computeForAddSub(Add, NSW, NUW, LHS: Known2, RHS: KnownOut);
500
501 if (!Add && NSW && !KnownOut.isNonNegative() &&
502 isImpliedByDomCondition(Pred: ICmpInst::ICMP_SLE, LHS: Op1, RHS: Op0, ContextI: Q.CxtI, DL: Q.DL)
503 .value_or(u: false))
504 KnownOut.makeNonNegative();
505
506 if (Add)
507 // Try to match lerp pattern and combine results
508 computeKnownBitsFromLerpPattern(Op0, Op1, DemandedElts, KnownOut, Q, Depth);
509}
510
511static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW,
512 bool NUW, const APInt &DemandedElts,
513 KnownBits &Known, KnownBits &Known2,
514 const SimplifyQuery &Q, unsigned Depth) {
515 computeKnownBits(V: Op1, DemandedElts, Known, Q, Depth: Depth + 1);
516 computeKnownBits(V: Op0, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
517
518 bool isKnownNegative = false;
519 bool isKnownNonNegative = false;
520 // If the multiplication is known not to overflow, compute the sign bit.
521 if (NSW) {
522 if (Op0 == Op1) {
523 // The product of a number with itself is non-negative.
524 isKnownNonNegative = true;
525 } else {
526 bool isKnownNonNegativeOp1 = Known.isNonNegative();
527 bool isKnownNonNegativeOp0 = Known2.isNonNegative();
528 bool isKnownNegativeOp1 = Known.isNegative();
529 bool isKnownNegativeOp0 = Known2.isNegative();
530 // The product of two numbers with the same sign is non-negative.
531 isKnownNonNegative = (isKnownNegativeOp1 && isKnownNegativeOp0) ||
532 (isKnownNonNegativeOp1 && isKnownNonNegativeOp0);
533 if (!isKnownNonNegative && NUW) {
534 // mul nuw nsw with a factor > 1 is non-negative.
535 KnownBits One = KnownBits::makeConstant(C: APInt(Known.getBitWidth(), 1));
536 isKnownNonNegative = KnownBits::sgt(LHS: Known, RHS: One).value_or(u: false) ||
537 KnownBits::sgt(LHS: Known2, RHS: One).value_or(u: false);
538 }
539
540 // The product of a negative number and a non-negative number is either
541 // negative or zero.
542 if (!isKnownNonNegative)
543 isKnownNegative =
544 (isKnownNegativeOp1 && isKnownNonNegativeOp0 &&
545 Known2.isNonZero()) ||
546 (isKnownNegativeOp0 && isKnownNonNegativeOp1 && Known.isNonZero());
547 }
548 }
549
550 bool SelfMultiply = Op0 == Op1;
551 if (SelfMultiply)
552 SelfMultiply &=
553 isGuaranteedNotToBeUndef(V: Op0, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1);
554 Known = KnownBits::mul(LHS: Known, RHS: Known2, NoUndefSelfMultiply: SelfMultiply);
555
556 if (SelfMultiply) {
557 unsigned SignBits = ComputeNumSignBits(V: Op0, DemandedElts, Q, Depth: Depth + 1);
558 unsigned TyBits = Op0->getType()->getScalarSizeInBits();
559 unsigned OutValidBits = 2 * (TyBits - SignBits + 1);
560
561 if (OutValidBits < TyBits) {
562 APInt KnownZeroMask =
563 APInt::getHighBitsSet(numBits: TyBits, hiBitsSet: TyBits - OutValidBits + 1);
564 Known.Zero |= KnownZeroMask;
565 }
566 }
567
568 // Only make use of no-wrap flags if we failed to compute the sign bit
569 // directly. This matters if the multiplication always overflows, in
570 // which case we prefer to follow the result of the direct computation,
571 // though as the program is invoking undefined behaviour we can choose
572 // whatever we like here.
573 if (isKnownNonNegative && !Known.isNegative())
574 Known.makeNonNegative();
575 else if (isKnownNegative && !Known.isNonNegative())
576 Known.makeNegative();
577}
578
579void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges,
580 KnownBits &Known) {
581 unsigned BitWidth = Known.getBitWidth();
582 unsigned NumRanges = Ranges.getNumOperands() / 2;
583 assert(NumRanges >= 1);
584
585 Known.setAllConflict();
586
587 for (unsigned i = 0; i < NumRanges; ++i) {
588 ConstantInt *Lower =
589 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 0));
590 ConstantInt *Upper =
591 mdconst::extract<ConstantInt>(MD: Ranges.getOperand(I: 2 * i + 1));
592 ConstantRange Range(Lower->getValue(), Upper->getValue());
593 // BitWidth must equal the Ranges BitWidth for the correct number of high
594 // bits to be set.
595 assert(BitWidth == Range.getBitWidth() &&
596 "Known bit width must match range bit width!");
597
598 // The first CommonPrefixBits of all values in Range are equal.
599 unsigned CommonPrefixBits =
600 (Range.getUnsignedMax() ^ Range.getUnsignedMin()).countl_zero();
601 APInt Mask = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: CommonPrefixBits);
602 APInt UnsignedMax = Range.getUnsignedMax().zextOrTrunc(width: BitWidth);
603 Known.One &= UnsignedMax & Mask;
604 Known.Zero &= ~UnsignedMax & Mask;
605 }
606}
607
608static bool isEphemeralValueOf(const Instruction *I, const Value *E) {
609 SmallVector<const Instruction *, 16> WorkSet(1, I);
610 SmallPtrSet<const Instruction *, 32> Visited;
611 SmallPtrSet<const Instruction *, 16> EphValues;
612
613 // The instruction defining an assumption's condition itself is always
614 // considered ephemeral to that assumption (even if it has other
615 // non-ephemeral users). See r246696's test case for an example.
616 if (is_contained(Range: I->operands(), Element: E))
617 return true;
618
619 while (!WorkSet.empty()) {
620 const Instruction *V = WorkSet.pop_back_val();
621 if (!Visited.insert(Ptr: V).second)
622 continue;
623
624 // If all uses of this value are ephemeral, then so is this value.
625 if (all_of(Range: V->users(), P: [&](const User *U) {
626 return EphValues.count(Ptr: cast<Instruction>(Val: U));
627 })) {
628 if (V == E)
629 return true;
630
631 if (V == I || (!V->mayHaveSideEffects() && !V->isTerminator())) {
632 EphValues.insert(Ptr: V);
633
634 if (const User *U = dyn_cast<User>(Val: V)) {
635 for (const Use &U : U->operands()) {
636 if (const auto *I = dyn_cast<Instruction>(Val: U.get()))
637 WorkSet.push_back(Elt: I);
638 }
639 }
640 }
641 }
642 }
643
644 return false;
645}
646
647// Is this an intrinsic that cannot be speculated but also cannot trap?
648bool llvm::isAssumeLikeIntrinsic(const Instruction *I) {
649 if (const IntrinsicInst *CI = dyn_cast<IntrinsicInst>(Val: I))
650 return CI->isAssumeLikeIntrinsic();
651
652 return false;
653}
654
655bool llvm::isValidAssumeForContext(const Instruction *Inv,
656 const Instruction *CxtI,
657 const DominatorTree *DT,
658 bool AllowEphemerals) {
659 // There are two restrictions on the use of an assume:
660 // 1. The assume must dominate the context (or the control flow must
661 // reach the assume whenever it reaches the context).
662 // 2. The context must not be in the assume's set of ephemeral values
663 // (otherwise we will use the assume to prove that the condition
664 // feeding the assume is trivially true, thus causing the removal of
665 // the assume).
666
667 if (Inv->getParent() == CxtI->getParent()) {
668 // If Inv and CtxI are in the same block, check if the assume (Inv) is first
669 // in the BB.
670 if (Inv->comesBefore(Other: CxtI))
671 return true;
672
673 // Don't let an assume affect itself - this would cause the problems
674 // `isEphemeralValueOf` is trying to prevent, and it would also make
675 // the loop below go out of bounds.
676 if (!AllowEphemerals && Inv == CxtI)
677 return false;
678
679 // The context comes first, but they're both in the same block.
680 // Make sure there is nothing in between that might interrupt
681 // the control flow, not even CxtI itself.
682 // We limit the scan distance between the assume and its context instruction
683 // to avoid a compile-time explosion. This limit is chosen arbitrarily, so
684 // it can be adjusted if needed (could be turned into a cl::opt).
685 auto Range = make_range(x: CxtI->getIterator(), y: Inv->getIterator());
686 if (!isGuaranteedToTransferExecutionToSuccessor(Range, ScanLimit: 15))
687 return false;
688
689 return AllowEphemerals || !isEphemeralValueOf(I: Inv, E: CxtI);
690 }
691
692 // Inv and CxtI are in different blocks.
693 if (DT) {
694 if (DT->dominates(Def: Inv, User: CxtI))
695 return true;
696 } else if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor() ||
697 Inv->getParent()->isEntryBlock()) {
698 // We don't have a DT, but this trivially dominates.
699 return true;
700 }
701
702 return false;
703}
704
705bool llvm::willNotFreeBetween(const Instruction *Assume,
706 const Instruction *CtxI) {
707 // Helper to check if there are any calls in the range that may free memory.
708 auto hasNoFreeCalls = [](auto Range) {
709 for (const auto &[Idx, I] : enumerate(Range)) {
710 if (Idx > MaxInstrsToCheckForFree)
711 return false;
712 if (const auto *CB = dyn_cast<CallBase>(&I))
713 if (!CB->hasFnAttr(Attribute::NoFree))
714 return false;
715 }
716 return true;
717 };
718
719 // Make sure the current function cannot arrange for another thread to free on
720 // its behalf.
721 if (!CtxI->getFunction()->hasNoSync())
722 return false;
723
724 // Handle cross-block case: CtxI in a successor of Assume's block.
725 const BasicBlock *CtxBB = CtxI->getParent();
726 const BasicBlock *AssumeBB = Assume->getParent();
727 BasicBlock::const_iterator CtxIter = CtxI->getIterator();
728 if (CtxBB != AssumeBB) {
729 if (CtxBB->getSinglePredecessor() != AssumeBB)
730 return false;
731
732 if (!hasNoFreeCalls(make_range(x: CtxBB->begin(), y: CtxIter)))
733 return false;
734
735 CtxIter = AssumeBB->end();
736 } else {
737 // Same block case: check that Assume comes before CtxI.
738 if (!Assume->comesBefore(Other: CtxI))
739 return false;
740 }
741
742 // Check if there are any calls between Assume and CtxIter that may free
743 // memory.
744 return hasNoFreeCalls(make_range(x: Assume->getIterator(), y: CtxIter));
745}
746
747// TODO: cmpExcludesZero misses many cases where `RHS` is non-constant but
748// we still have enough information about `RHS` to conclude non-zero. For
749// example Pred=EQ, RHS=isKnownNonZero. cmpExcludesZero is called in loops
750// so the extra compile time may not be worth it, but possibly a second API
751// should be created for use outside of loops.
752static bool cmpExcludesZero(CmpInst::Predicate Pred, const Value *RHS) {
753 // v u> y implies v != 0.
754 if (Pred == ICmpInst::ICMP_UGT)
755 return true;
756
757 // Special-case v != 0 to also handle v != null.
758 if (Pred == ICmpInst::ICMP_NE)
759 return match(V: RHS, P: m_Zero());
760
761 // All other predicates - rely on generic ConstantRange handling.
762 const APInt *C;
763 auto Zero = APInt::getZero(numBits: RHS->getType()->getScalarSizeInBits());
764 if (match(V: RHS, P: m_APInt(Res&: C))) {
765 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(Pred, Other: *C);
766 return !TrueValues.contains(Val: Zero);
767 }
768
769 auto *VC = dyn_cast<ConstantDataVector>(Val: RHS);
770 if (VC == nullptr)
771 return false;
772
773 for (unsigned ElemIdx = 0, NElem = VC->getNumElements(); ElemIdx < NElem;
774 ++ElemIdx) {
775 ConstantRange TrueValues = ConstantRange::makeExactICmpRegion(
776 Pred, Other: VC->getElementAsAPInt(i: ElemIdx));
777 if (TrueValues.contains(Val: Zero))
778 return false;
779 }
780 return true;
781}
782
783static void breakSelfRecursivePHI(const Use *U, const PHINode *PHI,
784 Value *&ValOut, Instruction *&CtxIOut,
785 const PHINode **PhiOut = nullptr) {
786 ValOut = U->get();
787 if (ValOut == PHI)
788 return;
789 CtxIOut = PHI->getIncomingBlock(U: *U)->getTerminator();
790 if (PhiOut)
791 *PhiOut = PHI;
792 Value *V;
793 // If the Use is a select of this phi, compute analysis on other arm to break
794 // recursion.
795 // TODO: Min/Max
796 if (match(V: ValOut, P: m_Select(C: m_Value(), L: m_Specific(V: PHI), R: m_Value(V))) ||
797 match(V: ValOut, P: m_Select(C: m_Value(), L: m_Value(V), R: m_Specific(V: PHI))))
798 ValOut = V;
799
800 // Same for select, if this phi is 2-operand phi, compute analysis on other
801 // incoming value to break recursion.
802 // TODO: We could handle any number of incoming edges as long as we only have
803 // two unique values.
804 if (auto *IncPhi = dyn_cast<PHINode>(Val: ValOut);
805 IncPhi && IncPhi->getNumIncomingValues() == 2) {
806 for (int Idx = 0; Idx < 2; ++Idx) {
807 if (IncPhi->getIncomingValue(i: Idx) == PHI) {
808 ValOut = IncPhi->getIncomingValue(i: 1 - Idx);
809 if (PhiOut)
810 *PhiOut = IncPhi;
811 CtxIOut = IncPhi->getIncomingBlock(i: 1 - Idx)->getTerminator();
812 break;
813 }
814 }
815 }
816}
817
818static bool isKnownNonZeroFromAssume(const Value *V, const SimplifyQuery &Q) {
819 // Use of assumptions is context-sensitive. If we don't have a context, we
820 // cannot use them!
821 if (!Q.AC || !Q.CxtI)
822 return false;
823
824 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
825 if (!Elem.Assume)
826 continue;
827
828 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
829 assert(I->getFunction() == Q.CxtI->getFunction() &&
830 "Got assumption for the wrong function!");
831
832 if (Elem.Index != AssumptionCache::ExprResultIdx) {
833 if (!V->getType()->isPointerTy())
834 continue;
835 if (RetainedKnowledge RK = getKnowledgeFromBundle(
836 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
837 if (RK.WasOn != V)
838 continue;
839 bool AssumeImpliesNonNull = [&]() {
840 if (RK.AttrKind == Attribute::NonNull)
841 return true;
842
843 if (RK.AttrKind == Attribute::Dereferenceable) {
844 if (NullPointerIsDefined(F: Q.CxtI->getFunction(),
845 AS: V->getType()->getPointerAddressSpace()))
846 return false;
847 assert(RK.IRArgValue &&
848 "Dereferenceable attribute without IR argument?");
849
850 auto *CI = dyn_cast<ConstantInt>(Val: RK.IRArgValue);
851 return CI && !CI->isZero();
852 }
853
854 return false;
855 }();
856 if (AssumeImpliesNonNull && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
857 return true;
858 }
859 continue;
860 }
861
862 // Warning: This loop can end up being somewhat performance sensitive.
863 // We're running this loop for once for each value queried resulting in a
864 // runtime of ~O(#assumes * #values).
865
866 Value *RHS;
867 CmpPredicate Pred;
868 auto m_V = m_CombineOr(L: m_Specific(V), R: m_PtrToInt(Op: m_Specific(V)));
869 if (!match(V: I->getArgOperand(i: 0), P: m_c_ICmp(Pred, L: m_V, R: m_Value(V&: RHS))))
870 continue;
871
872 if (cmpExcludesZero(Pred, RHS) && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
873 return true;
874 }
875
876 return false;
877}
878
879static void computeKnownBitsFromCmp(const Value *V, CmpInst::Predicate Pred,
880 Value *LHS, Value *RHS, KnownBits &Known,
881 const SimplifyQuery &Q) {
882 if (RHS->getType()->isPointerTy()) {
883 // Handle comparison of pointer to null explicitly, as it will not be
884 // covered by the m_APInt() logic below.
885 if (LHS == V && match(V: RHS, P: m_Zero())) {
886 switch (Pred) {
887 case ICmpInst::ICMP_EQ:
888 Known.setAllZero();
889 break;
890 case ICmpInst::ICMP_SGE:
891 case ICmpInst::ICMP_SGT:
892 Known.makeNonNegative();
893 break;
894 case ICmpInst::ICMP_SLT:
895 Known.makeNegative();
896 break;
897 default:
898 break;
899 }
900 }
901 return;
902 }
903
904 unsigned BitWidth = Known.getBitWidth();
905 auto m_V =
906 m_CombineOr(L: m_Specific(V), R: m_PtrToIntSameSize(DL: Q.DL, Op: m_Specific(V)));
907
908 Value *Y;
909 const APInt *Mask, *C;
910 if (!match(V: RHS, P: m_APInt(Res&: C)))
911 return;
912
913 uint64_t ShAmt;
914 switch (Pred) {
915 case ICmpInst::ICMP_EQ:
916 // assume(V = C)
917 if (match(V: LHS, P: m_V)) {
918 Known = Known.unionWith(RHS: KnownBits::makeConstant(C: *C));
919 // assume(V & Mask = C)
920 } else if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value(V&: Y)))) {
921 // For one bits in Mask, we can propagate bits from C to V.
922 Known.One |= *C;
923 if (match(V: Y, P: m_APInt(Res&: Mask)))
924 Known.Zero |= ~*C & *Mask;
925 // assume(V | Mask = C)
926 } else if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value(V&: Y)))) {
927 // For zero bits in Mask, we can propagate bits from C to V.
928 Known.Zero |= ~*C;
929 if (match(V: Y, P: m_APInt(Res&: Mask)))
930 Known.One |= *C & ~*Mask;
931 // assume(V << ShAmt = C)
932 } else if (match(V: LHS, P: m_Shl(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
933 ShAmt < BitWidth) {
934 // For those bits in C that are known, we can propagate them to known
935 // bits in V shifted to the right by ShAmt.
936 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
937 RHSKnown >>= ShAmt;
938 Known = Known.unionWith(RHS: RHSKnown);
939 // assume(V >> ShAmt = C)
940 } else if (match(V: LHS, P: m_Shr(L: m_V, R: m_ConstantInt(V&: ShAmt))) &&
941 ShAmt < BitWidth) {
942 // For those bits in RHS that are known, we can propagate them to known
943 // bits in V shifted to the right by C.
944 KnownBits RHSKnown = KnownBits::makeConstant(C: *C);
945 RHSKnown <<= ShAmt;
946 Known = Known.unionWith(RHS: RHSKnown);
947 }
948 break;
949 case ICmpInst::ICMP_NE: {
950 // assume (V & B != 0) where B is a power of 2
951 const APInt *BPow2;
952 if (C->isZero() && match(V: LHS, P: m_And(L: m_V, R: m_Power2(V&: BPow2))))
953 Known.One |= *BPow2;
954 break;
955 }
956 default: {
957 const APInt *Offset = nullptr;
958 if (match(V: LHS, P: m_CombineOr(L: m_V, R: m_AddLike(L: m_V, R: m_APInt(Res&: Offset))))) {
959 ConstantRange LHSRange = ConstantRange::makeAllowedICmpRegion(Pred, Other: *C);
960 if (Offset)
961 LHSRange = LHSRange.sub(Other: *Offset);
962 Known = Known.unionWith(RHS: LHSRange.toKnownBits());
963 }
964 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
965 // X & Y u> C -> X u> C && Y u> C
966 // X nuw- Y u> C -> X u> C
967 if (match(V: LHS, P: m_c_And(L: m_V, R: m_Value())) ||
968 match(V: LHS, P: m_NUWSub(L: m_V, R: m_Value())))
969 Known.One.setHighBits(
970 (*C + (Pred == ICmpInst::ICMP_UGT)).countLeadingOnes());
971 }
972 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
973 // X | Y u< C -> X u< C && Y u< C
974 // X nuw+ Y u< C -> X u< C && Y u< C
975 if (match(V: LHS, P: m_c_Or(L: m_V, R: m_Value())) ||
976 match(V: LHS, P: m_c_NUWAdd(L: m_V, R: m_Value()))) {
977 Known.Zero.setHighBits(
978 (*C - (Pred == ICmpInst::ICMP_ULT)).countLeadingZeros());
979 }
980 }
981 } break;
982 }
983}
984
985static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
986 KnownBits &Known,
987 const SimplifyQuery &SQ, bool Invert) {
988 ICmpInst::Predicate Pred =
989 Invert ? Cmp->getInversePredicate() : Cmp->getPredicate();
990 Value *LHS = Cmp->getOperand(i_nocapture: 0);
991 Value *RHS = Cmp->getOperand(i_nocapture: 1);
992
993 // Handle icmp pred (trunc V), C
994 if (match(V: LHS, P: m_Trunc(Op: m_Specific(V)))) {
995 KnownBits DstKnown(LHS->getType()->getScalarSizeInBits());
996 computeKnownBitsFromCmp(V: LHS, Pred, LHS, RHS, Known&: DstKnown, Q: SQ);
997 if (cast<TruncInst>(Val: LHS)->hasNoUnsignedWrap())
998 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
999 else
1000 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1001 return;
1002 }
1003
1004 computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, Q: SQ);
1005}
1006
1007static void computeKnownBitsFromCond(const Value *V, Value *Cond,
1008 KnownBits &Known, const SimplifyQuery &SQ,
1009 bool Invert, unsigned Depth) {
1010 Value *A, *B;
1011 if (Depth < MaxAnalysisRecursionDepth &&
1012 match(V: Cond, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
1013 KnownBits Known2(Known.getBitWidth());
1014 KnownBits Known3(Known.getBitWidth());
1015 computeKnownBitsFromCond(V, Cond: A, Known&: Known2, SQ, Invert, Depth: Depth + 1);
1016 computeKnownBitsFromCond(V, Cond: B, Known&: Known3, SQ, Invert, Depth: Depth + 1);
1017 if (Invert ? match(V: Cond, P: m_LogicalOr(L: m_Value(), R: m_Value()))
1018 : match(V: Cond, P: m_LogicalAnd(L: m_Value(), R: m_Value())))
1019 Known2 = Known2.unionWith(RHS: Known3);
1020 else
1021 Known2 = Known2.intersectWith(RHS: Known3);
1022 Known = Known.unionWith(RHS: Known2);
1023 return;
1024 }
1025
1026 if (auto *Cmp = dyn_cast<ICmpInst>(Val: Cond)) {
1027 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ, Invert);
1028 return;
1029 }
1030
1031 if (match(V: Cond, P: m_Trunc(Op: m_Specific(V)))) {
1032 KnownBits DstKnown(1);
1033 if (Invert) {
1034 DstKnown.setAllZero();
1035 } else {
1036 DstKnown.setAllOnes();
1037 }
1038 if (cast<TruncInst>(Val: Cond)->hasNoUnsignedWrap()) {
1039 Known = Known.unionWith(RHS: DstKnown.zext(BitWidth: Known.getBitWidth()));
1040 return;
1041 }
1042 Known = Known.unionWith(RHS: DstKnown.anyext(BitWidth: Known.getBitWidth()));
1043 return;
1044 }
1045
1046 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A))))
1047 computeKnownBitsFromCond(V, Cond: A, Known, SQ, Invert: !Invert, Depth: Depth + 1);
1048}
1049
1050void llvm::computeKnownBitsFromContext(const Value *V, KnownBits &Known,
1051 const SimplifyQuery &Q, unsigned Depth) {
1052 // Handle injected condition.
1053 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
1054 computeKnownBitsFromCond(V, Cond: Q.CC->Cond, Known, SQ: Q, Invert: Q.CC->Invert, Depth);
1055
1056 if (!Q.CxtI)
1057 return;
1058
1059 if (Q.DC && Q.DT) {
1060 // Handle dominating conditions.
1061 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
1062 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
1063 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
1064 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1065 /*Invert*/ false, Depth);
1066
1067 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
1068 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
1069 computeKnownBitsFromCond(V, Cond: BI->getCondition(), Known, SQ: Q,
1070 /*Invert*/ true, Depth);
1071 }
1072
1073 if (Known.hasConflict())
1074 Known.resetAll();
1075 }
1076
1077 if (!Q.AC)
1078 return;
1079
1080 unsigned BitWidth = Known.getBitWidth();
1081
1082 // Note that the patterns below need to be kept in sync with the code
1083 // in AssumptionCache::updateAffectedValues.
1084
1085 for (AssumptionCache::ResultElem &Elem : Q.AC->assumptionsFor(V)) {
1086 if (!Elem.Assume)
1087 continue;
1088
1089 AssumeInst *I = cast<AssumeInst>(Val&: Elem.Assume);
1090 assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() &&
1091 "Got assumption for the wrong function!");
1092
1093 if (Elem.Index != AssumptionCache::ExprResultIdx) {
1094 if (!V->getType()->isPointerTy())
1095 continue;
1096 if (RetainedKnowledge RK = getKnowledgeFromBundle(
1097 Assume&: *I, BOI: I->bundle_op_info_begin()[Elem.Index])) {
1098 // Allow AllowEphemerals in isValidAssumeForContext, as the CxtI might
1099 // be the producer of the pointer in the bundle. At the moment, align
1100 // assumptions aren't optimized away.
1101 if (RK.WasOn == V && RK.AttrKind == Attribute::Alignment &&
1102 isPowerOf2_64(Value: RK.ArgValue) &&
1103 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT, /*AllowEphemerals*/ true))
1104 Known.Zero.setLowBits(Log2_64(Value: RK.ArgValue));
1105 }
1106 continue;
1107 }
1108
1109 // Warning: This loop can end up being somewhat performance sensitive.
1110 // We're running this loop for once for each value queried resulting in a
1111 // runtime of ~O(#assumes * #values).
1112
1113 Value *Arg = I->getArgOperand(i: 0);
1114
1115 if (Arg == V && isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1116 assert(BitWidth == 1 && "assume operand is not i1?");
1117 (void)BitWidth;
1118 Known.setAllOnes();
1119 return;
1120 }
1121 if (match(V: Arg, P: m_Not(V: m_Specific(V))) &&
1122 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1123 assert(BitWidth == 1 && "assume operand is not i1?");
1124 (void)BitWidth;
1125 Known.setAllZero();
1126 return;
1127 }
1128 auto *Trunc = dyn_cast<TruncInst>(Val: Arg);
1129 if (Trunc && Trunc->getOperand(i_nocapture: 0) == V &&
1130 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT)) {
1131 if (Trunc->hasNoUnsignedWrap()) {
1132 Known = KnownBits::makeConstant(C: APInt(BitWidth, 1));
1133 return;
1134 }
1135 Known.One.setBit(0);
1136 return;
1137 }
1138
1139 // The remaining tests are all recursive, so bail out if we hit the limit.
1140 if (Depth == MaxAnalysisRecursionDepth)
1141 continue;
1142
1143 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
1144 if (!Cmp)
1145 continue;
1146
1147 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
1148 continue;
1149
1150 computeKnownBitsFromICmpCond(V, Cmp, Known, SQ: Q, /*Invert=*/false);
1151 }
1152
1153 // Conflicting assumption: Undefined behavior will occur on this execution
1154 // path.
1155 if (Known.hasConflict())
1156 Known.resetAll();
1157}
1158
1159/// Compute known bits from a shift operator, including those with a
1160/// non-constant shift amount. Known is the output of this function. Known2 is a
1161/// pre-allocated temporary with the same bit width as Known and on return
1162/// contains the known bit of the shift value source. KF is an
1163/// operator-specific function that, given the known-bits and a shift amount,
1164/// compute the implied known-bits of the shift operator's result respectively
1165/// for that shift amount. The results from calling KF are conservatively
1166/// combined for all permitted shift amounts.
1167static void computeKnownBitsFromShiftOperator(
1168 const Operator *I, const APInt &DemandedElts, KnownBits &Known,
1169 KnownBits &Known2, const SimplifyQuery &Q, unsigned Depth,
1170 function_ref<KnownBits(const KnownBits &, const KnownBits &, bool)> KF) {
1171 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1172 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1173 // To limit compile-time impact, only query isKnownNonZero() if we know at
1174 // least something about the shift amount.
1175 bool ShAmtNonZero =
1176 Known.isNonZero() ||
1177 (Known.getMaxValue().ult(RHS: Known.getBitWidth()) &&
1178 isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1));
1179 Known = KF(Known2, Known, ShAmtNonZero);
1180}
1181
1182static KnownBits
1183getKnownBitsFromAndXorOr(const Operator *I, const APInt &DemandedElts,
1184 const KnownBits &KnownLHS, const KnownBits &KnownRHS,
1185 const SimplifyQuery &Q, unsigned Depth) {
1186 unsigned BitWidth = KnownLHS.getBitWidth();
1187 KnownBits KnownOut(BitWidth);
1188 bool IsAnd = false;
1189 bool HasKnownOne = !KnownLHS.One.isZero() || !KnownRHS.One.isZero();
1190 Value *X = nullptr, *Y = nullptr;
1191
1192 switch (I->getOpcode()) {
1193 case Instruction::And:
1194 KnownOut = KnownLHS & KnownRHS;
1195 IsAnd = true;
1196 // and(x, -x) is common idioms that will clear all but lowest set
1197 // bit. If we have a single known bit in x, we can clear all bits
1198 // above it.
1199 // TODO: instcombine often reassociates independent `and` which can hide
1200 // this pattern. Try to match and(x, and(-x, y)) / and(and(x, y), -x).
1201 if (HasKnownOne && match(V: I, P: m_c_And(L: m_Value(V&: X), R: m_Neg(V: m_Deferred(V: X))))) {
1202 // -(-x) == x so using whichever (LHS/RHS) gets us a better result.
1203 if (KnownLHS.countMaxTrailingZeros() <= KnownRHS.countMaxTrailingZeros())
1204 KnownOut = KnownLHS.blsi();
1205 else
1206 KnownOut = KnownRHS.blsi();
1207 }
1208 break;
1209 case Instruction::Or:
1210 KnownOut = KnownLHS | KnownRHS;
1211 break;
1212 case Instruction::Xor:
1213 KnownOut = KnownLHS ^ KnownRHS;
1214 // xor(x, x-1) is common idioms that will clear all but lowest set
1215 // bit. If we have a single known bit in x, we can clear all bits
1216 // above it.
1217 // TODO: xor(x, x-1) is often rewritting as xor(x, x-C) where C !=
1218 // -1 but for the purpose of demanded bits (xor(x, x-C) &
1219 // Demanded) == (xor(x, x-1) & Demanded). Extend the xor pattern
1220 // to use arbitrary C if xor(x, x-C) as the same as xor(x, x-1).
1221 if (HasKnownOne &&
1222 match(V: I, P: m_c_Xor(L: m_Value(V&: X), R: m_Add(L: m_Deferred(V: X), R: m_AllOnes())))) {
1223 const KnownBits &XBits = I->getOperand(i: 0) == X ? KnownLHS : KnownRHS;
1224 KnownOut = XBits.blsmsk();
1225 }
1226 break;
1227 default:
1228 llvm_unreachable("Invalid Op used in 'analyzeKnownBitsFromAndXorOr'");
1229 }
1230
1231 // and(x, add (x, -1)) is a common idiom that always clears the low bit;
1232 // xor/or(x, add (x, -1)) is an idiom that will always set the low bit.
1233 // here we handle the more general case of adding any odd number by
1234 // matching the form and/xor/or(x, add(x, y)) where y is odd.
1235 // TODO: This could be generalized to clearing any bit set in y where the
1236 // following bit is known to be unset in y.
1237 if (!KnownOut.Zero[0] && !KnownOut.One[0] &&
1238 (match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_c_Add(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1239 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Deferred(V: X), R: m_Value(V&: Y)))) ||
1240 match(V: I, P: m_c_BinOp(L: m_Value(V&: X), R: m_Sub(L: m_Value(V&: Y), R: m_Deferred(V: X)))))) {
1241 KnownBits KnownY(BitWidth);
1242 computeKnownBits(V: Y, DemandedElts, Known&: KnownY, Q, Depth: Depth + 1);
1243 if (KnownY.countMinTrailingOnes() > 0) {
1244 if (IsAnd)
1245 KnownOut.Zero.setBit(0);
1246 else
1247 KnownOut.One.setBit(0);
1248 }
1249 }
1250 return KnownOut;
1251}
1252
1253static KnownBits computeKnownBitsForHorizontalOperation(
1254 const Operator *I, const APInt &DemandedElts, const SimplifyQuery &Q,
1255 unsigned Depth,
1256 const function_ref<KnownBits(const KnownBits &, const KnownBits &)>
1257 KnownBitsFunc) {
1258 APInt DemandedEltsLHS, DemandedEltsRHS;
1259 getHorizDemandedEltsForFirstOperand(VectorBitWidth: Q.DL.getTypeSizeInBits(Ty: I->getType()),
1260 DemandedElts, DemandedLHS&: DemandedEltsLHS,
1261 DemandedRHS&: DemandedEltsRHS);
1262
1263 const auto ComputeForSingleOpFunc =
1264 [Depth, &Q, KnownBitsFunc](const Value *Op, APInt &DemandedEltsOp) {
1265 return KnownBitsFunc(
1266 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp, Q, Depth: Depth + 1),
1267 computeKnownBits(V: Op, DemandedElts: DemandedEltsOp << 1, Q, Depth: Depth + 1));
1268 };
1269
1270 if (DemandedEltsRHS.isZero())
1271 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS);
1272 if (DemandedEltsLHS.isZero())
1273 return ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS);
1274
1275 return ComputeForSingleOpFunc(I->getOperand(i: 0), DemandedEltsLHS)
1276 .intersectWith(RHS: ComputeForSingleOpFunc(I->getOperand(i: 1), DemandedEltsRHS));
1277}
1278
1279// Public so this can be used in `SimplifyDemandedUseBits`.
1280KnownBits llvm::analyzeKnownBitsFromAndXorOr(const Operator *I,
1281 const KnownBits &KnownLHS,
1282 const KnownBits &KnownRHS,
1283 const SimplifyQuery &SQ,
1284 unsigned Depth) {
1285 auto *FVTy = dyn_cast<FixedVectorType>(Val: I->getType());
1286 APInt DemandedElts =
1287 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
1288
1289 return getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS, KnownRHS, Q: SQ,
1290 Depth);
1291}
1292
1293ConstantRange llvm::getVScaleRange(const Function *F, unsigned BitWidth) {
1294 Attribute Attr = F->getFnAttribute(Kind: Attribute::VScaleRange);
1295 // Without vscale_range, we only know that vscale is non-zero.
1296 if (!Attr.isValid())
1297 return ConstantRange(APInt(BitWidth, 1), APInt::getZero(numBits: BitWidth));
1298
1299 unsigned AttrMin = Attr.getVScaleRangeMin();
1300 // Minimum is larger than vscale width, result is always poison.
1301 if ((unsigned)llvm::bit_width(Value: AttrMin) > BitWidth)
1302 return ConstantRange::getEmpty(BitWidth);
1303
1304 APInt Min(BitWidth, AttrMin);
1305 std::optional<unsigned> AttrMax = Attr.getVScaleRangeMax();
1306 if (!AttrMax || (unsigned)llvm::bit_width(Value: *AttrMax) > BitWidth)
1307 return ConstantRange(Min, APInt::getZero(numBits: BitWidth));
1308
1309 return ConstantRange(Min, APInt(BitWidth, *AttrMax) + 1);
1310}
1311
1312void llvm::adjustKnownBitsForSelectArm(KnownBits &Known, Value *Cond,
1313 Value *Arm, bool Invert,
1314 const SimplifyQuery &Q, unsigned Depth) {
1315 // If we have a constant arm, we are done.
1316 if (Known.isConstant())
1317 return;
1318
1319 // See what condition implies about the bits of the select arm.
1320 KnownBits CondRes(Known.getBitWidth());
1321 computeKnownBitsFromCond(V: Arm, Cond, Known&: CondRes, SQ: Q, Invert, Depth: Depth + 1);
1322 // If we don't get any information from the condition, no reason to
1323 // proceed.
1324 if (CondRes.isUnknown())
1325 return;
1326
1327 // We can have conflict if the condition is dead. I.e if we have
1328 // (x | 64) < 32 ? (x | 64) : y
1329 // we will have conflict at bit 6 from the condition/the `or`.
1330 // In that case just return. Its not particularly important
1331 // what we do, as this select is going to be simplified soon.
1332 CondRes = CondRes.unionWith(RHS: Known);
1333 if (CondRes.hasConflict())
1334 return;
1335
1336 // Finally make sure the information we found is valid. This is relatively
1337 // expensive so it's left for the very end.
1338 if (!isGuaranteedNotToBeUndef(V: Arm, AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT, Depth: Depth + 1))
1339 return;
1340
1341 // Finally, we know we get information from the condition and its valid,
1342 // so return it.
1343 Known = CondRes;
1344}
1345
1346// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow).
1347// Returns the input and lower/upper bounds.
1348static bool isSignedMinMaxClamp(const Value *Select, const Value *&In,
1349 const APInt *&CLow, const APInt *&CHigh) {
1350 assert(isa<Operator>(Select) &&
1351 cast<Operator>(Select)->getOpcode() == Instruction::Select &&
1352 "Input should be a Select!");
1353
1354 const Value *LHS = nullptr, *RHS = nullptr;
1355 SelectPatternFlavor SPF = matchSelectPattern(V: Select, LHS, RHS).Flavor;
1356 if (SPF != SPF_SMAX && SPF != SPF_SMIN)
1357 return false;
1358
1359 if (!match(V: RHS, P: m_APInt(Res&: CLow)))
1360 return false;
1361
1362 const Value *LHS2 = nullptr, *RHS2 = nullptr;
1363 SelectPatternFlavor SPF2 = matchSelectPattern(V: LHS, LHS&: LHS2, RHS&: RHS2).Flavor;
1364 if (getInverseMinMaxFlavor(SPF) != SPF2)
1365 return false;
1366
1367 if (!match(V: RHS2, P: m_APInt(Res&: CHigh)))
1368 return false;
1369
1370 if (SPF == SPF_SMIN)
1371 std::swap(a&: CLow, b&: CHigh);
1372
1373 In = LHS2;
1374 return CLow->sle(RHS: *CHigh);
1375}
1376
1377static bool isSignedMinMaxIntrinsicClamp(const IntrinsicInst *II,
1378 const APInt *&CLow,
1379 const APInt *&CHigh) {
1380 assert((II->getIntrinsicID() == Intrinsic::smin ||
1381 II->getIntrinsicID() == Intrinsic::smax) &&
1382 "Must be smin/smax");
1383
1384 Intrinsic::ID InverseID = getInverseMinMaxIntrinsic(MinMaxID: II->getIntrinsicID());
1385 auto *InnerII = dyn_cast<IntrinsicInst>(Val: II->getArgOperand(i: 0));
1386 if (!InnerII || InnerII->getIntrinsicID() != InverseID ||
1387 !match(V: II->getArgOperand(i: 1), P: m_APInt(Res&: CLow)) ||
1388 !match(V: InnerII->getArgOperand(i: 1), P: m_APInt(Res&: CHigh)))
1389 return false;
1390
1391 if (II->getIntrinsicID() == Intrinsic::smin)
1392 std::swap(a&: CLow, b&: CHigh);
1393 return CLow->sle(RHS: *CHigh);
1394}
1395
1396static void unionWithMinMaxIntrinsicClamp(const IntrinsicInst *II,
1397 KnownBits &Known) {
1398 const APInt *CLow, *CHigh;
1399 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
1400 Known = Known.unionWith(
1401 RHS: ConstantRange::getNonEmpty(Lower: *CLow, Upper: *CHigh + 1).toKnownBits());
1402}
1403
1404static void computeKnownBitsFromOperator(const Operator *I,
1405 const APInt &DemandedElts,
1406 KnownBits &Known,
1407 const SimplifyQuery &Q,
1408 unsigned Depth) {
1409 unsigned BitWidth = Known.getBitWidth();
1410
1411 KnownBits Known2(BitWidth);
1412 switch (I->getOpcode()) {
1413 default: break;
1414 case Instruction::Load:
1415 if (MDNode *MD =
1416 Q.IIQ.getMetadata(I: cast<LoadInst>(Val: I), KindID: LLVMContext::MD_range))
1417 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
1418 break;
1419 case Instruction::And:
1420 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1421 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1422
1423 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1424 break;
1425 case Instruction::Or:
1426 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1427 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1428
1429 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1430 break;
1431 case Instruction::Xor:
1432 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
1433 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1434
1435 Known = getKnownBitsFromAndXorOr(I, DemandedElts, KnownLHS: Known2, KnownRHS: Known, Q, Depth);
1436 break;
1437 case Instruction::Mul: {
1438 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1439 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1440 computeKnownBitsMul(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1441 DemandedElts, Known, Known2, Q, Depth);
1442 break;
1443 }
1444 case Instruction::UDiv: {
1445 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1446 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1447 Known =
1448 KnownBits::udiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1449 break;
1450 }
1451 case Instruction::SDiv: {
1452 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1453 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1454 Known =
1455 KnownBits::sdiv(LHS: Known, RHS: Known2, Exact: Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)));
1456 break;
1457 }
1458 case Instruction::Select: {
1459 auto ComputeForArm = [&](Value *Arm, bool Invert) {
1460 KnownBits Res(Known.getBitWidth());
1461 computeKnownBits(V: Arm, DemandedElts, Known&: Res, Q, Depth: Depth + 1);
1462 adjustKnownBitsForSelectArm(Known&: Res, Cond: I->getOperand(i: 0), Arm, Invert, Q, Depth);
1463 return Res;
1464 };
1465 // Only known if known in both the LHS and RHS.
1466 Known =
1467 ComputeForArm(I->getOperand(i: 1), /*Invert=*/false)
1468 .intersectWith(RHS: ComputeForArm(I->getOperand(i: 2), /*Invert=*/true));
1469 break;
1470 }
1471 case Instruction::FPTrunc:
1472 case Instruction::FPExt:
1473 case Instruction::FPToUI:
1474 case Instruction::FPToSI:
1475 case Instruction::SIToFP:
1476 case Instruction::UIToFP:
1477 break; // Can't work with floating point.
1478 case Instruction::PtrToInt:
1479 case Instruction::PtrToAddr:
1480 case Instruction::IntToPtr:
1481 // Fall through and handle them the same as zext/trunc.
1482 [[fallthrough]];
1483 case Instruction::ZExt:
1484 case Instruction::Trunc: {
1485 Type *SrcTy = I->getOperand(i: 0)->getType();
1486
1487 unsigned SrcBitWidth;
1488 // Note that we handle pointer operands here because of inttoptr/ptrtoint
1489 // which fall through here.
1490 Type *ScalarTy = SrcTy->getScalarType();
1491 SrcBitWidth = ScalarTy->isPointerTy() ?
1492 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
1493 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
1494
1495 assert(SrcBitWidth && "SrcBitWidth can't be zero");
1496 Known = Known.anyextOrTrunc(BitWidth: SrcBitWidth);
1497 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1498 if (auto *Inst = dyn_cast<PossiblyNonNegInst>(Val: I);
1499 Inst && Inst->hasNonNeg() && !Known.isNegative())
1500 Known.makeNonNegative();
1501 Known = Known.zextOrTrunc(BitWidth);
1502 break;
1503 }
1504 case Instruction::BitCast: {
1505 Type *SrcTy = I->getOperand(i: 0)->getType();
1506 if (SrcTy->isIntOrPtrTy() &&
1507 // TODO: For now, not handling conversions like:
1508 // (bitcast i64 %x to <2 x i32>)
1509 !I->getType()->isVectorTy()) {
1510 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1511 break;
1512 }
1513
1514 const Value *V;
1515 // Handle bitcast from floating point to integer.
1516 if (match(V: I, P: m_ElementWiseBitCast(Op: m_Value(V))) &&
1517 V->getType()->isFPOrFPVectorTy()) {
1518 Type *FPType = V->getType()->getScalarType();
1519 KnownFPClass Result =
1520 computeKnownFPClass(V, DemandedElts, InterestedClasses: fcAllFlags, SQ: Q, Depth: Depth + 1);
1521 FPClassTest FPClasses = Result.KnownFPClasses;
1522
1523 // TODO: Treat it as zero/poison if the use of I is unreachable.
1524 if (FPClasses == fcNone)
1525 break;
1526
1527 if (Result.isKnownNever(Mask: fcNormal | fcSubnormal | fcNan)) {
1528 Known.setAllConflict();
1529
1530 if (FPClasses & fcInf)
1531 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1532 C: APFloat::getInf(Sem: FPType->getFltSemantics()).bitcastToAPInt()));
1533
1534 if (FPClasses & fcZero)
1535 Known = Known.intersectWith(RHS: KnownBits::makeConstant(
1536 C: APInt::getZero(numBits: FPType->getScalarSizeInBits())));
1537
1538 Known.Zero.clearSignBit();
1539 Known.One.clearSignBit();
1540 }
1541
1542 if (Result.SignBit) {
1543 if (*Result.SignBit)
1544 Known.makeNegative();
1545 else
1546 Known.makeNonNegative();
1547 }
1548
1549 break;
1550 }
1551
1552 // Handle cast from vector integer type to scalar or vector integer.
1553 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcTy);
1554 if (!SrcVecTy || !SrcVecTy->getElementType()->isIntegerTy() ||
1555 !I->getType()->isIntOrIntVectorTy() ||
1556 isa<ScalableVectorType>(Val: I->getType()))
1557 break;
1558
1559 unsigned NumElts = DemandedElts.getBitWidth();
1560 bool IsLE = Q.DL.isLittleEndian();
1561 // Look through a cast from narrow vector elements to wider type.
1562 // Examples: v4i32 -> v2i64, v3i8 -> v24
1563 unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
1564 if (BitWidth % SubBitWidth == 0) {
1565 // Known bits are automatically intersected across demanded elements of a
1566 // vector. So for example, if a bit is computed as known zero, it must be
1567 // zero across all demanded elements of the vector.
1568 //
1569 // For this bitcast, each demanded element of the output is sub-divided
1570 // across a set of smaller vector elements in the source vector. To get
1571 // the known bits for an entire element of the output, compute the known
1572 // bits for each sub-element sequentially. This is done by shifting the
1573 // one-set-bit demanded elements parameter across the sub-elements for
1574 // consecutive calls to computeKnownBits. We are using the demanded
1575 // elements parameter as a mask operator.
1576 //
1577 // The known bits of each sub-element are then inserted into place
1578 // (dependent on endian) to form the full result of known bits.
1579 unsigned SubScale = BitWidth / SubBitWidth;
1580 APInt SubDemandedElts = APInt::getZero(numBits: NumElts * SubScale);
1581 for (unsigned i = 0; i != NumElts; ++i) {
1582 if (DemandedElts[i])
1583 SubDemandedElts.setBit(i * SubScale);
1584 }
1585
1586 KnownBits KnownSrc(SubBitWidth);
1587 for (unsigned i = 0; i != SubScale; ++i) {
1588 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts.shl(shiftAmt: i), Known&: KnownSrc, Q,
1589 Depth: Depth + 1);
1590 unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
1591 Known.insertBits(SubBits: KnownSrc, BitPosition: ShiftElt * SubBitWidth);
1592 }
1593 }
1594 // Look through a cast from wider vector elements to narrow type.
1595 // Examples: v2i64 -> v4i32
1596 if (SubBitWidth % BitWidth == 0) {
1597 unsigned SubScale = SubBitWidth / BitWidth;
1598 KnownBits KnownSrc(SubBitWidth);
1599 APInt SubDemandedElts =
1600 APIntOps::ScaleBitMask(A: DemandedElts, NewBitWidth: NumElts / SubScale);
1601 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: SubDemandedElts, Known&: KnownSrc, Q,
1602 Depth: Depth + 1);
1603
1604 Known.setAllConflict();
1605 for (unsigned i = 0; i != NumElts; ++i) {
1606 if (DemandedElts[i]) {
1607 unsigned Shifts = IsLE ? i : NumElts - 1 - i;
1608 unsigned Offset = (Shifts % SubScale) * BitWidth;
1609 Known = Known.intersectWith(RHS: KnownSrc.extractBits(NumBits: BitWidth, BitPosition: Offset));
1610 if (Known.isUnknown())
1611 break;
1612 }
1613 }
1614 }
1615 break;
1616 }
1617 case Instruction::SExt: {
1618 // Compute the bits in the result that are not present in the input.
1619 unsigned SrcBitWidth = I->getOperand(i: 0)->getType()->getScalarSizeInBits();
1620
1621 Known = Known.trunc(BitWidth: SrcBitWidth);
1622 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1623 // If the sign bit of the input is known set or clear, then we know the
1624 // top bits of the result.
1625 Known = Known.sext(BitWidth);
1626 break;
1627 }
1628 case Instruction::Shl: {
1629 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1630 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1631 auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1632 bool ShAmtNonZero) {
1633 return KnownBits::shl(LHS: KnownVal, RHS: KnownAmt, NUW, NSW, ShAmtNonZero);
1634 };
1635 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1636 KF);
1637 // Trailing zeros of a right-shifted constant never decrease.
1638 const APInt *C;
1639 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1640 Known.Zero.setLowBits(C->countr_zero());
1641 break;
1642 }
1643 case Instruction::LShr: {
1644 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1645 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1646 bool ShAmtNonZero) {
1647 return KnownBits::lshr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1648 };
1649 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1650 KF);
1651 // Leading zeros of a left-shifted constant never decrease.
1652 const APInt *C;
1653 if (match(V: I->getOperand(i: 0), P: m_APInt(Res&: C)))
1654 Known.Zero.setHighBits(C->countl_zero());
1655 break;
1656 }
1657 case Instruction::AShr: {
1658 bool Exact = Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I));
1659 auto KF = [Exact](const KnownBits &KnownVal, const KnownBits &KnownAmt,
1660 bool ShAmtNonZero) {
1661 return KnownBits::ashr(LHS: KnownVal, RHS: KnownAmt, ShAmtNonZero, Exact);
1662 };
1663 computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Q, Depth,
1664 KF);
1665 break;
1666 }
1667 case Instruction::Sub: {
1668 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1669 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1670 computeKnownBitsAddSub(Add: false, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1671 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1672 break;
1673 }
1674 case Instruction::Add: {
1675 bool NSW = Q.IIQ.hasNoSignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1676 bool NUW = Q.IIQ.hasNoUnsignedWrap(Op: cast<OverflowingBinaryOperator>(Val: I));
1677 computeKnownBitsAddSub(Add: true, Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1), NSW, NUW,
1678 DemandedElts, KnownOut&: Known, Known2, Q, Depth);
1679 break;
1680 }
1681 case Instruction::SRem:
1682 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1683 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1684 Known = KnownBits::srem(LHS: Known, RHS: Known2);
1685 break;
1686
1687 case Instruction::URem:
1688 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
1689 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
1690 Known = KnownBits::urem(LHS: Known, RHS: Known2);
1691 break;
1692 case Instruction::Alloca:
1693 Known.Zero.setLowBits(Log2(A: cast<AllocaInst>(Val: I)->getAlign()));
1694 break;
1695 case Instruction::GetElementPtr: {
1696 // Analyze all of the subscripts of this getelementptr instruction
1697 // to determine if we can prove known low zero bits.
1698 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
1699 // Accumulate the constant indices in a separate variable
1700 // to minimize the number of calls to computeForAddSub.
1701 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: I->getType());
1702 APInt AccConstIndices(IndexWidth, 0);
1703
1704 auto AddIndexToKnown = [&](KnownBits IndexBits) {
1705 if (IndexWidth == BitWidth) {
1706 // Note that inbounds does *not* guarantee nsw for the addition, as only
1707 // the offset is signed, while the base address is unsigned.
1708 Known = KnownBits::add(LHS: Known, RHS: IndexBits);
1709 } else {
1710 // If the index width is smaller than the pointer width, only add the
1711 // value to the low bits.
1712 assert(IndexWidth < BitWidth &&
1713 "Index width can't be larger than pointer width");
1714 Known.insertBits(SubBits: KnownBits::add(LHS: Known.trunc(BitWidth: IndexWidth), RHS: IndexBits), BitPosition: 0);
1715 }
1716 };
1717
1718 gep_type_iterator GTI = gep_type_begin(GEP: I);
1719 for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) {
1720 // TrailZ can only become smaller, short-circuit if we hit zero.
1721 if (Known.isUnknown())
1722 break;
1723
1724 Value *Index = I->getOperand(i);
1725
1726 // Handle case when index is zero.
1727 Constant *CIndex = dyn_cast<Constant>(Val: Index);
1728 if (CIndex && CIndex->isNullValue())
1729 continue;
1730
1731 if (StructType *STy = GTI.getStructTypeOrNull()) {
1732 // Handle struct member offset arithmetic.
1733
1734 assert(CIndex &&
1735 "Access to structure field must be known at compile time");
1736
1737 if (CIndex->getType()->isVectorTy())
1738 Index = CIndex->getSplatValue();
1739
1740 unsigned Idx = cast<ConstantInt>(Val: Index)->getZExtValue();
1741 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
1742 uint64_t Offset = SL->getElementOffset(Idx);
1743 AccConstIndices += Offset;
1744 continue;
1745 }
1746
1747 // Handle array index arithmetic.
1748 Type *IndexedTy = GTI.getIndexedType();
1749 if (!IndexedTy->isSized()) {
1750 Known.resetAll();
1751 break;
1752 }
1753
1754 TypeSize Stride = GTI.getSequentialElementStride(DL: Q.DL);
1755 uint64_t StrideInBytes = Stride.getKnownMinValue();
1756 if (!Stride.isScalable()) {
1757 // Fast path for constant offset.
1758 if (auto *CI = dyn_cast<ConstantInt>(Val: Index)) {
1759 AccConstIndices +=
1760 CI->getValue().sextOrTrunc(width: IndexWidth) * StrideInBytes;
1761 continue;
1762 }
1763 }
1764
1765 KnownBits IndexBits =
1766 computeKnownBits(V: Index, Q, Depth: Depth + 1).sextOrTrunc(BitWidth: IndexWidth);
1767 KnownBits ScalingFactor(IndexWidth);
1768 // Multiply by current sizeof type.
1769 // &A[i] == A + i * sizeof(*A[i]).
1770 if (Stride.isScalable()) {
1771 // For scalable types the only thing we know about sizeof is
1772 // that this is a multiple of the minimum size.
1773 ScalingFactor.Zero.setLowBits(llvm::countr_zero(Val: StrideInBytes));
1774 } else {
1775 ScalingFactor =
1776 KnownBits::makeConstant(C: APInt(IndexWidth, StrideInBytes));
1777 }
1778 AddIndexToKnown(KnownBits::mul(LHS: IndexBits, RHS: ScalingFactor));
1779 }
1780 if (!Known.isUnknown() && !AccConstIndices.isZero())
1781 AddIndexToKnown(KnownBits::makeConstant(C: AccConstIndices));
1782 break;
1783 }
1784 case Instruction::PHI: {
1785 const PHINode *P = cast<PHINode>(Val: I);
1786 BinaryOperator *BO = nullptr;
1787 Value *R = nullptr, *L = nullptr;
1788 if (matchSimpleRecurrence(P, BO, Start&: R, Step&: L)) {
1789 // Handle the case of a simple two-predecessor recurrence PHI.
1790 // There's a lot more that could theoretically be done here, but
1791 // this is sufficient to catch some interesting cases.
1792 unsigned Opcode = BO->getOpcode();
1793
1794 switch (Opcode) {
1795 // If this is a shift recurrence, we know the bits being shifted in. We
1796 // can combine that with information about the start value of the
1797 // recurrence to conclude facts about the result. If this is a udiv
1798 // recurrence, we know that the result can never exceed either the
1799 // numerator or the start value, whichever is greater.
1800 case Instruction::LShr:
1801 case Instruction::AShr:
1802 case Instruction::Shl:
1803 case Instruction::UDiv:
1804 if (BO->getOperand(i_nocapture: 0) != I)
1805 break;
1806 [[fallthrough]];
1807
1808 // For a urem recurrence, the result can never exceed the start value. The
1809 // phi could either be the numerator or the denominator.
1810 case Instruction::URem: {
1811 // We have matched a recurrence of the form:
1812 // %iv = [R, %entry], [%iv.next, %backedge]
1813 // %iv.next = shift_op %iv, L
1814
1815 // Recurse with the phi context to avoid concern about whether facts
1816 // inferred hold at original context instruction. TODO: It may be
1817 // correct to use the original context. IF warranted, explore and
1818 // add sufficient tests to cover.
1819 SimplifyQuery RecQ = Q.getWithoutCondContext();
1820 RecQ.CxtI = P;
1821 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1822 switch (Opcode) {
1823 case Instruction::Shl:
1824 // A shl recurrence will only increase the tailing zeros
1825 Known.Zero.setLowBits(Known2.countMinTrailingZeros());
1826 break;
1827 case Instruction::LShr:
1828 case Instruction::UDiv:
1829 case Instruction::URem:
1830 // lshr, udiv, and urem recurrences will preserve the leading zeros of
1831 // the start value.
1832 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1833 break;
1834 case Instruction::AShr:
1835 // An ashr recurrence will extend the initial sign bit
1836 Known.Zero.setHighBits(Known2.countMinLeadingZeros());
1837 Known.One.setHighBits(Known2.countMinLeadingOnes());
1838 break;
1839 }
1840 break;
1841 }
1842
1843 // Check for operations that have the property that if
1844 // both their operands have low zero bits, the result
1845 // will have low zero bits.
1846 case Instruction::Add:
1847 case Instruction::Sub:
1848 case Instruction::And:
1849 case Instruction::Or:
1850 case Instruction::Mul: {
1851 // Change the context instruction to the "edge" that flows into the
1852 // phi. This is important because that is where the value is actually
1853 // "evaluated" even though it is used later somewhere else. (see also
1854 // D69571).
1855 SimplifyQuery RecQ = Q.getWithoutCondContext();
1856
1857 unsigned OpNum = P->getOperand(i_nocapture: 0) == R ? 0 : 1;
1858 Instruction *RInst = P->getIncomingBlock(i: OpNum)->getTerminator();
1859 Instruction *LInst = P->getIncomingBlock(i: 1 - OpNum)->getTerminator();
1860
1861 // Ok, we have a PHI of the form L op= R. Check for low
1862 // zero bits.
1863 RecQ.CxtI = RInst;
1864 computeKnownBits(V: R, DemandedElts, Known&: Known2, Q: RecQ, Depth: Depth + 1);
1865
1866 // We need to take the minimum number of known bits
1867 KnownBits Known3(BitWidth);
1868 RecQ.CxtI = LInst;
1869 computeKnownBits(V: L, DemandedElts, Known&: Known3, Q: RecQ, Depth: Depth + 1);
1870
1871 Known.Zero.setLowBits(std::min(a: Known2.countMinTrailingZeros(),
1872 b: Known3.countMinTrailingZeros()));
1873
1874 auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(Val: BO);
1875 if (!OverflowOp || !Q.IIQ.hasNoSignedWrap(Op: OverflowOp))
1876 break;
1877
1878 switch (Opcode) {
1879 // If initial value of recurrence is nonnegative, and we are adding
1880 // a nonnegative number with nsw, the result can only be nonnegative
1881 // or poison value regardless of the number of times we execute the
1882 // add in phi recurrence. If initial value is negative and we are
1883 // adding a negative number with nsw, the result can only be
1884 // negative or poison value. Similar arguments apply to sub and mul.
1885 //
1886 // (add non-negative, non-negative) --> non-negative
1887 // (add negative, negative) --> negative
1888 case Instruction::Add: {
1889 if (Known2.isNonNegative() && Known3.isNonNegative())
1890 Known.makeNonNegative();
1891 else if (Known2.isNegative() && Known3.isNegative())
1892 Known.makeNegative();
1893 break;
1894 }
1895
1896 // (sub nsw non-negative, negative) --> non-negative
1897 // (sub nsw negative, non-negative) --> negative
1898 case Instruction::Sub: {
1899 if (BO->getOperand(i_nocapture: 0) != I)
1900 break;
1901 if (Known2.isNonNegative() && Known3.isNegative())
1902 Known.makeNonNegative();
1903 else if (Known2.isNegative() && Known3.isNonNegative())
1904 Known.makeNegative();
1905 break;
1906 }
1907
1908 // (mul nsw non-negative, non-negative) --> non-negative
1909 case Instruction::Mul:
1910 if (Known2.isNonNegative() && Known3.isNonNegative())
1911 Known.makeNonNegative();
1912 break;
1913
1914 default:
1915 break;
1916 }
1917 break;
1918 }
1919
1920 default:
1921 break;
1922 }
1923 }
1924
1925 // Unreachable blocks may have zero-operand PHI nodes.
1926 if (P->getNumIncomingValues() == 0)
1927 break;
1928
1929 // Otherwise take the unions of the known bit sets of the operands,
1930 // taking conservative care to avoid excessive recursion.
1931 if (Depth < MaxAnalysisRecursionDepth - 1 && Known.isUnknown()) {
1932 // Skip if every incoming value references to ourself.
1933 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
1934 break;
1935
1936 Known.setAllConflict();
1937 for (const Use &U : P->operands()) {
1938 Value *IncValue;
1939 const PHINode *CxtPhi;
1940 Instruction *CxtI;
1941 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI, PhiOut: &CxtPhi);
1942 // Skip direct self references.
1943 if (IncValue == P)
1944 continue;
1945
1946 // Change the context instruction to the "edge" that flows into the
1947 // phi. This is important because that is where the value is actually
1948 // "evaluated" even though it is used later somewhere else. (see also
1949 // D69571).
1950 SimplifyQuery RecQ = Q.getWithoutCondContext().getWithInstruction(I: CxtI);
1951
1952 Known2 = KnownBits(BitWidth);
1953
1954 // Recurse, but cap the recursion to one level, because we don't
1955 // want to waste time spinning around in loops.
1956 // TODO: See if we can base recursion limiter on number of incoming phi
1957 // edges so we don't overly clamp analysis.
1958 computeKnownBits(V: IncValue, DemandedElts, Known&: Known2, Q: RecQ,
1959 Depth: MaxAnalysisRecursionDepth - 1);
1960
1961 // See if we can further use a conditional branch into the phi
1962 // to help us determine the range of the value.
1963 if (!Known2.isConstant()) {
1964 CmpPredicate Pred;
1965 const APInt *RHSC;
1966 BasicBlock *TrueSucc, *FalseSucc;
1967 // TODO: Use RHS Value and compute range from its known bits.
1968 if (match(V: RecQ.CxtI,
1969 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: IncValue), R: m_APInt(Res&: RHSC)),
1970 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
1971 // Check for cases of duplicate successors.
1972 if ((TrueSucc == CxtPhi->getParent()) !=
1973 (FalseSucc == CxtPhi->getParent())) {
1974 // If we're using the false successor, invert the predicate.
1975 if (FalseSucc == CxtPhi->getParent())
1976 Pred = CmpInst::getInversePredicate(pred: Pred);
1977 // Get the knownbits implied by the incoming phi condition.
1978 auto CR = ConstantRange::makeExactICmpRegion(Pred, Other: *RHSC);
1979 KnownBits KnownUnion = Known2.unionWith(RHS: CR.toKnownBits());
1980 // We can have conflicts here if we are analyzing deadcode (its
1981 // impossible for us reach this BB based the icmp).
1982 if (KnownUnion.hasConflict()) {
1983 // No reason to continue analyzing in a known dead region, so
1984 // just resetAll and break. This will cause us to also exit the
1985 // outer loop.
1986 Known.resetAll();
1987 break;
1988 }
1989 Known2 = KnownUnion;
1990 }
1991 }
1992 }
1993
1994 Known = Known.intersectWith(RHS: Known2);
1995 // If all bits have been ruled out, there's no need to check
1996 // more operands.
1997 if (Known.isUnknown())
1998 break;
1999 }
2000 }
2001 break;
2002 }
2003 case Instruction::Call:
2004 case Instruction::Invoke: {
2005 // If range metadata is attached to this call, set known bits from that,
2006 // and then intersect with known bits based on other properties of the
2007 // function.
2008 if (MDNode *MD =
2009 Q.IIQ.getMetadata(I: cast<Instruction>(Val: I), KindID: LLVMContext::MD_range))
2010 computeKnownBitsFromRangeMetadata(Ranges: *MD, Known);
2011
2012 const auto *CB = cast<CallBase>(Val: I);
2013
2014 if (std::optional<ConstantRange> Range = CB->getRange())
2015 Known = Known.unionWith(RHS: Range->toKnownBits());
2016
2017 if (const Value *RV = CB->getReturnedArgOperand()) {
2018 if (RV->getType() == I->getType()) {
2019 computeKnownBits(V: RV, Known&: Known2, Q, Depth: Depth + 1);
2020 Known = Known.unionWith(RHS: Known2);
2021 // If the function doesn't return properly for all input values
2022 // (e.g. unreachable exits) then there might be conflicts between the
2023 // argument value and the range metadata. Simply discard the known bits
2024 // in case of conflicts.
2025 if (Known.hasConflict())
2026 Known.resetAll();
2027 }
2028 }
2029 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) {
2030 switch (II->getIntrinsicID()) {
2031 default:
2032 break;
2033 case Intrinsic::abs: {
2034 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2035 bool IntMinIsPoison = match(V: II->getArgOperand(i: 1), P: m_One());
2036 Known = Known.unionWith(RHS: Known2.abs(IntMinIsPoison));
2037 break;
2038 }
2039 case Intrinsic::bitreverse:
2040 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2041 Known = Known.unionWith(RHS: Known2.reverseBits());
2042 break;
2043 case Intrinsic::bswap:
2044 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2045 Known = Known.unionWith(RHS: Known2.byteSwap());
2046 break;
2047 case Intrinsic::ctlz: {
2048 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2049 // If we have a known 1, its position is our upper bound.
2050 unsigned PossibleLZ = Known2.countMaxLeadingZeros();
2051 // If this call is poison for 0 input, the result will be less than 2^n.
2052 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2053 PossibleLZ = std::min(a: PossibleLZ, b: BitWidth - 1);
2054 unsigned LowBits = llvm::bit_width(Value: PossibleLZ);
2055 Known.Zero.setBitsFrom(LowBits);
2056 break;
2057 }
2058 case Intrinsic::cttz: {
2059 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2060 // If we have a known 1, its position is our upper bound.
2061 unsigned PossibleTZ = Known2.countMaxTrailingZeros();
2062 // If this call is poison for 0 input, the result will be less than 2^n.
2063 if (II->getArgOperand(i: 1) == ConstantInt::getTrue(Context&: II->getContext()))
2064 PossibleTZ = std::min(a: PossibleTZ, b: BitWidth - 1);
2065 unsigned LowBits = llvm::bit_width(Value: PossibleTZ);
2066 Known.Zero.setBitsFrom(LowBits);
2067 break;
2068 }
2069 case Intrinsic::ctpop: {
2070 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2071 // We can bound the space the count needs. Also, bits known to be zero
2072 // can't contribute to the population.
2073 unsigned BitsPossiblySet = Known2.countMaxPopulation();
2074 unsigned LowBits = llvm::bit_width(Value: BitsPossiblySet);
2075 Known.Zero.setBitsFrom(LowBits);
2076 // TODO: we could bound KnownOne using the lower bound on the number
2077 // of bits which might be set provided by popcnt KnownOne2.
2078 break;
2079 }
2080 case Intrinsic::fshr:
2081 case Intrinsic::fshl: {
2082 const APInt *SA;
2083 if (!match(V: I->getOperand(i: 2), P: m_APInt(Res&: SA)))
2084 break;
2085
2086 // Normalize to funnel shift left.
2087 uint64_t ShiftAmt = SA->urem(RHS: BitWidth);
2088 if (II->getIntrinsicID() == Intrinsic::fshr)
2089 ShiftAmt = BitWidth - ShiftAmt;
2090
2091 KnownBits Known3(BitWidth);
2092 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2093 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known3, Q, Depth: Depth + 1);
2094
2095 Known2 <<= ShiftAmt;
2096 Known3 >>= BitWidth - ShiftAmt;
2097 Known = Known2.unionWith(RHS: Known3);
2098 break;
2099 }
2100 case Intrinsic::clmul:
2101 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2102 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2103 Known = KnownBits::clmul(LHS: Known, RHS: Known2);
2104 break;
2105 case Intrinsic::uadd_sat:
2106 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2107 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2108 Known = KnownBits::uadd_sat(LHS: Known, RHS: Known2);
2109 break;
2110 case Intrinsic::usub_sat:
2111 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2112 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2113 Known = KnownBits::usub_sat(LHS: Known, RHS: Known2);
2114 break;
2115 case Intrinsic::sadd_sat:
2116 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2117 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2118 Known = KnownBits::sadd_sat(LHS: Known, RHS: Known2);
2119 break;
2120 case Intrinsic::ssub_sat:
2121 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2122 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2123 Known = KnownBits::ssub_sat(LHS: Known, RHS: Known2);
2124 break;
2125 // Vec reverse preserves bits from input vec.
2126 case Intrinsic::vector_reverse:
2127 computeKnownBits(V: I->getOperand(i: 0), DemandedElts: DemandedElts.reverseBits(), Known, Q,
2128 Depth: Depth + 1);
2129 break;
2130 // for min/max/and/or reduce, any bit common to each element in the
2131 // input vec is set in the output.
2132 case Intrinsic::vector_reduce_and:
2133 case Intrinsic::vector_reduce_or:
2134 case Intrinsic::vector_reduce_umax:
2135 case Intrinsic::vector_reduce_umin:
2136 case Intrinsic::vector_reduce_smax:
2137 case Intrinsic::vector_reduce_smin:
2138 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2139 break;
2140 case Intrinsic::vector_reduce_xor: {
2141 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2142 // The zeros common to all vecs are zero in the output.
2143 // If the number of elements is odd, then the common ones remain. If the
2144 // number of elements is even, then the common ones becomes zeros.
2145 auto *VecTy = cast<VectorType>(Val: I->getOperand(i: 0)->getType());
2146 // Even, so the ones become zeros.
2147 bool EvenCnt = VecTy->getElementCount().isKnownEven();
2148 if (EvenCnt)
2149 Known.Zero |= Known.One;
2150 // Maybe even element count so need to clear ones.
2151 if (VecTy->isScalableTy() || EvenCnt)
2152 Known.One.clearAllBits();
2153 break;
2154 }
2155 case Intrinsic::vector_reduce_add: {
2156 auto *VecTy = dyn_cast<FixedVectorType>(Val: I->getOperand(i: 0)->getType());
2157 if (!VecTy)
2158 break;
2159 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2160 Known = Known.reduceAdd(NumElts: VecTy->getNumElements());
2161 break;
2162 }
2163 case Intrinsic::umin:
2164 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2165 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2166 Known = KnownBits::umin(LHS: Known, RHS: Known2);
2167 break;
2168 case Intrinsic::umax:
2169 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2170 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2171 Known = KnownBits::umax(LHS: Known, RHS: Known2);
2172 break;
2173 case Intrinsic::smin:
2174 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2175 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2176 Known = KnownBits::smin(LHS: Known, RHS: Known2);
2177 unionWithMinMaxIntrinsicClamp(II, Known);
2178 break;
2179 case Intrinsic::smax:
2180 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2181 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2182 Known = KnownBits::smax(LHS: Known, RHS: Known2);
2183 unionWithMinMaxIntrinsicClamp(II, Known);
2184 break;
2185 case Intrinsic::ptrmask: {
2186 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2187
2188 const Value *Mask = I->getOperand(i: 1);
2189 Known2 = KnownBits(Mask->getType()->getScalarSizeInBits());
2190 computeKnownBits(V: Mask, DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2191 // TODO: 1-extend would be more precise.
2192 Known &= Known2.anyextOrTrunc(BitWidth);
2193 break;
2194 }
2195 case Intrinsic::x86_sse2_pmulh_w:
2196 case Intrinsic::x86_avx2_pmulh_w:
2197 case Intrinsic::x86_avx512_pmulh_w_512:
2198 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2199 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2200 Known = KnownBits::mulhs(LHS: Known, RHS: Known2);
2201 break;
2202 case Intrinsic::x86_sse2_pmulhu_w:
2203 case Intrinsic::x86_avx2_pmulhu_w:
2204 case Intrinsic::x86_avx512_pmulhu_w_512:
2205 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
2206 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Known&: Known2, Q, Depth: Depth + 1);
2207 Known = KnownBits::mulhu(LHS: Known, RHS: Known2);
2208 break;
2209 case Intrinsic::x86_sse42_crc32_64_64:
2210 Known.Zero.setBitsFrom(32);
2211 break;
2212 case Intrinsic::x86_ssse3_phadd_d_128:
2213 case Intrinsic::x86_ssse3_phadd_w_128:
2214 case Intrinsic::x86_avx2_phadd_d:
2215 case Intrinsic::x86_avx2_phadd_w: {
2216 Known = computeKnownBitsForHorizontalOperation(
2217 I, DemandedElts, Q, Depth,
2218 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2219 return KnownBits::add(LHS: KnownLHS, RHS: KnownRHS);
2220 });
2221 break;
2222 }
2223 case Intrinsic::x86_ssse3_phadd_sw_128:
2224 case Intrinsic::x86_avx2_phadd_sw: {
2225 Known = computeKnownBitsForHorizontalOperation(
2226 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::sadd_sat);
2227 break;
2228 }
2229 case Intrinsic::x86_ssse3_phsub_d_128:
2230 case Intrinsic::x86_ssse3_phsub_w_128:
2231 case Intrinsic::x86_avx2_phsub_d:
2232 case Intrinsic::x86_avx2_phsub_w: {
2233 Known = computeKnownBitsForHorizontalOperation(
2234 I, DemandedElts, Q, Depth,
2235 KnownBitsFunc: [](const KnownBits &KnownLHS, const KnownBits &KnownRHS) {
2236 return KnownBits::sub(LHS: KnownLHS, RHS: KnownRHS);
2237 });
2238 break;
2239 }
2240 case Intrinsic::x86_ssse3_phsub_sw_128:
2241 case Intrinsic::x86_avx2_phsub_sw: {
2242 Known = computeKnownBitsForHorizontalOperation(
2243 I, DemandedElts, Q, Depth, KnownBitsFunc: KnownBits::ssub_sat);
2244 break;
2245 }
2246 case Intrinsic::riscv_vsetvli:
2247 case Intrinsic::riscv_vsetvlimax: {
2248 bool HasAVL = II->getIntrinsicID() == Intrinsic::riscv_vsetvli;
2249 const ConstantRange Range = getVScaleRange(F: II->getFunction(), BitWidth);
2250 uint64_t SEW = RISCVVType::decodeVSEW(
2251 VSEW: cast<ConstantInt>(Val: II->getArgOperand(i: HasAVL))->getZExtValue());
2252 RISCVVType::VLMUL VLMUL = static_cast<RISCVVType::VLMUL>(
2253 cast<ConstantInt>(Val: II->getArgOperand(i: 1 + HasAVL))->getZExtValue());
2254 uint64_t MaxVLEN =
2255 Range.getUnsignedMax().getZExtValue() * RISCV::RVVBitsPerBlock;
2256 uint64_t MaxVL = MaxVLEN / RISCVVType::getSEWLMULRatio(SEW, VLMul: VLMUL);
2257
2258 // Result of vsetvli must be not larger than AVL.
2259 if (HasAVL)
2260 if (auto *CI = dyn_cast<ConstantInt>(Val: II->getArgOperand(i: 0)))
2261 MaxVL = std::min(a: MaxVL, b: CI->getZExtValue());
2262
2263 unsigned KnownZeroFirstBit = Log2_32(Value: MaxVL) + 1;
2264 if (BitWidth > KnownZeroFirstBit)
2265 Known.Zero.setBitsFrom(KnownZeroFirstBit);
2266 break;
2267 }
2268 case Intrinsic::vscale: {
2269 if (!II->getParent() || !II->getFunction())
2270 break;
2271
2272 Known = getVScaleRange(F: II->getFunction(), BitWidth).toKnownBits();
2273 break;
2274 }
2275 }
2276 }
2277 break;
2278 }
2279 case Instruction::ShuffleVector: {
2280 if (auto *Splat = getSplatValue(V: I)) {
2281 computeKnownBits(V: Splat, Known, Q, Depth: Depth + 1);
2282 break;
2283 }
2284
2285 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
2286 // FIXME: Do we need to handle ConstantExpr involving shufflevectors?
2287 if (!Shuf) {
2288 Known.resetAll();
2289 return;
2290 }
2291 // For undef elements, we don't know anything about the common state of
2292 // the shuffle result.
2293 APInt DemandedLHS, DemandedRHS;
2294 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) {
2295 Known.resetAll();
2296 return;
2297 }
2298 Known.setAllConflict();
2299 if (!!DemandedLHS) {
2300 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
2301 computeKnownBits(V: LHS, DemandedElts: DemandedLHS, Known, Q, Depth: Depth + 1);
2302 // If we don't know any bits, early out.
2303 if (Known.isUnknown())
2304 break;
2305 }
2306 if (!!DemandedRHS) {
2307 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
2308 computeKnownBits(V: RHS, DemandedElts: DemandedRHS, Known&: Known2, Q, Depth: Depth + 1);
2309 Known = Known.intersectWith(RHS: Known2);
2310 }
2311 break;
2312 }
2313 case Instruction::InsertElement: {
2314 if (isa<ScalableVectorType>(Val: I->getType())) {
2315 Known.resetAll();
2316 return;
2317 }
2318 const Value *Vec = I->getOperand(i: 0);
2319 const Value *Elt = I->getOperand(i: 1);
2320 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
2321 unsigned NumElts = DemandedElts.getBitWidth();
2322 APInt DemandedVecElts = DemandedElts;
2323 bool NeedsElt = true;
2324 // If we know the index we are inserting too, clear it from Vec check.
2325 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
2326 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
2327 NeedsElt = DemandedElts[CIdx->getZExtValue()];
2328 }
2329
2330 Known.setAllConflict();
2331 if (NeedsElt) {
2332 computeKnownBits(V: Elt, Known, Q, Depth: Depth + 1);
2333 // If we don't know any bits, early out.
2334 if (Known.isUnknown())
2335 break;
2336 }
2337
2338 if (!DemandedVecElts.isZero()) {
2339 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known&: Known2, Q, Depth: Depth + 1);
2340 Known = Known.intersectWith(RHS: Known2);
2341 }
2342 break;
2343 }
2344 case Instruction::ExtractElement: {
2345 // Look through extract element. If the index is non-constant or
2346 // out-of-range demand all elements, otherwise just the extracted element.
2347 const Value *Vec = I->getOperand(i: 0);
2348 const Value *Idx = I->getOperand(i: 1);
2349 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
2350 if (isa<ScalableVectorType>(Val: Vec->getType())) {
2351 // FIXME: there's probably *something* we can do with scalable vectors
2352 Known.resetAll();
2353 break;
2354 }
2355 unsigned NumElts = cast<FixedVectorType>(Val: Vec->getType())->getNumElements();
2356 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
2357 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
2358 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
2359 computeKnownBits(V: Vec, DemandedElts: DemandedVecElts, Known, Q, Depth: Depth + 1);
2360 break;
2361 }
2362 case Instruction::ExtractValue:
2363 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I->getOperand(i: 0))) {
2364 const ExtractValueInst *EVI = cast<ExtractValueInst>(Val: I);
2365 if (EVI->getNumIndices() != 1) break;
2366 if (EVI->getIndices()[0] == 0) {
2367 switch (II->getIntrinsicID()) {
2368 default: break;
2369 case Intrinsic::uadd_with_overflow:
2370 case Intrinsic::sadd_with_overflow:
2371 computeKnownBitsAddSub(
2372 Add: true, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2373 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2374 break;
2375 case Intrinsic::usub_with_overflow:
2376 case Intrinsic::ssub_with_overflow:
2377 computeKnownBitsAddSub(
2378 Add: false, Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), /*NSW=*/false,
2379 /* NUW=*/false, DemandedElts, KnownOut&: Known, Known2, Q, Depth);
2380 break;
2381 case Intrinsic::umul_with_overflow:
2382 case Intrinsic::smul_with_overflow:
2383 computeKnownBitsMul(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1), NSW: false,
2384 NUW: false, DemandedElts, Known, Known2, Q, Depth);
2385 break;
2386 }
2387 }
2388 }
2389 break;
2390 case Instruction::Freeze:
2391 if (isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
2392 Depth: Depth + 1))
2393 computeKnownBits(V: I->getOperand(i: 0), Known, Q, Depth: Depth + 1);
2394 break;
2395 }
2396}
2397
2398/// Determine which bits of V are known to be either zero or one and return
2399/// them.
2400KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
2401 const SimplifyQuery &Q, unsigned Depth) {
2402 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2403 ::computeKnownBits(V, DemandedElts, Known, Q, Depth);
2404 return Known;
2405}
2406
2407/// Determine which bits of V are known to be either zero or one and return
2408/// them.
2409KnownBits llvm::computeKnownBits(const Value *V, const SimplifyQuery &Q,
2410 unsigned Depth) {
2411 KnownBits Known(getBitWidth(Ty: V->getType(), DL: Q.DL));
2412 computeKnownBits(V, Known, Q, Depth);
2413 return Known;
2414}
2415
2416/// Determine which bits of V are known to be either zero or one and return
2417/// them in the Known bit set.
2418///
2419/// NOTE: we cannot consider 'undef' to be "IsZero" here. The problem is that
2420/// we cannot optimize based on the assumption that it is zero without changing
2421/// it to be an explicit zero. If we don't change it to zero, other code could
2422/// optimized based on the contradictory assumption that it is non-zero.
2423/// Because instcombine aggressively folds operations with undef args anyway,
2424/// this won't lose us code quality.
2425///
2426/// This function is defined on values with integer type, values with pointer
2427/// type, and vectors of integers. In the case
2428/// where V is a vector, known zero, and known one values are the
2429/// same width as the vector element, and the bit is set only if it is true
2430/// for all of the demanded elements in the vector specified by DemandedElts.
2431void computeKnownBits(const Value *V, const APInt &DemandedElts,
2432 KnownBits &Known, const SimplifyQuery &Q,
2433 unsigned Depth) {
2434 if (!DemandedElts) {
2435 // No demanded elts, better to assume we don't know anything.
2436 Known.resetAll();
2437 return;
2438 }
2439
2440 assert(V && "No Value?");
2441 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2442
2443#ifndef NDEBUG
2444 Type *Ty = V->getType();
2445 unsigned BitWidth = Known.getBitWidth();
2446
2447 assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) &&
2448 "Not integer or pointer type!");
2449
2450 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
2451 assert(
2452 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
2453 "DemandedElt width should equal the fixed vector number of elements");
2454 } else {
2455 assert(DemandedElts == APInt(1, 1) &&
2456 "DemandedElt width should be 1 for scalars or scalable vectors");
2457 }
2458
2459 Type *ScalarTy = Ty->getScalarType();
2460 if (ScalarTy->isPointerTy()) {
2461 assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) &&
2462 "V and Known should have same BitWidth");
2463 } else {
2464 assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) &&
2465 "V and Known should have same BitWidth");
2466 }
2467#endif
2468
2469 const APInt *C;
2470 if (match(V, P: m_APInt(Res&: C))) {
2471 // We know all of the bits for a scalar constant or a splat vector constant!
2472 Known = KnownBits::makeConstant(C: *C);
2473 return;
2474 }
2475 // Null and aggregate-zero are all-zeros.
2476 if (isa<ConstantPointerNull>(Val: V) || isa<ConstantAggregateZero>(Val: V)) {
2477 Known.setAllZero();
2478 return;
2479 }
2480 // Handle a constant vector by taking the intersection of the known bits of
2481 // each element.
2482 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(Val: V)) {
2483 assert(!isa<ScalableVectorType>(V->getType()));
2484 // We know that CDV must be a vector of integers. Take the intersection of
2485 // each element.
2486 Known.setAllConflict();
2487 for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) {
2488 if (!DemandedElts[i])
2489 continue;
2490 APInt Elt = CDV->getElementAsAPInt(i);
2491 Known.Zero &= ~Elt;
2492 Known.One &= Elt;
2493 }
2494 if (Known.hasConflict())
2495 Known.resetAll();
2496 return;
2497 }
2498
2499 if (const auto *CV = dyn_cast<ConstantVector>(Val: V)) {
2500 assert(!isa<ScalableVectorType>(V->getType()));
2501 // We know that CV must be a vector of integers. Take the intersection of
2502 // each element.
2503 Known.setAllConflict();
2504 for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
2505 if (!DemandedElts[i])
2506 continue;
2507 Constant *Element = CV->getAggregateElement(Elt: i);
2508 if (isa<PoisonValue>(Val: Element))
2509 continue;
2510 auto *ElementCI = dyn_cast_or_null<ConstantInt>(Val: Element);
2511 if (!ElementCI) {
2512 Known.resetAll();
2513 return;
2514 }
2515 const APInt &Elt = ElementCI->getValue();
2516 Known.Zero &= ~Elt;
2517 Known.One &= Elt;
2518 }
2519 if (Known.hasConflict())
2520 Known.resetAll();
2521 return;
2522 }
2523
2524 // Start out not knowing anything.
2525 Known.resetAll();
2526
2527 // We can't imply anything about undefs.
2528 if (isa<UndefValue>(Val: V))
2529 return;
2530
2531 // There's no point in looking through other users of ConstantData for
2532 // assumptions. Confirm that we've handled them all.
2533 assert(!isa<ConstantData>(V) && "Unhandled constant data!");
2534
2535 if (const auto *A = dyn_cast<Argument>(Val: V))
2536 if (std::optional<ConstantRange> Range = A->getRange())
2537 Known = Range->toKnownBits();
2538
2539 // All recursive calls that increase depth must come after this.
2540 if (Depth == MaxAnalysisRecursionDepth)
2541 return;
2542
2543 // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has
2544 // the bits of its aliasee.
2545 if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(Val: V)) {
2546 if (!GA->isInterposable())
2547 computeKnownBits(V: GA->getAliasee(), Known, Q, Depth: Depth + 1);
2548 return;
2549 }
2550
2551 if (const Operator *I = dyn_cast<Operator>(Val: V))
2552 computeKnownBitsFromOperator(I, DemandedElts, Known, Q, Depth);
2553 else if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
2554 if (std::optional<ConstantRange> CR = GV->getAbsoluteSymbolRange())
2555 Known = CR->toKnownBits();
2556 }
2557
2558 // Aligned pointers have trailing zeros - refine Known.Zero set
2559 if (isa<PointerType>(Val: V->getType())) {
2560 Align Alignment = V->getPointerAlignment(DL: Q.DL);
2561 Known.Zero.setLowBits(Log2(A: Alignment));
2562 }
2563
2564 // computeKnownBitsFromContext strictly refines Known.
2565 // Therefore, we run them after computeKnownBitsFromOperator.
2566
2567 // Check whether we can determine known bits from context such as assumes.
2568 computeKnownBitsFromContext(V, Known, Q, Depth);
2569}
2570
2571/// Try to detect a recurrence that the value of the induction variable is
2572/// always a power of two (or zero).
2573static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
2574 SimplifyQuery &Q, unsigned Depth) {
2575 BinaryOperator *BO = nullptr;
2576 Value *Start = nullptr, *Step = nullptr;
2577 if (!matchSimpleRecurrence(P: PN, BO, Start, Step))
2578 return false;
2579
2580 // Initial value must be a power of two.
2581 for (const Use &U : PN->operands()) {
2582 if (U.get() == Start) {
2583 // Initial value comes from a different BB, need to adjust context
2584 // instruction for analysis.
2585 Q.CxtI = PN->getIncomingBlock(U)->getTerminator();
2586 if (!isKnownToBeAPowerOfTwo(V: Start, OrZero, Q, Depth))
2587 return false;
2588 }
2589 }
2590
2591 // Except for Mul, the induction variable must be on the left side of the
2592 // increment expression, otherwise its value can be arbitrary.
2593 if (BO->getOpcode() != Instruction::Mul && BO->getOperand(i_nocapture: 1) != Step)
2594 return false;
2595
2596 Q.CxtI = BO->getParent()->getTerminator();
2597 switch (BO->getOpcode()) {
2598 case Instruction::Mul:
2599 // Power of two is closed under multiplication.
2600 return (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) ||
2601 Q.IIQ.hasNoSignedWrap(Op: BO)) &&
2602 isKnownToBeAPowerOfTwo(V: Step, OrZero, Q, Depth);
2603 case Instruction::SDiv:
2604 // Start value must not be signmask for signed division, so simply being a
2605 // power of two is not sufficient, and it has to be a constant.
2606 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2607 return false;
2608 [[fallthrough]];
2609 case Instruction::UDiv:
2610 // Divisor must be a power of two.
2611 // If OrZero is false, cannot guarantee induction variable is non-zero after
2612 // division, same for Shr, unless it is exact division.
2613 return (OrZero || Q.IIQ.isExact(Op: BO)) &&
2614 isKnownToBeAPowerOfTwo(V: Step, OrZero: false, Q, Depth);
2615 case Instruction::Shl:
2616 return OrZero || Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO);
2617 case Instruction::AShr:
2618 if (!match(V: Start, P: m_Power2()) || match(V: Start, P: m_SignMask()))
2619 return false;
2620 [[fallthrough]];
2621 case Instruction::LShr:
2622 return OrZero || Q.IIQ.isExact(Op: BO);
2623 default:
2624 return false;
2625 }
2626}
2627
2628/// Return true if we can infer that \p V is known to be a power of 2 from
2629/// dominating condition \p Cond (e.g., ctpop(V) == 1).
2630static bool isImpliedToBeAPowerOfTwoFromCond(const Value *V, bool OrZero,
2631 const Value *Cond,
2632 bool CondIsTrue) {
2633 CmpPredicate Pred;
2634 const APInt *RHSC;
2635 if (!match(V: Cond, P: m_ICmp(Pred, L: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Specific(V)),
2636 R: m_APInt(Res&: RHSC))))
2637 return false;
2638 if (!CondIsTrue)
2639 Pred = ICmpInst::getInversePredicate(pred: Pred);
2640 // ctpop(V) u< 2
2641 if (OrZero && Pred == ICmpInst::ICMP_ULT && *RHSC == 2)
2642 return true;
2643 // ctpop(V) == 1
2644 return Pred == ICmpInst::ICMP_EQ && *RHSC == 1;
2645}
2646
2647/// Return true if the given value is known to have exactly one
2648/// bit set when defined. For vectors return true if every element is known to
2649/// be a power of two when defined. Supports values with integer or pointer
2650/// types and vectors of integers.
2651bool llvm::isKnownToBeAPowerOfTwo(const Value *V, bool OrZero,
2652 const SimplifyQuery &Q, unsigned Depth) {
2653 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
2654
2655 if (isa<Constant>(Val: V))
2656 return OrZero ? match(V, P: m_Power2OrZero()) : match(V, P: m_Power2());
2657
2658 // i1 is by definition a power of 2 or zero.
2659 if (OrZero && V->getType()->getScalarSizeInBits() == 1)
2660 return true;
2661
2662 // Try to infer from assumptions.
2663 if (Q.AC && Q.CxtI) {
2664 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
2665 if (!AssumeVH)
2666 continue;
2667 CallInst *I = cast<CallInst>(Val&: AssumeVH);
2668 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond: I->getArgOperand(i: 0),
2669 /*CondIsTrue=*/true) &&
2670 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
2671 return true;
2672 }
2673 }
2674
2675 // Handle dominating conditions.
2676 if (Q.DC && Q.CxtI && Q.DT) {
2677 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
2678 Value *Cond = BI->getCondition();
2679
2680 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
2681 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2682 /*CondIsTrue=*/true) &&
2683 Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
2684 return true;
2685
2686 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
2687 if (isImpliedToBeAPowerOfTwoFromCond(V, OrZero, Cond,
2688 /*CondIsTrue=*/false) &&
2689 Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
2690 return true;
2691 }
2692 }
2693
2694 auto *I = dyn_cast<Instruction>(Val: V);
2695 if (!I)
2696 return false;
2697
2698 if (Q.CxtI && match(V, P: m_VScale())) {
2699 const Function *F = Q.CxtI->getFunction();
2700 // The vscale_range indicates vscale is a power-of-two.
2701 return F->hasFnAttribute(Kind: Attribute::VScaleRange);
2702 }
2703
2704 // 1 << X is clearly a power of two if the one is not shifted off the end. If
2705 // it is shifted off the end then the result is undefined.
2706 if (match(V: I, P: m_Shl(L: m_One(), R: m_Value())))
2707 return true;
2708
2709 // (signmask) >>l X is clearly a power of two if the one is not shifted off
2710 // the bottom. If it is shifted off the bottom then the result is undefined.
2711 if (match(V: I, P: m_LShr(L: m_SignMask(), R: m_Value())))
2712 return true;
2713
2714 // The remaining tests are all recursive, so bail out if we hit the limit.
2715 if (Depth++ == MaxAnalysisRecursionDepth)
2716 return false;
2717
2718 switch (I->getOpcode()) {
2719 case Instruction::ZExt:
2720 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2721 case Instruction::Trunc:
2722 return OrZero && isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2723 case Instruction::Shl:
2724 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: I) || Q.IIQ.hasNoSignedWrap(Op: I))
2725 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2726 return false;
2727 case Instruction::LShr:
2728 if (OrZero || Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2729 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2730 return false;
2731 case Instruction::UDiv:
2732 if (Q.IIQ.isExact(Op: cast<BinaryOperator>(Val: I)))
2733 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth);
2734 return false;
2735 case Instruction::Mul:
2736 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2737 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth) &&
2738 (OrZero || isKnownNonZero(V: I, Q, Depth));
2739 case Instruction::And:
2740 // A power of two and'd with anything is a power of two or zero.
2741 if (OrZero &&
2742 (isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), /*OrZero*/ true, Q, Depth) ||
2743 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), /*OrZero*/ true, Q, Depth)))
2744 return true;
2745 // X & (-X) is always a power of two or zero.
2746 if (match(V: I->getOperand(i: 0), P: m_Neg(V: m_Specific(V: I->getOperand(i: 1)))) ||
2747 match(V: I->getOperand(i: 1), P: m_Neg(V: m_Specific(V: I->getOperand(i: 0)))))
2748 return OrZero || isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
2749 return false;
2750 case Instruction::Add: {
2751 // Adding a power-of-two or zero to the same power-of-two or zero yields
2752 // either the original power-of-two, a larger power-of-two or zero.
2753 const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(Val: V);
2754 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO) ||
2755 Q.IIQ.hasNoSignedWrap(Op: VOBO)) {
2756 if (match(V: I->getOperand(i: 0),
2757 P: m_c_And(L: m_Specific(V: I->getOperand(i: 1)), R: m_Value())) &&
2758 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth))
2759 return true;
2760 if (match(V: I->getOperand(i: 1),
2761 P: m_c_And(L: m_Specific(V: I->getOperand(i: 0)), R: m_Value())) &&
2762 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 0), OrZero, Q, Depth))
2763 return true;
2764
2765 unsigned BitWidth = V->getType()->getScalarSizeInBits();
2766 KnownBits LHSBits(BitWidth);
2767 computeKnownBits(V: I->getOperand(i: 0), Known&: LHSBits, Q, Depth);
2768
2769 KnownBits RHSBits(BitWidth);
2770 computeKnownBits(V: I->getOperand(i: 1), Known&: RHSBits, Q, Depth);
2771 // If i8 V is a power of two or zero:
2772 // ZeroBits: 1 1 1 0 1 1 1 1
2773 // ~ZeroBits: 0 0 0 1 0 0 0 0
2774 if ((~(LHSBits.Zero & RHSBits.Zero)).isPowerOf2())
2775 // If OrZero isn't set, we cannot give back a zero result.
2776 // Make sure either the LHS or RHS has a bit set.
2777 if (OrZero || RHSBits.One.getBoolValue() || LHSBits.One.getBoolValue())
2778 return true;
2779 }
2780
2781 // LShr(UINT_MAX, Y) + 1 is a power of two (if add is nuw) or zero.
2782 if (OrZero || Q.IIQ.hasNoUnsignedWrap(Op: VOBO))
2783 if (match(V: I, P: m_Add(L: m_LShr(L: m_AllOnes(), R: m_Value()), R: m_One())))
2784 return true;
2785 return false;
2786 }
2787 case Instruction::Select:
2788 return isKnownToBeAPowerOfTwo(V: I->getOperand(i: 1), OrZero, Q, Depth) &&
2789 isKnownToBeAPowerOfTwo(V: I->getOperand(i: 2), OrZero, Q, Depth);
2790 case Instruction::PHI: {
2791 // A PHI node is power of two if all incoming values are power of two, or if
2792 // it is an induction variable where in each step its value is a power of
2793 // two.
2794 auto *PN = cast<PHINode>(Val: I);
2795 SimplifyQuery RecQ = Q.getWithoutCondContext();
2796
2797 // Check if it is an induction variable and always power of two.
2798 if (isPowerOfTwoRecurrence(PN, OrZero, Q&: RecQ, Depth))
2799 return true;
2800
2801 // Recursively check all incoming values. Limit recursion to 2 levels, so
2802 // that search complexity is limited to number of operands^2.
2803 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
2804 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
2805 // Value is power of 2 if it is coming from PHI node itself by induction.
2806 if (U.get() == PN)
2807 return true;
2808
2809 // Change the context instruction to the incoming block where it is
2810 // evaluated.
2811 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
2812 return isKnownToBeAPowerOfTwo(V: U.get(), OrZero, Q: RecQ, Depth: NewDepth);
2813 });
2814 }
2815 case Instruction::Invoke:
2816 case Instruction::Call: {
2817 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
2818 switch (II->getIntrinsicID()) {
2819 case Intrinsic::umax:
2820 case Intrinsic::smax:
2821 case Intrinsic::umin:
2822 case Intrinsic::smin:
2823 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 1), OrZero, Q, Depth) &&
2824 isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2825 // bswap/bitreverse just move around bits, but don't change any 1s/0s
2826 // thus dont change pow2/non-pow2 status.
2827 case Intrinsic::bitreverse:
2828 case Intrinsic::bswap:
2829 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2830 case Intrinsic::fshr:
2831 case Intrinsic::fshl:
2832 // If Op0 == Op1, this is a rotate. is_pow2(rotate(x, y)) == is_pow2(x)
2833 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
2834 return isKnownToBeAPowerOfTwo(V: II->getArgOperand(i: 0), OrZero, Q, Depth);
2835 break;
2836 default:
2837 break;
2838 }
2839 }
2840 return false;
2841 }
2842 default:
2843 return false;
2844 }
2845}
2846
2847/// Test whether a GEP's result is known to be non-null.
2848///
2849/// Uses properties inherent in a GEP to try to determine whether it is known
2850/// to be non-null.
2851///
2852/// Currently this routine does not support vector GEPs.
2853static bool isGEPKnownNonNull(const GEPOperator *GEP, const SimplifyQuery &Q,
2854 unsigned Depth) {
2855 const Function *F = nullptr;
2856 if (const Instruction *I = dyn_cast<Instruction>(Val: GEP))
2857 F = I->getFunction();
2858
2859 // If the gep is nuw or inbounds with invalid null pointer, then the GEP
2860 // may be null iff the base pointer is null and the offset is zero.
2861 if (!GEP->hasNoUnsignedWrap() &&
2862 !(GEP->isInBounds() &&
2863 !NullPointerIsDefined(F, AS: GEP->getPointerAddressSpace())))
2864 return false;
2865
2866 // FIXME: Support vector-GEPs.
2867 assert(GEP->getType()->isPointerTy() && "We only support plain pointer GEP");
2868
2869 // If the base pointer is non-null, we cannot walk to a null address with an
2870 // inbounds GEP in address space zero.
2871 if (isKnownNonZero(V: GEP->getPointerOperand(), Q, Depth))
2872 return true;
2873
2874 // Walk the GEP operands and see if any operand introduces a non-zero offset.
2875 // If so, then the GEP cannot produce a null pointer, as doing so would
2876 // inherently violate the inbounds contract within address space zero.
2877 for (gep_type_iterator GTI = gep_type_begin(GEP), GTE = gep_type_end(GEP);
2878 GTI != GTE; ++GTI) {
2879 // Struct types are easy -- they must always be indexed by a constant.
2880 if (StructType *STy = GTI.getStructTypeOrNull()) {
2881 ConstantInt *OpC = cast<ConstantInt>(Val: GTI.getOperand());
2882 unsigned ElementIdx = OpC->getZExtValue();
2883 const StructLayout *SL = Q.DL.getStructLayout(Ty: STy);
2884 uint64_t ElementOffset = SL->getElementOffset(Idx: ElementIdx);
2885 if (ElementOffset > 0)
2886 return true;
2887 continue;
2888 }
2889
2890 // If we have a zero-sized type, the index doesn't matter. Keep looping.
2891 if (GTI.getSequentialElementStride(DL: Q.DL).isZero())
2892 continue;
2893
2894 // Fast path the constant operand case both for efficiency and so we don't
2895 // increment Depth when just zipping down an all-constant GEP.
2896 if (ConstantInt *OpC = dyn_cast<ConstantInt>(Val: GTI.getOperand())) {
2897 if (!OpC->isZero())
2898 return true;
2899 continue;
2900 }
2901
2902 // We post-increment Depth here because while isKnownNonZero increments it
2903 // as well, when we pop back up that increment won't persist. We don't want
2904 // to recurse 10k times just because we have 10k GEP operands. We don't
2905 // bail completely out because we want to handle constant GEPs regardless
2906 // of depth.
2907 if (Depth++ >= MaxAnalysisRecursionDepth)
2908 continue;
2909
2910 if (isKnownNonZero(V: GTI.getOperand(), Q, Depth))
2911 return true;
2912 }
2913
2914 return false;
2915}
2916
2917static bool isKnownNonNullFromDominatingCondition(const Value *V,
2918 const Instruction *CtxI,
2919 const DominatorTree *DT) {
2920 assert(!isa<Constant>(V) && "Called for constant?");
2921
2922 if (!CtxI || !DT)
2923 return false;
2924
2925 unsigned NumUsesExplored = 0;
2926 for (auto &U : V->uses()) {
2927 // Avoid massive lists
2928 if (NumUsesExplored >= DomConditionsMaxUses)
2929 break;
2930 NumUsesExplored++;
2931
2932 const Instruction *UI = cast<Instruction>(Val: U.getUser());
2933 // If the value is used as an argument to a call or invoke, then argument
2934 // attributes may provide an answer about null-ness.
2935 if (V->getType()->isPointerTy()) {
2936 if (const auto *CB = dyn_cast<CallBase>(Val: UI)) {
2937 if (CB->isArgOperand(U: &U) &&
2938 CB->paramHasNonNullAttr(ArgNo: CB->getArgOperandNo(U: &U),
2939 /*AllowUndefOrPoison=*/false) &&
2940 DT->dominates(Def: CB, User: CtxI))
2941 return true;
2942 }
2943 }
2944
2945 // If the value is used as a load/store, then the pointer must be non null.
2946 if (V == getLoadStorePointerOperand(V: UI)) {
2947 if (!NullPointerIsDefined(F: UI->getFunction(),
2948 AS: V->getType()->getPointerAddressSpace()) &&
2949 DT->dominates(Def: UI, User: CtxI))
2950 return true;
2951 }
2952
2953 if ((match(V: UI, P: m_IDiv(L: m_Value(), R: m_Specific(V))) ||
2954 match(V: UI, P: m_IRem(L: m_Value(), R: m_Specific(V)))) &&
2955 isValidAssumeForContext(Inv: UI, CxtI: CtxI, DT))
2956 return true;
2957
2958 // Consider only compare instructions uniquely controlling a branch
2959 Value *RHS;
2960 CmpPredicate Pred;
2961 if (!match(V: UI, P: m_c_ICmp(Pred, L: m_Specific(V), R: m_Value(V&: RHS))))
2962 continue;
2963
2964 bool NonNullIfTrue;
2965 if (cmpExcludesZero(Pred, RHS))
2966 NonNullIfTrue = true;
2967 else if (cmpExcludesZero(Pred: CmpInst::getInversePredicate(pred: Pred), RHS))
2968 NonNullIfTrue = false;
2969 else
2970 continue;
2971
2972 SmallVector<const User *, 4> WorkList;
2973 SmallPtrSet<const User *, 4> Visited;
2974 for (const auto *CmpU : UI->users()) {
2975 assert(WorkList.empty() && "Should be!");
2976 if (Visited.insert(Ptr: CmpU).second)
2977 WorkList.push_back(Elt: CmpU);
2978
2979 while (!WorkList.empty()) {
2980 auto *Curr = WorkList.pop_back_val();
2981
2982 // If a user is an AND, add all its users to the work list. We only
2983 // propagate "pred != null" condition through AND because it is only
2984 // correct to assume that all conditions of AND are met in true branch.
2985 // TODO: Support similar logic of OR and EQ predicate?
2986 if (NonNullIfTrue)
2987 if (match(V: Curr, P: m_LogicalAnd(L: m_Value(), R: m_Value()))) {
2988 for (const auto *CurrU : Curr->users())
2989 if (Visited.insert(Ptr: CurrU).second)
2990 WorkList.push_back(Elt: CurrU);
2991 continue;
2992 }
2993
2994 if (const BranchInst *BI = dyn_cast<BranchInst>(Val: Curr)) {
2995 assert(BI->isConditional() && "uses a comparison!");
2996
2997 BasicBlock *NonNullSuccessor =
2998 BI->getSuccessor(i: NonNullIfTrue ? 0 : 1);
2999 BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor);
3000 if (Edge.isSingleEdge() && DT->dominates(BBE: Edge, BB: CtxI->getParent()))
3001 return true;
3002 } else if (NonNullIfTrue && isGuard(U: Curr) &&
3003 DT->dominates(Def: cast<Instruction>(Val: Curr), User: CtxI)) {
3004 return true;
3005 }
3006 }
3007 }
3008 }
3009
3010 return false;
3011}
3012
3013/// Does the 'Range' metadata (which must be a valid MD_range operand list)
3014/// ensure that the value it's attached to is never Value? 'RangeType' is
3015/// is the type of the value described by the range.
3016static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) {
3017 const unsigned NumRanges = Ranges->getNumOperands() / 2;
3018 assert(NumRanges >= 1);
3019 for (unsigned i = 0; i < NumRanges; ++i) {
3020 ConstantInt *Lower =
3021 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 0));
3022 ConstantInt *Upper =
3023 mdconst::extract<ConstantInt>(MD: Ranges->getOperand(I: 2 * i + 1));
3024 ConstantRange Range(Lower->getValue(), Upper->getValue());
3025 if (Range.contains(Val: Value))
3026 return false;
3027 }
3028 return true;
3029}
3030
3031/// Try to detect a recurrence that monotonically increases/decreases from a
3032/// non-zero starting value. These are common as induction variables.
3033static bool isNonZeroRecurrence(const PHINode *PN) {
3034 BinaryOperator *BO = nullptr;
3035 Value *Start = nullptr, *Step = nullptr;
3036 const APInt *StartC, *StepC;
3037 if (!matchSimpleRecurrence(P: PN, BO, Start, Step) ||
3038 !match(V: Start, P: m_APInt(Res&: StartC)) || StartC->isZero())
3039 return false;
3040
3041 switch (BO->getOpcode()) {
3042 case Instruction::Add:
3043 // Starting from non-zero and stepping away from zero can never wrap back
3044 // to zero.
3045 return BO->hasNoUnsignedWrap() ||
3046 (BO->hasNoSignedWrap() && match(V: Step, P: m_APInt(Res&: StepC)) &&
3047 StartC->isNegative() == StepC->isNegative());
3048 case Instruction::Mul:
3049 return (BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) &&
3050 match(V: Step, P: m_APInt(Res&: StepC)) && !StepC->isZero();
3051 case Instruction::Shl:
3052 return BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap();
3053 case Instruction::AShr:
3054 case Instruction::LShr:
3055 return BO->isExact();
3056 default:
3057 return false;
3058 }
3059}
3060
3061static bool matchOpWithOpEqZero(Value *Op0, Value *Op1) {
3062 return match(V: Op0, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3063 L: m_Specific(V: Op1), R: m_Zero()))) ||
3064 match(V: Op1, P: m_ZExtOrSExt(Op: m_SpecificICmp(MatchPred: ICmpInst::ICMP_EQ,
3065 L: m_Specific(V: Op0), R: m_Zero())));
3066}
3067
3068static bool isNonZeroAdd(const APInt &DemandedElts, const SimplifyQuery &Q,
3069 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3070 bool NUW, unsigned Depth) {
3071 // (X + (X != 0)) is non zero
3072 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3073 return true;
3074
3075 if (NUW)
3076 return isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3077 isKnownNonZero(V: X, DemandedElts, Q, Depth);
3078
3079 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3080 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3081
3082 // If X and Y are both non-negative (as signed values) then their sum is not
3083 // zero unless both X and Y are zero.
3084 if (XKnown.isNonNegative() && YKnown.isNonNegative())
3085 if (isKnownNonZero(V: Y, DemandedElts, Q, Depth) ||
3086 isKnownNonZero(V: X, DemandedElts, Q, Depth))
3087 return true;
3088
3089 // If X and Y are both negative (as signed values) then their sum is not
3090 // zero unless both X and Y equal INT_MIN.
3091 if (XKnown.isNegative() && YKnown.isNegative()) {
3092 APInt Mask = APInt::getSignedMaxValue(numBits: BitWidth);
3093 // The sign bit of X is set. If some other bit is set then X is not equal
3094 // to INT_MIN.
3095 if (XKnown.One.intersects(RHS: Mask))
3096 return true;
3097 // The sign bit of Y is set. If some other bit is set then Y is not equal
3098 // to INT_MIN.
3099 if (YKnown.One.intersects(RHS: Mask))
3100 return true;
3101 }
3102
3103 // The sum of a non-negative number and a power of two is not zero.
3104 if (XKnown.isNonNegative() &&
3105 isKnownToBeAPowerOfTwo(V: Y, /*OrZero*/ false, Q, Depth))
3106 return true;
3107 if (YKnown.isNonNegative() &&
3108 isKnownToBeAPowerOfTwo(V: X, /*OrZero*/ false, Q, Depth))
3109 return true;
3110
3111 return KnownBits::add(LHS: XKnown, RHS: YKnown, NSW, NUW).isNonZero();
3112}
3113
3114static bool isNonZeroSub(const APInt &DemandedElts, const SimplifyQuery &Q,
3115 unsigned BitWidth, Value *X, Value *Y,
3116 unsigned Depth) {
3117 // (X - (X != 0)) is non zero
3118 // ((X != 0) - X) is non zero
3119 if (matchOpWithOpEqZero(Op0: X, Op1: Y))
3120 return true;
3121
3122 // TODO: Move this case into isKnownNonEqual().
3123 if (auto *C = dyn_cast<Constant>(Val: X))
3124 if (C->isNullValue() && isKnownNonZero(V: Y, DemandedElts, Q, Depth))
3125 return true;
3126
3127 return ::isKnownNonEqual(V1: X, V2: Y, DemandedElts, Q, Depth);
3128}
3129
3130static bool isNonZeroMul(const APInt &DemandedElts, const SimplifyQuery &Q,
3131 unsigned BitWidth, Value *X, Value *Y, bool NSW,
3132 bool NUW, unsigned Depth) {
3133 // If X and Y are non-zero then so is X * Y as long as the multiplication
3134 // does not overflow.
3135 if (NSW || NUW)
3136 return isKnownNonZero(V: X, DemandedElts, Q, Depth) &&
3137 isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3138
3139 // If either X or Y is odd, then if the other is non-zero the result can't
3140 // be zero.
3141 KnownBits XKnown = computeKnownBits(V: X, DemandedElts, Q, Depth);
3142 if (XKnown.One[0])
3143 return isKnownNonZero(V: Y, DemandedElts, Q, Depth);
3144
3145 KnownBits YKnown = computeKnownBits(V: Y, DemandedElts, Q, Depth);
3146 if (YKnown.One[0])
3147 return XKnown.isNonZero() || isKnownNonZero(V: X, DemandedElts, Q, Depth);
3148
3149 // If there exists any subset of X (sX) and subset of Y (sY) s.t sX * sY is
3150 // non-zero, then X * Y is non-zero. We can find sX and sY by just taking
3151 // the lowest known One of X and Y. If they are non-zero, the result
3152 // must be non-zero. We can check if LSB(X) * LSB(Y) != 0 by doing
3153 // X.CountLeadingZeros + Y.CountLeadingZeros < BitWidth.
3154 return (XKnown.countMaxTrailingZeros() + YKnown.countMaxTrailingZeros()) <
3155 BitWidth;
3156}
3157
3158static bool isNonZeroShift(const Operator *I, const APInt &DemandedElts,
3159 const SimplifyQuery &Q, const KnownBits &KnownVal,
3160 unsigned Depth) {
3161 auto ShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3162 switch (I->getOpcode()) {
3163 case Instruction::Shl:
3164 return Lhs.shl(ShiftAmt: Rhs);
3165 case Instruction::LShr:
3166 return Lhs.lshr(ShiftAmt: Rhs);
3167 case Instruction::AShr:
3168 return Lhs.ashr(ShiftAmt: Rhs);
3169 default:
3170 llvm_unreachable("Unknown Shift Opcode");
3171 }
3172 };
3173
3174 auto InvShiftOp = [&](const APInt &Lhs, const APInt &Rhs) {
3175 switch (I->getOpcode()) {
3176 case Instruction::Shl:
3177 return Lhs.lshr(ShiftAmt: Rhs);
3178 case Instruction::LShr:
3179 case Instruction::AShr:
3180 return Lhs.shl(ShiftAmt: Rhs);
3181 default:
3182 llvm_unreachable("Unknown Shift Opcode");
3183 }
3184 };
3185
3186 if (KnownVal.isUnknown())
3187 return false;
3188
3189 KnownBits KnownCnt =
3190 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3191 APInt MaxShift = KnownCnt.getMaxValue();
3192 unsigned NumBits = KnownVal.getBitWidth();
3193 if (MaxShift.uge(RHS: NumBits))
3194 return false;
3195
3196 if (!ShiftOp(KnownVal.One, MaxShift).isZero())
3197 return true;
3198
3199 // If all of the bits shifted out are known to be zero, and Val is known
3200 // non-zero then at least one non-zero bit must remain.
3201 if (InvShiftOp(KnownVal.Zero, NumBits - MaxShift)
3202 .eq(RHS: InvShiftOp(APInt::getAllOnes(numBits: NumBits), NumBits - MaxShift)) &&
3203 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth))
3204 return true;
3205
3206 return false;
3207}
3208
3209static bool isKnownNonZeroFromOperator(const Operator *I,
3210 const APInt &DemandedElts,
3211 const SimplifyQuery &Q, unsigned Depth) {
3212 unsigned BitWidth = getBitWidth(Ty: I->getType()->getScalarType(), DL: Q.DL);
3213 switch (I->getOpcode()) {
3214 case Instruction::Alloca:
3215 // Alloca never returns null, malloc might.
3216 return I->getType()->getPointerAddressSpace() == 0;
3217 case Instruction::GetElementPtr:
3218 if (I->getType()->isPointerTy())
3219 return isGEPKnownNonNull(GEP: cast<GEPOperator>(Val: I), Q, Depth);
3220 break;
3221 case Instruction::BitCast: {
3222 // We need to be a bit careful here. We can only peek through the bitcast
3223 // if the scalar size of elements in the operand are smaller than and a
3224 // multiple of the size they are casting too. Take three cases:
3225 //
3226 // 1) Unsafe:
3227 // bitcast <2 x i16> %NonZero to <4 x i8>
3228 //
3229 // %NonZero can have 2 non-zero i16 elements, but isKnownNonZero on a
3230 // <4 x i8> requires that all 4 i8 elements be non-zero which isn't
3231 // guranteed (imagine just sign bit set in the 2 i16 elements).
3232 //
3233 // 2) Unsafe:
3234 // bitcast <4 x i3> %NonZero to <3 x i4>
3235 //
3236 // Even though the scalar size of the src (`i3`) is smaller than the
3237 // scalar size of the dst `i4`, because `i3` is not a multiple of `i4`
3238 // its possible for the `3 x i4` elements to be zero because there are
3239 // some elements in the destination that don't contain any full src
3240 // element.
3241 //
3242 // 3) Safe:
3243 // bitcast <4 x i8> %NonZero to <2 x i16>
3244 //
3245 // This is always safe as non-zero in the 4 i8 elements implies
3246 // non-zero in the combination of any two adjacent ones. Since i8 is a
3247 // multiple of i16, each i16 is guranteed to have 2 full i8 elements.
3248 // This all implies the 2 i16 elements are non-zero.
3249 Type *FromTy = I->getOperand(i: 0)->getType();
3250 if ((FromTy->isIntOrIntVectorTy() || FromTy->isPtrOrPtrVectorTy()) &&
3251 (BitWidth % getBitWidth(Ty: FromTy->getScalarType(), DL: Q.DL)) == 0)
3252 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3253 } break;
3254 case Instruction::IntToPtr:
3255 // Note that we have to take special care to avoid looking through
3256 // truncating casts, e.g., int2ptr/ptr2int with appropriate sizes, as well
3257 // as casts that can alter the value, e.g., AddrSpaceCasts.
3258 if (!isa<ScalableVectorType>(Val: I->getType()) &&
3259 Q.DL.getTypeSizeInBits(Ty: I->getOperand(i: 0)->getType()).getFixedValue() <=
3260 Q.DL.getTypeSizeInBits(Ty: I->getType()).getFixedValue())
3261 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3262 break;
3263 case Instruction::PtrToAddr:
3264 // isKnownNonZero() for pointers refers to the address bits being non-zero,
3265 // so we can directly forward.
3266 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3267 case Instruction::PtrToInt:
3268 // For inttoptr, make sure the result size is >= the address size. If the
3269 // address is non-zero, any larger value is also non-zero.
3270 if (Q.DL.getAddressSizeInBits(Ty: I->getOperand(i: 0)->getType()) <=
3271 I->getType()->getScalarSizeInBits())
3272 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3273 break;
3274 case Instruction::Trunc:
3275 // nuw/nsw trunc preserves zero/non-zero status of input.
3276 if (auto *TI = dyn_cast<TruncInst>(Val: I))
3277 if (TI->hasNoSignedWrap() || TI->hasNoUnsignedWrap())
3278 return isKnownNonZero(V: TI->getOperand(i_nocapture: 0), DemandedElts, Q, Depth);
3279 break;
3280
3281 // Iff x - y != 0, then x ^ y != 0
3282 // Therefore we can do the same exact checks
3283 case Instruction::Xor:
3284 case Instruction::Sub:
3285 return isNonZeroSub(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3286 Y: I->getOperand(i: 1), Depth);
3287 case Instruction::Or:
3288 // (X | (X != 0)) is non zero
3289 if (matchOpWithOpEqZero(Op0: I->getOperand(i: 0), Op1: I->getOperand(i: 1)))
3290 return true;
3291 // X | Y != 0 if X != Y.
3292 if (isKnownNonEqual(V1: I->getOperand(i: 0), V2: I->getOperand(i: 1), DemandedElts, Q,
3293 Depth))
3294 return true;
3295 // X | Y != 0 if X != 0 or Y != 0.
3296 return isKnownNonZero(V: I->getOperand(i: 1), DemandedElts, Q, Depth) ||
3297 isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3298 case Instruction::SExt:
3299 case Instruction::ZExt:
3300 // ext X != 0 if X != 0.
3301 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3302
3303 case Instruction::Shl: {
3304 // shl nsw/nuw can't remove any non-zero bits.
3305 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3306 if (Q.IIQ.hasNoUnsignedWrap(Op: BO) || Q.IIQ.hasNoSignedWrap(Op: BO))
3307 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3308
3309 // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined
3310 // if the lowest bit is shifted off the end.
3311 KnownBits Known(BitWidth);
3312 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Known, Q, Depth);
3313 if (Known.One[0])
3314 return true;
3315
3316 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3317 }
3318 case Instruction::LShr:
3319 case Instruction::AShr: {
3320 // shr exact can only shift out zero bits.
3321 const PossiblyExactOperator *BO = cast<PossiblyExactOperator>(Val: I);
3322 if (BO->isExact())
3323 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3324
3325 // shr X, Y != 0 if X is negative. Note that the value of the shift is not
3326 // defined if the sign bit is shifted off the end.
3327 KnownBits Known =
3328 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3329 if (Known.isNegative())
3330 return true;
3331
3332 return isNonZeroShift(I, DemandedElts, Q, KnownVal: Known, Depth);
3333 }
3334 case Instruction::UDiv:
3335 case Instruction::SDiv: {
3336 // X / Y
3337 // div exact can only produce a zero if the dividend is zero.
3338 if (cast<PossiblyExactOperator>(Val: I)->isExact())
3339 return isKnownNonZero(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3340
3341 KnownBits XKnown =
3342 computeKnownBits(V: I->getOperand(i: 0), DemandedElts, Q, Depth);
3343 // If X is fully unknown we won't be able to figure anything out so don't
3344 // both computing knownbits for Y.
3345 if (XKnown.isUnknown())
3346 return false;
3347
3348 KnownBits YKnown =
3349 computeKnownBits(V: I->getOperand(i: 1), DemandedElts, Q, Depth);
3350 if (I->getOpcode() == Instruction::SDiv) {
3351 // For signed division need to compare abs value of the operands.
3352 XKnown = XKnown.abs(/*IntMinIsPoison*/ false);
3353 YKnown = YKnown.abs(/*IntMinIsPoison*/ false);
3354 }
3355 // If X u>= Y then div is non zero (0/0 is UB).
3356 std::optional<bool> XUgeY = KnownBits::uge(LHS: XKnown, RHS: YKnown);
3357 // If X is total unknown or X u< Y we won't be able to prove non-zero
3358 // with compute known bits so just return early.
3359 return XUgeY && *XUgeY;
3360 }
3361 case Instruction::Add: {
3362 // X + Y.
3363
3364 // If Add has nuw wrap flag, then if either X or Y is non-zero the result is
3365 // non-zero.
3366 auto *BO = cast<OverflowingBinaryOperator>(Val: I);
3367 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3368 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3369 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3370 }
3371 case Instruction::Mul: {
3372 const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(Val: I);
3373 return isNonZeroMul(DemandedElts, Q, BitWidth, X: I->getOperand(i: 0),
3374 Y: I->getOperand(i: 1), NSW: Q.IIQ.hasNoSignedWrap(Op: BO),
3375 NUW: Q.IIQ.hasNoUnsignedWrap(Op: BO), Depth);
3376 }
3377 case Instruction::Select: {
3378 // (C ? X : Y) != 0 if X != 0 and Y != 0.
3379
3380 // First check if the arm is non-zero using `isKnownNonZero`. If that fails,
3381 // then see if the select condition implies the arm is non-zero. For example
3382 // (X != 0 ? X : Y), we know the true arm is non-zero as the `X` "return" is
3383 // dominated by `X != 0`.
3384 auto SelectArmIsNonZero = [&](bool IsTrueArm) {
3385 Value *Op;
3386 Op = IsTrueArm ? I->getOperand(i: 1) : I->getOperand(i: 2);
3387 // Op is trivially non-zero.
3388 if (isKnownNonZero(V: Op, DemandedElts, Q, Depth))
3389 return true;
3390
3391 // The condition of the select dominates the true/false arm. Check if the
3392 // condition implies that a given arm is non-zero.
3393 Value *X;
3394 CmpPredicate Pred;
3395 if (!match(V: I->getOperand(i: 0), P: m_c_ICmp(Pred, L: m_Specific(V: Op), R: m_Value(V&: X))))
3396 return false;
3397
3398 if (!IsTrueArm)
3399 Pred = ICmpInst::getInversePredicate(pred: Pred);
3400
3401 return cmpExcludesZero(Pred, RHS: X);
3402 };
3403
3404 if (SelectArmIsNonZero(/* IsTrueArm */ true) &&
3405 SelectArmIsNonZero(/* IsTrueArm */ false))
3406 return true;
3407 break;
3408 }
3409 case Instruction::PHI: {
3410 auto *PN = cast<PHINode>(Val: I);
3411 if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN))
3412 return true;
3413
3414 // Check if all incoming values are non-zero using recursion.
3415 SimplifyQuery RecQ = Q.getWithoutCondContext();
3416 unsigned NewDepth = std::max(a: Depth, b: MaxAnalysisRecursionDepth - 1);
3417 return llvm::all_of(Range: PN->operands(), P: [&](const Use &U) {
3418 if (U.get() == PN)
3419 return true;
3420 RecQ.CxtI = PN->getIncomingBlock(U)->getTerminator();
3421 // Check if the branch on the phi excludes zero.
3422 CmpPredicate Pred;
3423 Value *X;
3424 BasicBlock *TrueSucc, *FalseSucc;
3425 if (match(V: RecQ.CxtI,
3426 P: m_Br(C: m_c_ICmp(Pred, L: m_Specific(V: U.get()), R: m_Value(V&: X)),
3427 T: m_BasicBlock(V&: TrueSucc), F: m_BasicBlock(V&: FalseSucc)))) {
3428 // Check for cases of duplicate successors.
3429 if ((TrueSucc == PN->getParent()) != (FalseSucc == PN->getParent())) {
3430 // If we're using the false successor, invert the predicate.
3431 if (FalseSucc == PN->getParent())
3432 Pred = CmpInst::getInversePredicate(pred: Pred);
3433 if (cmpExcludesZero(Pred, RHS: X))
3434 return true;
3435 }
3436 }
3437 // Finally recurse on the edge and check it directly.
3438 return isKnownNonZero(V: U.get(), DemandedElts, Q: RecQ, Depth: NewDepth);
3439 });
3440 }
3441 case Instruction::InsertElement: {
3442 if (isa<ScalableVectorType>(Val: I->getType()))
3443 break;
3444
3445 const Value *Vec = I->getOperand(i: 0);
3446 const Value *Elt = I->getOperand(i: 1);
3447 auto *CIdx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
3448
3449 unsigned NumElts = DemandedElts.getBitWidth();
3450 APInt DemandedVecElts = DemandedElts;
3451 bool SkipElt = false;
3452 // If we know the index we are inserting too, clear it from Vec check.
3453 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
3454 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
3455 SkipElt = !DemandedElts[CIdx->getZExtValue()];
3456 }
3457
3458 // Result is zero if Elt is non-zero and rest of the demanded elts in Vec
3459 // are non-zero.
3460 return (SkipElt || isKnownNonZero(V: Elt, Q, Depth)) &&
3461 (DemandedVecElts.isZero() ||
3462 isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth));
3463 }
3464 case Instruction::ExtractElement:
3465 if (const auto *EEI = dyn_cast<ExtractElementInst>(Val: I)) {
3466 const Value *Vec = EEI->getVectorOperand();
3467 const Value *Idx = EEI->getIndexOperand();
3468 auto *CIdx = dyn_cast<ConstantInt>(Val: Idx);
3469 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
3470 unsigned NumElts = VecTy->getNumElements();
3471 APInt DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
3472 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
3473 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
3474 return isKnownNonZero(V: Vec, DemandedElts: DemandedVecElts, Q, Depth);
3475 }
3476 }
3477 break;
3478 case Instruction::ShuffleVector: {
3479 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: I);
3480 if (!Shuf)
3481 break;
3482 APInt DemandedLHS, DemandedRHS;
3483 // For undef elements, we don't know anything about the common state of
3484 // the shuffle result.
3485 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
3486 break;
3487 // If demanded elements for both vecs are non-zero, the shuffle is non-zero.
3488 return (DemandedRHS.isZero() ||
3489 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 1), DemandedElts: DemandedRHS, Q, Depth)) &&
3490 (DemandedLHS.isZero() ||
3491 isKnownNonZero(V: Shuf->getOperand(i_nocapture: 0), DemandedElts: DemandedLHS, Q, Depth));
3492 }
3493 case Instruction::Freeze:
3494 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth) &&
3495 isGuaranteedNotToBePoison(V: I->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
3496 Depth);
3497 case Instruction::Load: {
3498 auto *LI = cast<LoadInst>(Val: I);
3499 // A Load tagged with nonnull or dereferenceable with null pointer undefined
3500 // is never null.
3501 if (auto *PtrT = dyn_cast<PointerType>(Val: I->getType())) {
3502 if (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_nonnull) ||
3503 (Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_dereferenceable) &&
3504 !NullPointerIsDefined(F: LI->getFunction(), AS: PtrT->getAddressSpace())))
3505 return true;
3506 } else if (MDNode *Ranges = Q.IIQ.getMetadata(I: LI, KindID: LLVMContext::MD_range)) {
3507 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3508 }
3509
3510 // No need to fall through to computeKnownBits as range metadata is already
3511 // handled in isKnownNonZero.
3512 return false;
3513 }
3514 case Instruction::ExtractValue: {
3515 const WithOverflowInst *WO;
3516 if (match(V: I, P: m_ExtractValue<0>(V: m_WithOverflowInst(I&: WO)))) {
3517 switch (WO->getBinaryOp()) {
3518 default:
3519 break;
3520 case Instruction::Add:
3521 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3522 Y: WO->getArgOperand(i: 1),
3523 /*NSW=*/false,
3524 /*NUW=*/false, Depth);
3525 case Instruction::Sub:
3526 return isNonZeroSub(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3527 Y: WO->getArgOperand(i: 1), Depth);
3528 case Instruction::Mul:
3529 return isNonZeroMul(DemandedElts, Q, BitWidth, X: WO->getArgOperand(i: 0),
3530 Y: WO->getArgOperand(i: 1),
3531 /*NSW=*/false, /*NUW=*/false, Depth);
3532 break;
3533 }
3534 }
3535 break;
3536 }
3537 case Instruction::Call:
3538 case Instruction::Invoke: {
3539 const auto *Call = cast<CallBase>(Val: I);
3540 if (I->getType()->isPointerTy()) {
3541 if (Call->isReturnNonNull())
3542 return true;
3543 if (const auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: true))
3544 return isKnownNonZero(V: RP, Q, Depth);
3545 } else {
3546 if (MDNode *Ranges = Q.IIQ.getMetadata(I: Call, KindID: LLVMContext::MD_range))
3547 return rangeMetadataExcludesValue(Ranges, Value: APInt::getZero(numBits: BitWidth));
3548 if (std::optional<ConstantRange> Range = Call->getRange()) {
3549 const APInt ZeroValue(Range->getBitWidth(), 0);
3550 if (!Range->contains(Val: ZeroValue))
3551 return true;
3552 }
3553 if (const Value *RV = Call->getReturnedArgOperand())
3554 if (RV->getType() == I->getType() && isKnownNonZero(V: RV, Q, Depth))
3555 return true;
3556 }
3557
3558 if (auto *II = dyn_cast<IntrinsicInst>(Val: I)) {
3559 switch (II->getIntrinsicID()) {
3560 case Intrinsic::sshl_sat:
3561 case Intrinsic::ushl_sat:
3562 case Intrinsic::abs:
3563 case Intrinsic::bitreverse:
3564 case Intrinsic::bswap:
3565 case Intrinsic::ctpop:
3566 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3567 // NB: We don't do usub_sat here as in any case we can prove its
3568 // non-zero, we will fold it to `sub nuw` in InstCombine.
3569 case Intrinsic::ssub_sat:
3570 // For most types, if x != y then ssub.sat x, y != 0. But
3571 // ssub.sat.i1 0, -1 = 0, because 1 saturates to 0. This means
3572 // isNonZeroSub will do the wrong thing for ssub.sat.i1.
3573 if (BitWidth == 1)
3574 return false;
3575 return isNonZeroSub(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3576 Y: II->getArgOperand(i: 1), Depth);
3577 case Intrinsic::sadd_sat:
3578 return isNonZeroAdd(DemandedElts, Q, BitWidth, X: II->getArgOperand(i: 0),
3579 Y: II->getArgOperand(i: 1),
3580 /*NSW=*/true, /* NUW=*/false, Depth);
3581 // Vec reverse preserves zero/non-zero status from input vec.
3582 case Intrinsic::vector_reverse:
3583 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
3584 Q, Depth);
3585 // umin/smin/smax/smin/or of all non-zero elements is always non-zero.
3586 case Intrinsic::vector_reduce_or:
3587 case Intrinsic::vector_reduce_umax:
3588 case Intrinsic::vector_reduce_umin:
3589 case Intrinsic::vector_reduce_smax:
3590 case Intrinsic::vector_reduce_smin:
3591 return isKnownNonZero(V: II->getArgOperand(i: 0), Q, Depth);
3592 case Intrinsic::umax:
3593 case Intrinsic::uadd_sat:
3594 // umax(X, (X != 0)) is non zero
3595 // X +usat (X != 0) is non zero
3596 if (matchOpWithOpEqZero(Op0: II->getArgOperand(i: 0), Op1: II->getArgOperand(i: 1)))
3597 return true;
3598
3599 return isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth) ||
3600 isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3601 case Intrinsic::smax: {
3602 // If either arg is strictly positive the result is non-zero. Otherwise
3603 // the result is non-zero if both ops are non-zero.
3604 auto IsNonZero = [&](Value *Op, std::optional<bool> &OpNonZero,
3605 const KnownBits &OpKnown) {
3606 if (!OpNonZero.has_value())
3607 OpNonZero = OpKnown.isNonZero() ||
3608 isKnownNonZero(V: Op, DemandedElts, Q, Depth);
3609 return *OpNonZero;
3610 };
3611 // Avoid re-computing isKnownNonZero.
3612 std::optional<bool> Op0NonZero, Op1NonZero;
3613 KnownBits Op1Known =
3614 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3615 if (Op1Known.isNonNegative() &&
3616 IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known))
3617 return true;
3618 KnownBits Op0Known =
3619 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3620 if (Op0Known.isNonNegative() &&
3621 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known))
3622 return true;
3623 return IsNonZero(II->getArgOperand(i: 1), Op1NonZero, Op1Known) &&
3624 IsNonZero(II->getArgOperand(i: 0), Op0NonZero, Op0Known);
3625 }
3626 case Intrinsic::smin: {
3627 // If either arg is negative the result is non-zero. Otherwise
3628 // the result is non-zero if both ops are non-zero.
3629 KnownBits Op1Known =
3630 computeKnownBits(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3631 if (Op1Known.isNegative())
3632 return true;
3633 KnownBits Op0Known =
3634 computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3635 if (Op0Known.isNegative())
3636 return true;
3637
3638 if (Op1Known.isNonZero() && Op0Known.isNonZero())
3639 return true;
3640 }
3641 [[fallthrough]];
3642 case Intrinsic::umin:
3643 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth) &&
3644 isKnownNonZero(V: II->getArgOperand(i: 1), DemandedElts, Q, Depth);
3645 case Intrinsic::cttz:
3646 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3647 .Zero[0];
3648 case Intrinsic::ctlz:
3649 return computeKnownBits(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth)
3650 .isNonNegative();
3651 case Intrinsic::fshr:
3652 case Intrinsic::fshl:
3653 // If Op0 == Op1, this is a rotate. rotate(x, y) != 0 iff x != 0.
3654 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1))
3655 return isKnownNonZero(V: II->getArgOperand(i: 0), DemandedElts, Q, Depth);
3656 break;
3657 case Intrinsic::vscale:
3658 return true;
3659 case Intrinsic::experimental_get_vector_length:
3660 return isKnownNonZero(V: I->getOperand(i: 0), Q, Depth);
3661 default:
3662 break;
3663 }
3664 break;
3665 }
3666
3667 return false;
3668 }
3669 }
3670
3671 KnownBits Known(BitWidth);
3672 computeKnownBits(V: I, DemandedElts, Known, Q, Depth);
3673 return Known.One != 0;
3674}
3675
3676/// Return true if the given value is known to be non-zero when defined. For
3677/// vectors, return true if every demanded element is known to be non-zero when
3678/// defined. For pointers, if the context instruction and dominator tree are
3679/// specified, perform context-sensitive analysis and return true if the
3680/// pointer couldn't possibly be null at the specified instruction.
3681/// Supports values with integer or pointer type and vectors of integers.
3682bool isKnownNonZero(const Value *V, const APInt &DemandedElts,
3683 const SimplifyQuery &Q, unsigned Depth) {
3684 Type *Ty = V->getType();
3685
3686#ifndef NDEBUG
3687 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
3688
3689 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
3690 assert(
3691 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
3692 "DemandedElt width should equal the fixed vector number of elements");
3693 } else {
3694 assert(DemandedElts == APInt(1, 1) &&
3695 "DemandedElt width should be 1 for scalars");
3696 }
3697#endif
3698
3699 if (auto *C = dyn_cast<Constant>(Val: V)) {
3700 if (C->isNullValue())
3701 return false;
3702 if (isa<ConstantInt>(Val: C))
3703 // Must be non-zero due to null test above.
3704 return true;
3705
3706 // For constant vectors, check that all elements are poison or known
3707 // non-zero to determine that the whole vector is known non-zero.
3708 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Ty)) {
3709 for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) {
3710 if (!DemandedElts[i])
3711 continue;
3712 Constant *Elt = C->getAggregateElement(Elt: i);
3713 if (!Elt || Elt->isNullValue())
3714 return false;
3715 if (!isa<PoisonValue>(Val: Elt) && !isa<ConstantInt>(Val: Elt))
3716 return false;
3717 }
3718 return true;
3719 }
3720
3721 // Constant ptrauth can be null, iff the base pointer can be.
3722 if (auto *CPA = dyn_cast<ConstantPtrAuth>(Val: V))
3723 return isKnownNonZero(V: CPA->getPointer(), DemandedElts, Q, Depth);
3724
3725 // A global variable in address space 0 is non null unless extern weak
3726 // or an absolute symbol reference. Other address spaces may have null as a
3727 // valid address for a global, so we can't assume anything.
3728 if (const GlobalValue *GV = dyn_cast<GlobalValue>(Val: V)) {
3729 if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() &&
3730 GV->getType()->getAddressSpace() == 0)
3731 return true;
3732 }
3733
3734 // For constant expressions, fall through to the Operator code below.
3735 if (!isa<ConstantExpr>(Val: V))
3736 return false;
3737 }
3738
3739 if (const auto *A = dyn_cast<Argument>(Val: V))
3740 if (std::optional<ConstantRange> Range = A->getRange()) {
3741 const APInt ZeroValue(Range->getBitWidth(), 0);
3742 if (!Range->contains(Val: ZeroValue))
3743 return true;
3744 }
3745
3746 if (!isa<Constant>(Val: V) && isKnownNonZeroFromAssume(V, Q))
3747 return true;
3748
3749 // Some of the tests below are recursive, so bail out if we hit the limit.
3750 if (Depth++ >= MaxAnalysisRecursionDepth)
3751 return false;
3752
3753 // Check for pointer simplifications.
3754
3755 if (PointerType *PtrTy = dyn_cast<PointerType>(Val: Ty)) {
3756 // A byval, inalloca may not be null in a non-default addres space. A
3757 // nonnull argument is assumed never 0.
3758 if (const Argument *A = dyn_cast<Argument>(Val: V)) {
3759 if (((A->hasPassPointeeByValueCopyAttr() &&
3760 !NullPointerIsDefined(F: A->getParent(), AS: PtrTy->getAddressSpace())) ||
3761 A->hasNonNullAttr()))
3762 return true;
3763 }
3764 }
3765
3766 if (const auto *I = dyn_cast<Operator>(Val: V))
3767 if (isKnownNonZeroFromOperator(I, DemandedElts, Q, Depth))
3768 return true;
3769
3770 if (!isa<Constant>(Val: V) &&
3771 isKnownNonNullFromDominatingCondition(V, CtxI: Q.CxtI, DT: Q.DT))
3772 return true;
3773
3774 if (const Value *Stripped = stripNullTest(V))
3775 return isKnownNonZero(V: Stripped, DemandedElts, Q, Depth);
3776
3777 return false;
3778}
3779
3780bool llvm::isKnownNonZero(const Value *V, const SimplifyQuery &Q,
3781 unsigned Depth) {
3782 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
3783 APInt DemandedElts =
3784 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
3785 return ::isKnownNonZero(V, DemandedElts, Q, Depth);
3786}
3787
3788/// If the pair of operators are the same invertible function, return the
3789/// the operands of the function corresponding to each input. Otherwise,
3790/// return std::nullopt. An invertible function is one that is 1-to-1 and maps
3791/// every input value to exactly one output value. This is equivalent to
3792/// saying that Op1 and Op2 are equal exactly when the specified pair of
3793/// operands are equal, (except that Op1 and Op2 may be poison more often.)
3794static std::optional<std::pair<Value*, Value*>>
3795getInvertibleOperands(const Operator *Op1,
3796 const Operator *Op2) {
3797 if (Op1->getOpcode() != Op2->getOpcode())
3798 return std::nullopt;
3799
3800 auto getOperands = [&](unsigned OpNum) -> auto {
3801 return std::make_pair(x: Op1->getOperand(i: OpNum), y: Op2->getOperand(i: OpNum));
3802 };
3803
3804 switch (Op1->getOpcode()) {
3805 default:
3806 break;
3807 case Instruction::Or:
3808 if (!cast<PossiblyDisjointInst>(Val: Op1)->isDisjoint() ||
3809 !cast<PossiblyDisjointInst>(Val: Op2)->isDisjoint())
3810 break;
3811 [[fallthrough]];
3812 case Instruction::Xor:
3813 case Instruction::Add: {
3814 Value *Other;
3815 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 0)), R: m_Value(V&: Other))))
3816 return std::make_pair(x: Op1->getOperand(i: 1), y&: Other);
3817 if (match(V: Op2, P: m_c_BinOp(L: m_Specific(V: Op1->getOperand(i: 1)), R: m_Value(V&: Other))))
3818 return std::make_pair(x: Op1->getOperand(i: 0), y&: Other);
3819 break;
3820 }
3821 case Instruction::Sub:
3822 if (Op1->getOperand(i: 0) == Op2->getOperand(i: 0))
3823 return getOperands(1);
3824 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3825 return getOperands(0);
3826 break;
3827 case Instruction::Mul: {
3828 // invertible if A * B == (A * B) mod 2^N where A, and B are integers
3829 // and N is the bitwdith. The nsw case is non-obvious, but proven by
3830 // alive2: https://alive2.llvm.org/ce/z/Z6D5qK
3831 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3832 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3833 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3834 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3835 break;
3836
3837 // Assume operand order has been canonicalized
3838 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1) &&
3839 isa<ConstantInt>(Val: Op1->getOperand(i: 1)) &&
3840 !cast<ConstantInt>(Val: Op1->getOperand(i: 1))->isZero())
3841 return getOperands(0);
3842 break;
3843 }
3844 case Instruction::Shl: {
3845 // Same as multiplies, with the difference that we don't need to check
3846 // for a non-zero multiply. Shifts always multiply by non-zero.
3847 auto *OBO1 = cast<OverflowingBinaryOperator>(Val: Op1);
3848 auto *OBO2 = cast<OverflowingBinaryOperator>(Val: Op2);
3849 if ((!OBO1->hasNoUnsignedWrap() || !OBO2->hasNoUnsignedWrap()) &&
3850 (!OBO1->hasNoSignedWrap() || !OBO2->hasNoSignedWrap()))
3851 break;
3852
3853 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3854 return getOperands(0);
3855 break;
3856 }
3857 case Instruction::AShr:
3858 case Instruction::LShr: {
3859 auto *PEO1 = cast<PossiblyExactOperator>(Val: Op1);
3860 auto *PEO2 = cast<PossiblyExactOperator>(Val: Op2);
3861 if (!PEO1->isExact() || !PEO2->isExact())
3862 break;
3863
3864 if (Op1->getOperand(i: 1) == Op2->getOperand(i: 1))
3865 return getOperands(0);
3866 break;
3867 }
3868 case Instruction::SExt:
3869 case Instruction::ZExt:
3870 if (Op1->getOperand(i: 0)->getType() == Op2->getOperand(i: 0)->getType())
3871 return getOperands(0);
3872 break;
3873 case Instruction::PHI: {
3874 const PHINode *PN1 = cast<PHINode>(Val: Op1);
3875 const PHINode *PN2 = cast<PHINode>(Val: Op2);
3876
3877 // If PN1 and PN2 are both recurrences, can we prove the entire recurrences
3878 // are a single invertible function of the start values? Note that repeated
3879 // application of an invertible function is also invertible
3880 BinaryOperator *BO1 = nullptr;
3881 Value *Start1 = nullptr, *Step1 = nullptr;
3882 BinaryOperator *BO2 = nullptr;
3883 Value *Start2 = nullptr, *Step2 = nullptr;
3884 if (PN1->getParent() != PN2->getParent() ||
3885 !matchSimpleRecurrence(P: PN1, BO&: BO1, Start&: Start1, Step&: Step1) ||
3886 !matchSimpleRecurrence(P: PN2, BO&: BO2, Start&: Start2, Step&: Step2))
3887 break;
3888
3889 auto Values = getInvertibleOperands(Op1: cast<Operator>(Val: BO1),
3890 Op2: cast<Operator>(Val: BO2));
3891 if (!Values)
3892 break;
3893
3894 // We have to be careful of mutually defined recurrences here. Ex:
3895 // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V
3896 // * X_i = Y_i = X_(i-1) OP Y_(i-1)
3897 // The invertibility of these is complicated, and not worth reasoning
3898 // about (yet?).
3899 if (Values->first != PN1 || Values->second != PN2)
3900 break;
3901
3902 return std::make_pair(x&: Start1, y&: Start2);
3903 }
3904 }
3905 return std::nullopt;
3906}
3907
3908/// Return true if V1 == (binop V2, X), where X is known non-zero.
3909/// Only handle a small subset of binops where (binop V2, X) with non-zero X
3910/// implies V2 != V1.
3911static bool isModifyingBinopOfNonZero(const Value *V1, const Value *V2,
3912 const APInt &DemandedElts,
3913 const SimplifyQuery &Q, unsigned Depth) {
3914 const BinaryOperator *BO = dyn_cast<BinaryOperator>(Val: V1);
3915 if (!BO)
3916 return false;
3917 switch (BO->getOpcode()) {
3918 default:
3919 break;
3920 case Instruction::Or:
3921 if (!cast<PossiblyDisjointInst>(Val: V1)->isDisjoint())
3922 break;
3923 [[fallthrough]];
3924 case Instruction::Xor:
3925 case Instruction::Add:
3926 Value *Op = nullptr;
3927 if (V2 == BO->getOperand(i_nocapture: 0))
3928 Op = BO->getOperand(i_nocapture: 1);
3929 else if (V2 == BO->getOperand(i_nocapture: 1))
3930 Op = BO->getOperand(i_nocapture: 0);
3931 else
3932 return false;
3933 return isKnownNonZero(V: Op, DemandedElts, Q, Depth: Depth + 1);
3934 }
3935 return false;
3936}
3937
3938/// Return true if V2 == V1 * C, where V1 is known non-zero, C is not 0/1 and
3939/// the multiplication is nuw or nsw.
3940static bool isNonEqualMul(const Value *V1, const Value *V2,
3941 const APInt &DemandedElts, const SimplifyQuery &Q,
3942 unsigned Depth) {
3943 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3944 const APInt *C;
3945 return match(V: OBO, P: m_Mul(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3946 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3947 !C->isZero() && !C->isOne() &&
3948 isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3949 }
3950 return false;
3951}
3952
3953/// Return true if V2 == V1 << C, where V1 is known non-zero, C is not 0 and
3954/// the shift is nuw or nsw.
3955static bool isNonEqualShl(const Value *V1, const Value *V2,
3956 const APInt &DemandedElts, const SimplifyQuery &Q,
3957 unsigned Depth) {
3958 if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Val: V2)) {
3959 const APInt *C;
3960 return match(V: OBO, P: m_Shl(L: m_Specific(V: V1), R: m_APInt(Res&: C))) &&
3961 (OBO->hasNoUnsignedWrap() || OBO->hasNoSignedWrap()) &&
3962 !C->isZero() && isKnownNonZero(V: V1, DemandedElts, Q, Depth: Depth + 1);
3963 }
3964 return false;
3965}
3966
3967static bool isNonEqualPHIs(const PHINode *PN1, const PHINode *PN2,
3968 const APInt &DemandedElts, const SimplifyQuery &Q,
3969 unsigned Depth) {
3970 // Check two PHIs are in same block.
3971 if (PN1->getParent() != PN2->getParent())
3972 return false;
3973
3974 SmallPtrSet<const BasicBlock *, 8> VisitedBBs;
3975 bool UsedFullRecursion = false;
3976 for (const BasicBlock *IncomBB : PN1->blocks()) {
3977 if (!VisitedBBs.insert(Ptr: IncomBB).second)
3978 continue; // Don't reprocess blocks that we have dealt with already.
3979 const Value *IV1 = PN1->getIncomingValueForBlock(BB: IncomBB);
3980 const Value *IV2 = PN2->getIncomingValueForBlock(BB: IncomBB);
3981 const APInt *C1, *C2;
3982 if (match(V: IV1, P: m_APInt(Res&: C1)) && match(V: IV2, P: m_APInt(Res&: C2)) && *C1 != *C2)
3983 continue;
3984
3985 // Only one pair of phi operands is allowed for full recursion.
3986 if (UsedFullRecursion)
3987 return false;
3988
3989 SimplifyQuery RecQ = Q.getWithoutCondContext();
3990 RecQ.CxtI = IncomBB->getTerminator();
3991 if (!isKnownNonEqual(V1: IV1, V2: IV2, DemandedElts, Q: RecQ, Depth: Depth + 1))
3992 return false;
3993 UsedFullRecursion = true;
3994 }
3995 return true;
3996}
3997
3998static bool isNonEqualSelect(const Value *V1, const Value *V2,
3999 const APInt &DemandedElts, const SimplifyQuery &Q,
4000 unsigned Depth) {
4001 const SelectInst *SI1 = dyn_cast<SelectInst>(Val: V1);
4002 if (!SI1)
4003 return false;
4004
4005 if (const SelectInst *SI2 = dyn_cast<SelectInst>(Val: V2)) {
4006 const Value *Cond1 = SI1->getCondition();
4007 const Value *Cond2 = SI2->getCondition();
4008 if (Cond1 == Cond2)
4009 return isKnownNonEqual(V1: SI1->getTrueValue(), V2: SI2->getTrueValue(),
4010 DemandedElts, Q, Depth: Depth + 1) &&
4011 isKnownNonEqual(V1: SI1->getFalseValue(), V2: SI2->getFalseValue(),
4012 DemandedElts, Q, Depth: Depth + 1);
4013 }
4014 return isKnownNonEqual(V1: SI1->getTrueValue(), V2, DemandedElts, Q, Depth: Depth + 1) &&
4015 isKnownNonEqual(V1: SI1->getFalseValue(), V2, DemandedElts, Q, Depth: Depth + 1);
4016}
4017
4018// Check to see if A is both a GEP and is the incoming value for a PHI in the
4019// loop, and B is either a ptr or another GEP. If the PHI has 2 incoming values,
4020// one of them being the recursive GEP A and the other a ptr at same base and at
4021// the same/higher offset than B we are only incrementing the pointer further in
4022// loop if offset of recursive GEP is greater than 0.
4023static bool isNonEqualPointersWithRecursiveGEP(const Value *A, const Value *B,
4024 const SimplifyQuery &Q) {
4025 if (!A->getType()->isPointerTy() || !B->getType()->isPointerTy())
4026 return false;
4027
4028 auto *GEPA = dyn_cast<GEPOperator>(Val: A);
4029 if (!GEPA || GEPA->getNumIndices() != 1 || !isa<Constant>(Val: GEPA->idx_begin()))
4030 return false;
4031
4032 // Handle 2 incoming PHI values with one being a recursive GEP.
4033 auto *PN = dyn_cast<PHINode>(Val: GEPA->getPointerOperand());
4034 if (!PN || PN->getNumIncomingValues() != 2)
4035 return false;
4036
4037 // Search for the recursive GEP as an incoming operand, and record that as
4038 // Step.
4039 Value *Start = nullptr;
4040 Value *Step = const_cast<Value *>(A);
4041 if (PN->getIncomingValue(i: 0) == Step)
4042 Start = PN->getIncomingValue(i: 1);
4043 else if (PN->getIncomingValue(i: 1) == Step)
4044 Start = PN->getIncomingValue(i: 0);
4045 else
4046 return false;
4047
4048 // Other incoming node base should match the B base.
4049 // StartOffset >= OffsetB && StepOffset > 0?
4050 // StartOffset <= OffsetB && StepOffset < 0?
4051 // Is non-equal if above are true.
4052 // We use stripAndAccumulateInBoundsConstantOffsets to restrict the
4053 // optimisation to inbounds GEPs only.
4054 unsigned IndexWidth = Q.DL.getIndexTypeSizeInBits(Ty: Start->getType());
4055 APInt StartOffset(IndexWidth, 0);
4056 Start = Start->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StartOffset);
4057 APInt StepOffset(IndexWidth, 0);
4058 Step = Step->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: StepOffset);
4059
4060 // Check if Base Pointer of Step matches the PHI.
4061 if (Step != PN)
4062 return false;
4063 APInt OffsetB(IndexWidth, 0);
4064 B = B->stripAndAccumulateInBoundsConstantOffsets(DL: Q.DL, Offset&: OffsetB);
4065 return Start == B &&
4066 ((StartOffset.sge(RHS: OffsetB) && StepOffset.isStrictlyPositive()) ||
4067 (StartOffset.sle(RHS: OffsetB) && StepOffset.isNegative()));
4068}
4069
4070static bool isKnownNonEqualFromContext(const Value *V1, const Value *V2,
4071 const SimplifyQuery &Q, unsigned Depth) {
4072 if (!Q.CxtI)
4073 return false;
4074
4075 // Try to infer NonEqual based on information from dominating conditions.
4076 if (Q.DC && Q.DT) {
4077 auto IsKnownNonEqualFromDominatingCondition = [&](const Value *V) {
4078 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4079 Value *Cond = BI->getCondition();
4080 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4081 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()) &&
4082 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4083 /*LHSIsTrue=*/true, Depth)
4084 .value_or(u: false))
4085 return true;
4086
4087 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4088 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()) &&
4089 isImpliedCondition(LHS: Cond, RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4090 /*LHSIsTrue=*/false, Depth)
4091 .value_or(u: false))
4092 return true;
4093 }
4094
4095 return false;
4096 };
4097
4098 if (IsKnownNonEqualFromDominatingCondition(V1) ||
4099 IsKnownNonEqualFromDominatingCondition(V2))
4100 return true;
4101 }
4102
4103 if (!Q.AC)
4104 return false;
4105
4106 // Try to infer NonEqual based on information from assumptions.
4107 for (auto &AssumeVH : Q.AC->assumptionsFor(V: V1)) {
4108 if (!AssumeVH)
4109 continue;
4110 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4111
4112 assert(I->getFunction() == Q.CxtI->getFunction() &&
4113 "Got assumption for the wrong function!");
4114 assert(I->getIntrinsicID() == Intrinsic::assume &&
4115 "must be an assume intrinsic");
4116
4117 if (isImpliedCondition(LHS: I->getArgOperand(i: 0), RHSPred: ICmpInst::ICMP_NE, RHSOp0: V1, RHSOp1: V2, DL: Q.DL,
4118 /*LHSIsTrue=*/true, Depth)
4119 .value_or(u: false) &&
4120 isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4121 return true;
4122 }
4123
4124 return false;
4125}
4126
4127/// Return true if it is known that V1 != V2.
4128static bool isKnownNonEqual(const Value *V1, const Value *V2,
4129 const APInt &DemandedElts, const SimplifyQuery &Q,
4130 unsigned Depth) {
4131 if (V1 == V2)
4132 return false;
4133 if (V1->getType() != V2->getType())
4134 // We can't look through casts yet.
4135 return false;
4136
4137 if (Depth >= MaxAnalysisRecursionDepth)
4138 return false;
4139
4140 // See if we can recurse through (exactly one of) our operands. This
4141 // requires our operation be 1-to-1 and map every input value to exactly
4142 // one output value. Such an operation is invertible.
4143 auto *O1 = dyn_cast<Operator>(Val: V1);
4144 auto *O2 = dyn_cast<Operator>(Val: V2);
4145 if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) {
4146 if (auto Values = getInvertibleOperands(Op1: O1, Op2: O2))
4147 return isKnownNonEqual(V1: Values->first, V2: Values->second, DemandedElts, Q,
4148 Depth: Depth + 1);
4149
4150 if (const PHINode *PN1 = dyn_cast<PHINode>(Val: V1)) {
4151 const PHINode *PN2 = cast<PHINode>(Val: V2);
4152 // FIXME: This is missing a generalization to handle the case where one is
4153 // a PHI and another one isn't.
4154 if (isNonEqualPHIs(PN1, PN2, DemandedElts, Q, Depth))
4155 return true;
4156 };
4157 }
4158
4159 if (isModifyingBinopOfNonZero(V1, V2, DemandedElts, Q, Depth) ||
4160 isModifyingBinopOfNonZero(V1: V2, V2: V1, DemandedElts, Q, Depth))
4161 return true;
4162
4163 if (isNonEqualMul(V1, V2, DemandedElts, Q, Depth) ||
4164 isNonEqualMul(V1: V2, V2: V1, DemandedElts, Q, Depth))
4165 return true;
4166
4167 if (isNonEqualShl(V1, V2, DemandedElts, Q, Depth) ||
4168 isNonEqualShl(V1: V2, V2: V1, DemandedElts, Q, Depth))
4169 return true;
4170
4171 if (V1->getType()->isIntOrIntVectorTy()) {
4172 // Are any known bits in V1 contradictory to known bits in V2? If V1
4173 // has a known zero where V2 has a known one, they must not be equal.
4174 KnownBits Known1 = computeKnownBits(V: V1, DemandedElts, Q, Depth);
4175 if (!Known1.isUnknown()) {
4176 KnownBits Known2 = computeKnownBits(V: V2, DemandedElts, Q, Depth);
4177 if (Known1.Zero.intersects(RHS: Known2.One) ||
4178 Known2.Zero.intersects(RHS: Known1.One))
4179 return true;
4180 }
4181 }
4182
4183 if (isNonEqualSelect(V1, V2, DemandedElts, Q, Depth) ||
4184 isNonEqualSelect(V1: V2, V2: V1, DemandedElts, Q, Depth))
4185 return true;
4186
4187 if (isNonEqualPointersWithRecursiveGEP(A: V1, B: V2, Q) ||
4188 isNonEqualPointersWithRecursiveGEP(A: V2, B: V1, Q))
4189 return true;
4190
4191 Value *A, *B;
4192 // PtrToInts are NonEqual if their Ptrs are NonEqual.
4193 // Check PtrToInt type matches the pointer size.
4194 if (match(V: V1, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: A))) &&
4195 match(V: V2, P: m_PtrToIntSameSize(DL: Q.DL, Op: m_Value(V&: B))))
4196 return isKnownNonEqual(V1: A, V2: B, DemandedElts, Q, Depth: Depth + 1);
4197
4198 if (isKnownNonEqualFromContext(V1, V2, Q, Depth))
4199 return true;
4200
4201 return false;
4202}
4203
4204/// For vector constants, loop over the elements and find the constant with the
4205/// minimum number of sign bits. Return 0 if the value is not a vector constant
4206/// or if any element was not analyzed; otherwise, return the count for the
4207/// element with the minimum number of sign bits.
4208static unsigned computeNumSignBitsVectorConstant(const Value *V,
4209 const APInt &DemandedElts,
4210 unsigned TyBits) {
4211 const auto *CV = dyn_cast<Constant>(Val: V);
4212 if (!CV || !isa<FixedVectorType>(Val: CV->getType()))
4213 return 0;
4214
4215 unsigned MinSignBits = TyBits;
4216 unsigned NumElts = cast<FixedVectorType>(Val: CV->getType())->getNumElements();
4217 for (unsigned i = 0; i != NumElts; ++i) {
4218 if (!DemandedElts[i])
4219 continue;
4220 // If we find a non-ConstantInt, bail out.
4221 auto *Elt = dyn_cast_or_null<ConstantInt>(Val: CV->getAggregateElement(Elt: i));
4222 if (!Elt)
4223 return 0;
4224
4225 MinSignBits = std::min(a: MinSignBits, b: Elt->getValue().getNumSignBits());
4226 }
4227
4228 return MinSignBits;
4229}
4230
4231static unsigned ComputeNumSignBitsImpl(const Value *V,
4232 const APInt &DemandedElts,
4233 const SimplifyQuery &Q, unsigned Depth);
4234
4235static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts,
4236 const SimplifyQuery &Q, unsigned Depth) {
4237 unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Q, Depth);
4238 assert(Result > 0 && "At least one sign bit needs to be present!");
4239 return Result;
4240}
4241
4242/// Return the number of times the sign bit of the register is replicated into
4243/// the other bits. We know that at least 1 bit is always equal to the sign bit
4244/// (itself), but other cases can give us information. For example, immediately
4245/// after an "ashr X, 2", we know that the top 3 bits are all equal to each
4246/// other, so we return 3. For vectors, return the number of sign bits for the
4247/// vector element with the minimum number of known sign bits of the demanded
4248/// elements in the vector specified by DemandedElts.
4249static unsigned ComputeNumSignBitsImpl(const Value *V,
4250 const APInt &DemandedElts,
4251 const SimplifyQuery &Q, unsigned Depth) {
4252 Type *Ty = V->getType();
4253#ifndef NDEBUG
4254 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4255
4256 if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) {
4257 assert(
4258 FVTy->getNumElements() == DemandedElts.getBitWidth() &&
4259 "DemandedElt width should equal the fixed vector number of elements");
4260 } else {
4261 assert(DemandedElts == APInt(1, 1) &&
4262 "DemandedElt width should be 1 for scalars");
4263 }
4264#endif
4265
4266 // We return the minimum number of sign bits that are guaranteed to be present
4267 // in V, so for undef we have to conservatively return 1. We don't have the
4268 // same behavior for poison though -- that's a FIXME today.
4269
4270 Type *ScalarTy = Ty->getScalarType();
4271 unsigned TyBits = ScalarTy->isPointerTy() ?
4272 Q.DL.getPointerTypeSizeInBits(ScalarTy) :
4273 Q.DL.getTypeSizeInBits(Ty: ScalarTy);
4274
4275 unsigned Tmp, Tmp2;
4276 unsigned FirstAnswer = 1;
4277
4278 // Note that ConstantInt is handled by the general computeKnownBits case
4279 // below.
4280
4281 if (Depth == MaxAnalysisRecursionDepth)
4282 return 1;
4283
4284 if (auto *U = dyn_cast<Operator>(Val: V)) {
4285 switch (Operator::getOpcode(V)) {
4286 default: break;
4287 case Instruction::BitCast: {
4288 Value *Src = U->getOperand(i: 0);
4289 Type *SrcTy = Src->getType();
4290
4291 // Skip if the source type is not an integer or integer vector type
4292 // This ensures we only process integer-like types
4293 if (!SrcTy->isIntOrIntVectorTy())
4294 break;
4295
4296 unsigned SrcBits = SrcTy->getScalarSizeInBits();
4297
4298 // Bitcast 'large element' scalar/vector to 'small element' vector.
4299 if ((SrcBits % TyBits) != 0)
4300 break;
4301
4302 // Only proceed if the destination type is a fixed-size vector
4303 if (isa<FixedVectorType>(Val: Ty)) {
4304 // Fast case - sign splat can be simply split across the small elements.
4305 // This works for both vector and scalar sources
4306 Tmp = ComputeNumSignBits(V: Src, Q, Depth: Depth + 1);
4307 if (Tmp == SrcBits)
4308 return TyBits;
4309 }
4310 break;
4311 }
4312 case Instruction::SExt:
4313 Tmp = TyBits - U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4314 return ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1) +
4315 Tmp;
4316
4317 case Instruction::SDiv: {
4318 const APInt *Denominator;
4319 // sdiv X, C -> adds log(C) sign bits.
4320 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4321
4322 // Ignore non-positive denominator.
4323 if (!Denominator->isStrictlyPositive())
4324 break;
4325
4326 // Calculate the incoming numerator bits.
4327 unsigned NumBits =
4328 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4329
4330 // Add floor(log(C)) bits to the numerator bits.
4331 return std::min(a: TyBits, b: NumBits + Denominator->logBase2());
4332 }
4333 break;
4334 }
4335
4336 case Instruction::SRem: {
4337 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4338
4339 const APInt *Denominator;
4340 // srem X, C -> we know that the result is within [-C+1,C) when C is a
4341 // positive constant. This let us put a lower bound on the number of sign
4342 // bits.
4343 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: Denominator))) {
4344
4345 // Ignore non-positive denominator.
4346 if (Denominator->isStrictlyPositive()) {
4347 // Calculate the leading sign bit constraints by examining the
4348 // denominator. Given that the denominator is positive, there are two
4349 // cases:
4350 //
4351 // 1. The numerator is positive. The result range is [0,C) and
4352 // [0,C) u< (1 << ceilLogBase2(C)).
4353 //
4354 // 2. The numerator is negative. Then the result range is (-C,0] and
4355 // integers in (-C,0] are either 0 or >u (-1 << ceilLogBase2(C)).
4356 //
4357 // Thus a lower bound on the number of sign bits is `TyBits -
4358 // ceilLogBase2(C)`.
4359
4360 unsigned ResBits = TyBits - Denominator->ceilLogBase2();
4361 Tmp = std::max(a: Tmp, b: ResBits);
4362 }
4363 }
4364 return Tmp;
4365 }
4366
4367 case Instruction::AShr: {
4368 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4369 // ashr X, C -> adds C sign bits. Vectors too.
4370 const APInt *ShAmt;
4371 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4372 if (ShAmt->uge(RHS: TyBits))
4373 break; // Bad shift.
4374 unsigned ShAmtLimited = ShAmt->getZExtValue();
4375 Tmp += ShAmtLimited;
4376 if (Tmp > TyBits) Tmp = TyBits;
4377 }
4378 return Tmp;
4379 }
4380 case Instruction::Shl: {
4381 const APInt *ShAmt;
4382 Value *X = nullptr;
4383 if (match(V: U->getOperand(i: 1), P: m_APInt(Res&: ShAmt))) {
4384 // shl destroys sign bits.
4385 if (ShAmt->uge(RHS: TyBits))
4386 break; // Bad shift.
4387 // We can look through a zext (more or less treating it as a sext) if
4388 // all extended bits are shifted out.
4389 if (match(V: U->getOperand(i: 0), P: m_ZExt(Op: m_Value(V&: X))) &&
4390 ShAmt->uge(RHS: TyBits - X->getType()->getScalarSizeInBits())) {
4391 Tmp = ComputeNumSignBits(V: X, DemandedElts, Q, Depth: Depth + 1);
4392 Tmp += TyBits - X->getType()->getScalarSizeInBits();
4393 } else
4394 Tmp =
4395 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4396 if (ShAmt->uge(RHS: Tmp))
4397 break; // Shifted all sign bits out.
4398 Tmp2 = ShAmt->getZExtValue();
4399 return Tmp - Tmp2;
4400 }
4401 break;
4402 }
4403 case Instruction::And:
4404 case Instruction::Or:
4405 case Instruction::Xor: // NOT is handled here.
4406 // Logical binary ops preserve the number of sign bits at the worst.
4407 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4408 if (Tmp != 1) {
4409 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4410 FirstAnswer = std::min(a: Tmp, b: Tmp2);
4411 // We computed what we know about the sign bits as our first
4412 // answer. Now proceed to the generic code that uses
4413 // computeKnownBits, and pick whichever answer is better.
4414 }
4415 break;
4416
4417 case Instruction::Select: {
4418 // If we have a clamp pattern, we know that the number of sign bits will
4419 // be the minimum of the clamp min/max range.
4420 const Value *X;
4421 const APInt *CLow, *CHigh;
4422 if (isSignedMinMaxClamp(Select: U, In&: X, CLow, CHigh))
4423 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4424
4425 Tmp = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4426 if (Tmp == 1)
4427 break;
4428 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 2), DemandedElts, Q, Depth: Depth + 1);
4429 return std::min(a: Tmp, b: Tmp2);
4430 }
4431
4432 case Instruction::Add:
4433 // Add can have at most one carry bit. Thus we know that the output
4434 // is, at worst, one more bit than the inputs.
4435 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4436 if (Tmp == 1) break;
4437
4438 // Special case decrementing a value (ADD X, -1):
4439 if (const auto *CRHS = dyn_cast<Constant>(Val: U->getOperand(i: 1)))
4440 if (CRHS->isAllOnesValue()) {
4441 KnownBits Known(TyBits);
4442 computeKnownBits(V: U->getOperand(i: 0), DemandedElts, Known, Q, Depth: Depth + 1);
4443
4444 // If the input is known to be 0 or 1, the output is 0/-1, which is
4445 // all sign bits set.
4446 if ((Known.Zero | 1).isAllOnes())
4447 return TyBits;
4448
4449 // If we are subtracting one from a positive number, there is no carry
4450 // out of the result.
4451 if (Known.isNonNegative())
4452 return Tmp;
4453 }
4454
4455 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4456 if (Tmp2 == 1)
4457 break;
4458 return std::min(a: Tmp, b: Tmp2) - 1;
4459
4460 case Instruction::Sub:
4461 Tmp2 = ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4462 if (Tmp2 == 1)
4463 break;
4464
4465 // Handle NEG.
4466 if (const auto *CLHS = dyn_cast<Constant>(Val: U->getOperand(i: 0)))
4467 if (CLHS->isNullValue()) {
4468 KnownBits Known(TyBits);
4469 computeKnownBits(V: U->getOperand(i: 1), DemandedElts, Known, Q, Depth: Depth + 1);
4470 // If the input is known to be 0 or 1, the output is 0/-1, which is
4471 // all sign bits set.
4472 if ((Known.Zero | 1).isAllOnes())
4473 return TyBits;
4474
4475 // If the input is known to be positive (the sign bit is known clear),
4476 // the output of the NEG has the same number of sign bits as the
4477 // input.
4478 if (Known.isNonNegative())
4479 return Tmp2;
4480
4481 // Otherwise, we treat this like a SUB.
4482 }
4483
4484 // Sub can have at most one carry bit. Thus we know that the output
4485 // is, at worst, one more bit than the inputs.
4486 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4487 if (Tmp == 1)
4488 break;
4489 return std::min(a: Tmp, b: Tmp2) - 1;
4490
4491 case Instruction::Mul: {
4492 // The output of the Mul can be at most twice the valid bits in the
4493 // inputs.
4494 unsigned SignBitsOp0 =
4495 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4496 if (SignBitsOp0 == 1)
4497 break;
4498 unsigned SignBitsOp1 =
4499 ComputeNumSignBits(V: U->getOperand(i: 1), DemandedElts, Q, Depth: Depth + 1);
4500 if (SignBitsOp1 == 1)
4501 break;
4502 unsigned OutValidBits =
4503 (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1);
4504 return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1;
4505 }
4506
4507 case Instruction::PHI: {
4508 const PHINode *PN = cast<PHINode>(Val: U);
4509 unsigned NumIncomingValues = PN->getNumIncomingValues();
4510 // Don't analyze large in-degree PHIs.
4511 if (NumIncomingValues > 4) break;
4512 // Unreachable blocks may have zero-operand PHI nodes.
4513 if (NumIncomingValues == 0) break;
4514
4515 // Take the minimum of all incoming values. This can't infinitely loop
4516 // because of our depth threshold.
4517 SimplifyQuery RecQ = Q.getWithoutCondContext();
4518 Tmp = TyBits;
4519 for (unsigned i = 0, e = NumIncomingValues; i != e; ++i) {
4520 if (Tmp == 1) return Tmp;
4521 RecQ.CxtI = PN->getIncomingBlock(i)->getTerminator();
4522 Tmp = std::min(a: Tmp, b: ComputeNumSignBits(V: PN->getIncomingValue(i),
4523 DemandedElts, Q: RecQ, Depth: Depth + 1));
4524 }
4525 return Tmp;
4526 }
4527
4528 case Instruction::Trunc: {
4529 // If the input contained enough sign bits that some remain after the
4530 // truncation, then we can make use of that. Otherwise we don't know
4531 // anything.
4532 Tmp = ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4533 unsigned OperandTyBits = U->getOperand(i: 0)->getType()->getScalarSizeInBits();
4534 if (Tmp > (OperandTyBits - TyBits))
4535 return Tmp - (OperandTyBits - TyBits);
4536
4537 return 1;
4538 }
4539
4540 case Instruction::ExtractElement:
4541 // Look through extract element. At the moment we keep this simple and
4542 // skip tracking the specific element. But at least we might find
4543 // information valid for all elements of the vector (for example if vector
4544 // is sign extended, shifted, etc).
4545 return ComputeNumSignBits(V: U->getOperand(i: 0), Q, Depth: Depth + 1);
4546
4547 case Instruction::ShuffleVector: {
4548 // Collect the minimum number of sign bits that are shared by every vector
4549 // element referenced by the shuffle.
4550 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: U);
4551 if (!Shuf) {
4552 // FIXME: Add support for shufflevector constant expressions.
4553 return 1;
4554 }
4555 APInt DemandedLHS, DemandedRHS;
4556 // For undef elements, we don't know anything about the common state of
4557 // the shuffle result.
4558 if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
4559 return 1;
4560 Tmp = std::numeric_limits<unsigned>::max();
4561 if (!!DemandedLHS) {
4562 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
4563 Tmp = ComputeNumSignBits(V: LHS, DemandedElts: DemandedLHS, Q, Depth: Depth + 1);
4564 }
4565 // If we don't know anything, early out and try computeKnownBits
4566 // fall-back.
4567 if (Tmp == 1)
4568 break;
4569 if (!!DemandedRHS) {
4570 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
4571 Tmp2 = ComputeNumSignBits(V: RHS, DemandedElts: DemandedRHS, Q, Depth: Depth + 1);
4572 Tmp = std::min(a: Tmp, b: Tmp2);
4573 }
4574 // If we don't know anything, early out and try computeKnownBits
4575 // fall-back.
4576 if (Tmp == 1)
4577 break;
4578 assert(Tmp <= TyBits && "Failed to determine minimum sign bits");
4579 return Tmp;
4580 }
4581 case Instruction::Call: {
4582 if (const auto *II = dyn_cast<IntrinsicInst>(Val: U)) {
4583 switch (II->getIntrinsicID()) {
4584 default:
4585 break;
4586 case Intrinsic::abs:
4587 Tmp =
4588 ComputeNumSignBits(V: U->getOperand(i: 0), DemandedElts, Q, Depth: Depth + 1);
4589 if (Tmp == 1)
4590 break;
4591
4592 // Absolute value reduces number of sign bits by at most 1.
4593 return Tmp - 1;
4594 case Intrinsic::smin:
4595 case Intrinsic::smax: {
4596 const APInt *CLow, *CHigh;
4597 if (isSignedMinMaxIntrinsicClamp(II, CLow, CHigh))
4598 return std::min(a: CLow->getNumSignBits(), b: CHigh->getNumSignBits());
4599 }
4600 }
4601 }
4602 }
4603 }
4604 }
4605
4606 // Finally, if we can prove that the top bits of the result are 0's or 1's,
4607 // use this information.
4608
4609 // If we can examine all elements of a vector constant successfully, we're
4610 // done (we can't do any better than that). If not, keep trying.
4611 if (unsigned VecSignBits =
4612 computeNumSignBitsVectorConstant(V, DemandedElts, TyBits))
4613 return VecSignBits;
4614
4615 KnownBits Known(TyBits);
4616 computeKnownBits(V, DemandedElts, Known, Q, Depth);
4617
4618 // If we know that the sign bit is either zero or one, determine the number of
4619 // identical bits in the top of the input value.
4620 return std::max(a: FirstAnswer, b: Known.countMinSignBits());
4621}
4622
4623Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB,
4624 const TargetLibraryInfo *TLI) {
4625 const Function *F = CB.getCalledFunction();
4626 if (!F)
4627 return Intrinsic::not_intrinsic;
4628
4629 if (F->isIntrinsic())
4630 return F->getIntrinsicID();
4631
4632 // We are going to infer semantics of a library function based on mapping it
4633 // to an LLVM intrinsic. Check that the library function is available from
4634 // this callbase and in this environment.
4635 LibFunc Func;
4636 if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, F&: Func) ||
4637 !CB.onlyReadsMemory())
4638 return Intrinsic::not_intrinsic;
4639
4640 switch (Func) {
4641 default:
4642 break;
4643 case LibFunc_sin:
4644 case LibFunc_sinf:
4645 case LibFunc_sinl:
4646 return Intrinsic::sin;
4647 case LibFunc_cos:
4648 case LibFunc_cosf:
4649 case LibFunc_cosl:
4650 return Intrinsic::cos;
4651 case LibFunc_tan:
4652 case LibFunc_tanf:
4653 case LibFunc_tanl:
4654 return Intrinsic::tan;
4655 case LibFunc_asin:
4656 case LibFunc_asinf:
4657 case LibFunc_asinl:
4658 return Intrinsic::asin;
4659 case LibFunc_acos:
4660 case LibFunc_acosf:
4661 case LibFunc_acosl:
4662 return Intrinsic::acos;
4663 case LibFunc_atan:
4664 case LibFunc_atanf:
4665 case LibFunc_atanl:
4666 return Intrinsic::atan;
4667 case LibFunc_atan2:
4668 case LibFunc_atan2f:
4669 case LibFunc_atan2l:
4670 return Intrinsic::atan2;
4671 case LibFunc_sinh:
4672 case LibFunc_sinhf:
4673 case LibFunc_sinhl:
4674 return Intrinsic::sinh;
4675 case LibFunc_cosh:
4676 case LibFunc_coshf:
4677 case LibFunc_coshl:
4678 return Intrinsic::cosh;
4679 case LibFunc_tanh:
4680 case LibFunc_tanhf:
4681 case LibFunc_tanhl:
4682 return Intrinsic::tanh;
4683 case LibFunc_exp:
4684 case LibFunc_expf:
4685 case LibFunc_expl:
4686 return Intrinsic::exp;
4687 case LibFunc_exp2:
4688 case LibFunc_exp2f:
4689 case LibFunc_exp2l:
4690 return Intrinsic::exp2;
4691 case LibFunc_exp10:
4692 case LibFunc_exp10f:
4693 case LibFunc_exp10l:
4694 return Intrinsic::exp10;
4695 case LibFunc_log:
4696 case LibFunc_logf:
4697 case LibFunc_logl:
4698 return Intrinsic::log;
4699 case LibFunc_log10:
4700 case LibFunc_log10f:
4701 case LibFunc_log10l:
4702 return Intrinsic::log10;
4703 case LibFunc_log2:
4704 case LibFunc_log2f:
4705 case LibFunc_log2l:
4706 return Intrinsic::log2;
4707 case LibFunc_fabs:
4708 case LibFunc_fabsf:
4709 case LibFunc_fabsl:
4710 return Intrinsic::fabs;
4711 case LibFunc_fmin:
4712 case LibFunc_fminf:
4713 case LibFunc_fminl:
4714 return Intrinsic::minnum;
4715 case LibFunc_fmax:
4716 case LibFunc_fmaxf:
4717 case LibFunc_fmaxl:
4718 return Intrinsic::maxnum;
4719 case LibFunc_copysign:
4720 case LibFunc_copysignf:
4721 case LibFunc_copysignl:
4722 return Intrinsic::copysign;
4723 case LibFunc_floor:
4724 case LibFunc_floorf:
4725 case LibFunc_floorl:
4726 return Intrinsic::floor;
4727 case LibFunc_ceil:
4728 case LibFunc_ceilf:
4729 case LibFunc_ceill:
4730 return Intrinsic::ceil;
4731 case LibFunc_trunc:
4732 case LibFunc_truncf:
4733 case LibFunc_truncl:
4734 return Intrinsic::trunc;
4735 case LibFunc_rint:
4736 case LibFunc_rintf:
4737 case LibFunc_rintl:
4738 return Intrinsic::rint;
4739 case LibFunc_nearbyint:
4740 case LibFunc_nearbyintf:
4741 case LibFunc_nearbyintl:
4742 return Intrinsic::nearbyint;
4743 case LibFunc_round:
4744 case LibFunc_roundf:
4745 case LibFunc_roundl:
4746 return Intrinsic::round;
4747 case LibFunc_roundeven:
4748 case LibFunc_roundevenf:
4749 case LibFunc_roundevenl:
4750 return Intrinsic::roundeven;
4751 case LibFunc_pow:
4752 case LibFunc_powf:
4753 case LibFunc_powl:
4754 return Intrinsic::pow;
4755 case LibFunc_sqrt:
4756 case LibFunc_sqrtf:
4757 case LibFunc_sqrtl:
4758 return Intrinsic::sqrt;
4759 }
4760
4761 return Intrinsic::not_intrinsic;
4762}
4763
4764/// Given an exploded icmp instruction, return true if the comparison only
4765/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if
4766/// the result of the comparison is true when the input value is signed.
4767bool llvm::isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS,
4768 bool &TrueIfSigned) {
4769 switch (Pred) {
4770 case ICmpInst::ICMP_SLT: // True if LHS s< 0
4771 TrueIfSigned = true;
4772 return RHS.isZero();
4773 case ICmpInst::ICMP_SLE: // True if LHS s<= -1
4774 TrueIfSigned = true;
4775 return RHS.isAllOnes();
4776 case ICmpInst::ICMP_SGT: // True if LHS s> -1
4777 TrueIfSigned = false;
4778 return RHS.isAllOnes();
4779 case ICmpInst::ICMP_SGE: // True if LHS s>= 0
4780 TrueIfSigned = false;
4781 return RHS.isZero();
4782 case ICmpInst::ICMP_UGT:
4783 // True if LHS u> RHS and RHS == sign-bit-mask - 1
4784 TrueIfSigned = true;
4785 return RHS.isMaxSignedValue();
4786 case ICmpInst::ICMP_UGE:
4787 // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4788 TrueIfSigned = true;
4789 return RHS.isMinSignedValue();
4790 case ICmpInst::ICMP_ULT:
4791 // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc)
4792 TrueIfSigned = false;
4793 return RHS.isMinSignedValue();
4794 case ICmpInst::ICMP_ULE:
4795 // True if LHS u<= RHS and RHS == sign-bit-mask - 1
4796 TrueIfSigned = false;
4797 return RHS.isMaxSignedValue();
4798 default:
4799 return false;
4800 }
4801}
4802
4803static void computeKnownFPClassFromCond(const Value *V, Value *Cond,
4804 bool CondIsTrue,
4805 const Instruction *CxtI,
4806 KnownFPClass &KnownFromContext,
4807 unsigned Depth = 0) {
4808 Value *A, *B;
4809 if (Depth < MaxAnalysisRecursionDepth &&
4810 (CondIsTrue ? match(V: Cond, P: m_LogicalAnd(L: m_Value(V&: A), R: m_Value(V&: B)))
4811 : match(V: Cond, P: m_LogicalOr(L: m_Value(V&: A), R: m_Value(V&: B))))) {
4812 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue, CxtI, KnownFromContext,
4813 Depth: Depth + 1);
4814 computeKnownFPClassFromCond(V, Cond: B, CondIsTrue, CxtI, KnownFromContext,
4815 Depth: Depth + 1);
4816 return;
4817 }
4818 if (Depth < MaxAnalysisRecursionDepth && match(V: Cond, P: m_Not(V: m_Value(V&: A)))) {
4819 computeKnownFPClassFromCond(V, Cond: A, CondIsTrue: !CondIsTrue, CxtI, KnownFromContext,
4820 Depth: Depth + 1);
4821 return;
4822 }
4823 CmpPredicate Pred;
4824 Value *LHS;
4825 uint64_t ClassVal = 0;
4826 const APFloat *CRHS;
4827 const APInt *RHS;
4828 if (match(V: Cond, P: m_FCmp(Pred, L: m_Value(V&: LHS), R: m_APFloat(Res&: CRHS)))) {
4829 auto [CmpVal, MaskIfTrue, MaskIfFalse] = fcmpImpliesClass(
4830 Pred, F: *cast<Instruction>(Val: Cond)->getParent()->getParent(), LHS, ConstRHS: *CRHS,
4831 LookThroughSrc: LHS != V);
4832 if (CmpVal == V)
4833 KnownFromContext.knownNot(RuleOut: ~(CondIsTrue ? MaskIfTrue : MaskIfFalse));
4834 } else if (match(V: Cond, P: m_Intrinsic<Intrinsic::is_fpclass>(
4835 Op0: m_Specific(V), Op1: m_ConstantInt(V&: ClassVal)))) {
4836 FPClassTest Mask = static_cast<FPClassTest>(ClassVal);
4837 KnownFromContext.knownNot(RuleOut: CondIsTrue ? ~Mask : Mask);
4838 } else if (match(V: Cond, P: m_ICmp(Pred, L: m_ElementWiseBitCast(Op: m_Specific(V)),
4839 R: m_APInt(Res&: RHS)))) {
4840 bool TrueIfSigned;
4841 if (!isSignBitCheck(Pred, RHS: *RHS, TrueIfSigned))
4842 return;
4843 if (TrueIfSigned == CondIsTrue)
4844 KnownFromContext.signBitMustBeOne();
4845 else
4846 KnownFromContext.signBitMustBeZero();
4847 }
4848}
4849
4850static KnownFPClass computeKnownFPClassFromContext(const Value *V,
4851 const SimplifyQuery &Q) {
4852 KnownFPClass KnownFromContext;
4853
4854 if (Q.CC && Q.CC->AffectedValues.contains(Ptr: V))
4855 computeKnownFPClassFromCond(V, Cond: Q.CC->Cond, CondIsTrue: !Q.CC->Invert, CxtI: Q.CxtI,
4856 KnownFromContext);
4857
4858 if (!Q.CxtI)
4859 return KnownFromContext;
4860
4861 if (Q.DC && Q.DT) {
4862 // Handle dominating conditions.
4863 for (BranchInst *BI : Q.DC->conditionsFor(V)) {
4864 Value *Cond = BI->getCondition();
4865
4866 BasicBlockEdge Edge0(BI->getParent(), BI->getSuccessor(i: 0));
4867 if (Q.DT->dominates(BBE: Edge0, BB: Q.CxtI->getParent()))
4868 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/true, CxtI: Q.CxtI,
4869 KnownFromContext);
4870
4871 BasicBlockEdge Edge1(BI->getParent(), BI->getSuccessor(i: 1));
4872 if (Q.DT->dominates(BBE: Edge1, BB: Q.CxtI->getParent()))
4873 computeKnownFPClassFromCond(V, Cond, /*CondIsTrue=*/false, CxtI: Q.CxtI,
4874 KnownFromContext);
4875 }
4876 }
4877
4878 if (!Q.AC)
4879 return KnownFromContext;
4880
4881 // Try to restrict the floating-point classes based on information from
4882 // assumptions.
4883 for (auto &AssumeVH : Q.AC->assumptionsFor(V)) {
4884 if (!AssumeVH)
4885 continue;
4886 CallInst *I = cast<CallInst>(Val&: AssumeVH);
4887
4888 assert(I->getFunction() == Q.CxtI->getParent()->getParent() &&
4889 "Got assumption for the wrong function!");
4890 assert(I->getIntrinsicID() == Intrinsic::assume &&
4891 "must be an assume intrinsic");
4892
4893 if (!isValidAssumeForContext(Inv: I, CxtI: Q.CxtI, DT: Q.DT))
4894 continue;
4895
4896 computeKnownFPClassFromCond(V, Cond: I->getArgOperand(i: 0),
4897 /*CondIsTrue=*/true, CxtI: Q.CxtI, KnownFromContext);
4898 }
4899
4900 return KnownFromContext;
4901}
4902
4903void llvm::adjustKnownFPClassForSelectArm(KnownFPClass &Known, Value *Cond,
4904 Value *Arm, bool Invert,
4905 const SimplifyQuery &SQ,
4906 unsigned Depth) {
4907
4908 KnownFPClass KnownSrc;
4909 computeKnownFPClassFromCond(V: Arm, Cond,
4910 /*CondIsTrue=*/!Invert, CxtI: SQ.CxtI, KnownFromContext&: KnownSrc,
4911 Depth: Depth + 1);
4912 KnownSrc = KnownSrc.unionWith(RHS: Known);
4913 if (KnownSrc.isUnknown())
4914 return;
4915
4916 if (isGuaranteedNotToBeUndef(V: Arm, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT, Depth: Depth + 1))
4917 Known = KnownSrc;
4918}
4919
4920void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4921 FPClassTest InterestedClasses, KnownFPClass &Known,
4922 const SimplifyQuery &Q, unsigned Depth);
4923
4924static void computeKnownFPClass(const Value *V, KnownFPClass &Known,
4925 FPClassTest InterestedClasses,
4926 const SimplifyQuery &Q, unsigned Depth) {
4927 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
4928 APInt DemandedElts =
4929 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
4930 computeKnownFPClass(V, DemandedElts, InterestedClasses, Known, Q, Depth);
4931}
4932
4933static void computeKnownFPClassForFPTrunc(const Operator *Op,
4934 const APInt &DemandedElts,
4935 FPClassTest InterestedClasses,
4936 KnownFPClass &Known,
4937 const SimplifyQuery &Q,
4938 unsigned Depth) {
4939 if ((InterestedClasses &
4940 (KnownFPClass::OrderedLessThanZeroMask | fcNan)) == fcNone)
4941 return;
4942
4943 KnownFPClass KnownSrc;
4944 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
4945 Known&: KnownSrc, Q, Depth: Depth + 1);
4946 Known = KnownFPClass::fptrunc(KnownSrc);
4947}
4948
4949static constexpr KnownFPClass::MinMaxKind getMinMaxKind(Intrinsic::ID IID) {
4950 switch (IID) {
4951 case Intrinsic::minimum:
4952 return KnownFPClass::MinMaxKind::minimum;
4953 case Intrinsic::maximum:
4954 return KnownFPClass::MinMaxKind::maximum;
4955 case Intrinsic::minimumnum:
4956 return KnownFPClass::MinMaxKind::minimumnum;
4957 case Intrinsic::maximumnum:
4958 return KnownFPClass::MinMaxKind::maximumnum;
4959 case Intrinsic::minnum:
4960 return KnownFPClass::MinMaxKind::minnum;
4961 case Intrinsic::maxnum:
4962 return KnownFPClass::MinMaxKind::maxnum;
4963 default:
4964 llvm_unreachable("not a floating-point min-max intrinsic");
4965 }
4966}
4967
4968void computeKnownFPClass(const Value *V, const APInt &DemandedElts,
4969 FPClassTest InterestedClasses, KnownFPClass &Known,
4970 const SimplifyQuery &Q, unsigned Depth) {
4971 assert(Known.isUnknown() && "should not be called with known information");
4972
4973 if (!DemandedElts) {
4974 // No demanded elts, better to assume we don't know anything.
4975 Known.resetAll();
4976 return;
4977 }
4978
4979 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
4980
4981 if (auto *CFP = dyn_cast<ConstantFP>(Val: V)) {
4982 Known = KnownFPClass(CFP->getValueAPF());
4983 return;
4984 }
4985
4986 if (isa<ConstantAggregateZero>(Val: V)) {
4987 Known.KnownFPClasses = fcPosZero;
4988 Known.SignBit = false;
4989 return;
4990 }
4991
4992 if (isa<PoisonValue>(Val: V)) {
4993 Known.KnownFPClasses = fcNone;
4994 Known.SignBit = false;
4995 return;
4996 }
4997
4998 // Try to handle fixed width vector constants
4999 auto *VFVTy = dyn_cast<FixedVectorType>(Val: V->getType());
5000 const Constant *CV = dyn_cast<Constant>(Val: V);
5001 if (VFVTy && CV) {
5002 Known.KnownFPClasses = fcNone;
5003 bool SignBitAllZero = true;
5004 bool SignBitAllOne = true;
5005
5006 // For vectors, verify that each element is not NaN.
5007 unsigned NumElts = VFVTy->getNumElements();
5008 for (unsigned i = 0; i != NumElts; ++i) {
5009 if (!DemandedElts[i])
5010 continue;
5011
5012 Constant *Elt = CV->getAggregateElement(Elt: i);
5013 if (!Elt) {
5014 Known = KnownFPClass();
5015 return;
5016 }
5017 if (isa<PoisonValue>(Val: Elt))
5018 continue;
5019 auto *CElt = dyn_cast<ConstantFP>(Val: Elt);
5020 if (!CElt) {
5021 Known = KnownFPClass();
5022 return;
5023 }
5024
5025 const APFloat &C = CElt->getValueAPF();
5026 Known.KnownFPClasses |= C.classify();
5027 if (C.isNegative())
5028 SignBitAllZero = false;
5029 else
5030 SignBitAllOne = false;
5031 }
5032 if (SignBitAllOne != SignBitAllZero)
5033 Known.SignBit = SignBitAllOne;
5034 return;
5035 }
5036
5037 FPClassTest KnownNotFromFlags = fcNone;
5038 if (const auto *CB = dyn_cast<CallBase>(Val: V))
5039 KnownNotFromFlags |= CB->getRetNoFPClass();
5040 else if (const auto *Arg = dyn_cast<Argument>(Val: V))
5041 KnownNotFromFlags |= Arg->getNoFPClass();
5042
5043 const Operator *Op = dyn_cast<Operator>(Val: V);
5044 if (const FPMathOperator *FPOp = dyn_cast_or_null<FPMathOperator>(Val: Op)) {
5045 if (FPOp->hasNoNaNs())
5046 KnownNotFromFlags |= fcNan;
5047 if (FPOp->hasNoInfs())
5048 KnownNotFromFlags |= fcInf;
5049 }
5050
5051 KnownFPClass AssumedClasses = computeKnownFPClassFromContext(V, Q);
5052 KnownNotFromFlags |= ~AssumedClasses.KnownFPClasses;
5053
5054 // We no longer need to find out about these bits from inputs if we can
5055 // assume this from flags/attributes.
5056 InterestedClasses &= ~KnownNotFromFlags;
5057
5058 llvm::scope_exit ClearClassesFromFlags([=, &Known] {
5059 Known.knownNot(RuleOut: KnownNotFromFlags);
5060 if (!Known.SignBit && AssumedClasses.SignBit) {
5061 if (*AssumedClasses.SignBit)
5062 Known.signBitMustBeOne();
5063 else
5064 Known.signBitMustBeZero();
5065 }
5066 });
5067
5068 if (!Op)
5069 return;
5070
5071 // All recursive calls that increase depth must come after this.
5072 if (Depth == MaxAnalysisRecursionDepth)
5073 return;
5074
5075 const unsigned Opc = Op->getOpcode();
5076 switch (Opc) {
5077 case Instruction::FNeg: {
5078 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5079 Known, Q, Depth: Depth + 1);
5080 Known.fneg();
5081 break;
5082 }
5083 case Instruction::Select: {
5084 auto ComputeForArm = [&](Value *Arm, bool Invert) {
5085 KnownFPClass Res;
5086 computeKnownFPClass(V: Arm, DemandedElts, InterestedClasses, Known&: Res, Q,
5087 Depth: Depth + 1);
5088 adjustKnownFPClassForSelectArm(Known&: Res, Cond: Op->getOperand(i: 0), Arm, Invert, SQ: Q,
5089 Depth);
5090 return Res;
5091 };
5092 // Only known if known in both the LHS and RHS.
5093 Known =
5094 ComputeForArm(Op->getOperand(i: 1), /*Invert=*/false)
5095 .intersectWith(RHS: ComputeForArm(Op->getOperand(i: 2), /*Invert=*/true));
5096 break;
5097 }
5098 case Instruction::Load: {
5099 const MDNode *NoFPClass =
5100 cast<LoadInst>(Val: Op)->getMetadata(KindID: LLVMContext::MD_nofpclass);
5101 if (!NoFPClass)
5102 break;
5103
5104 ConstantInt *MaskVal =
5105 mdconst::extract<ConstantInt>(MD: NoFPClass->getOperand(I: 0));
5106 Known.knownNot(RuleOut: static_cast<FPClassTest>(MaskVal->getZExtValue()));
5107 break;
5108 }
5109 case Instruction::Call: {
5110 const CallInst *II = cast<CallInst>(Val: Op);
5111 const Intrinsic::ID IID = II->getIntrinsicID();
5112 switch (IID) {
5113 case Intrinsic::fabs: {
5114 if ((InterestedClasses & (fcNan | fcPositive)) != fcNone) {
5115 // If we only care about the sign bit we don't need to inspect the
5116 // operand.
5117 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5118 InterestedClasses, Known, Q, Depth: Depth + 1);
5119 }
5120
5121 Known.fabs();
5122 break;
5123 }
5124 case Intrinsic::copysign: {
5125 KnownFPClass KnownSign;
5126
5127 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5128 Known, Q, Depth: Depth + 1);
5129 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5130 Known&: KnownSign, Q, Depth: Depth + 1);
5131 Known.copysign(Sign: KnownSign);
5132 break;
5133 }
5134 case Intrinsic::fma:
5135 case Intrinsic::fmuladd: {
5136 if ((InterestedClasses & fcNegative) == fcNone)
5137 break;
5138
5139 // FIXME: This should check isGuaranteedNotToBeUndef
5140 if (II->getArgOperand(i: 0) == II->getArgOperand(i: 1)) {
5141 KnownFPClass KnownSrc, KnownAddend;
5142 computeKnownFPClass(V: II->getArgOperand(i: 2), DemandedElts,
5143 InterestedClasses, Known&: KnownAddend, Q, Depth: Depth + 1);
5144 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5145 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5146
5147 const Function *F = II->getFunction();
5148 const fltSemantics &FltSem =
5149 II->getType()->getScalarType()->getFltSemantics();
5150 DenormalMode Mode =
5151 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5152
5153 if (KnownNotFromFlags & fcNan) {
5154 KnownSrc.knownNot(RuleOut: fcNan);
5155 KnownAddend.knownNot(RuleOut: fcNan);
5156 }
5157
5158 if (KnownNotFromFlags & fcInf) {
5159 KnownSrc.knownNot(RuleOut: fcInf);
5160 KnownAddend.knownNot(RuleOut: fcInf);
5161 }
5162
5163 Known = KnownFPClass::fma_square(Squared: KnownSrc, Addend: KnownAddend, Mode);
5164 break;
5165 }
5166
5167 KnownFPClass KnownSrc[3];
5168 for (int I = 0; I != 3; ++I) {
5169 computeKnownFPClass(V: II->getArgOperand(i: I), DemandedElts,
5170 InterestedClasses, Known&: KnownSrc[I], Q, Depth: Depth + 1);
5171 if (KnownSrc[I].isUnknown())
5172 return;
5173
5174 if (KnownNotFromFlags & fcNan)
5175 KnownSrc[I].knownNot(RuleOut: fcNan);
5176 if (KnownNotFromFlags & fcInf)
5177 KnownSrc[I].knownNot(RuleOut: fcInf);
5178 }
5179
5180 const Function *F = II->getFunction();
5181 const fltSemantics &FltSem =
5182 II->getType()->getScalarType()->getFltSemantics();
5183 DenormalMode Mode =
5184 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5185 Known = KnownFPClass::fma(LHS: KnownSrc[0], RHS: KnownSrc[1], Addend: KnownSrc[2], Mode);
5186 break;
5187 }
5188 case Intrinsic::sqrt:
5189 case Intrinsic::experimental_constrained_sqrt: {
5190 KnownFPClass KnownSrc;
5191 FPClassTest InterestedSrcs = InterestedClasses;
5192 if (InterestedClasses & fcNan)
5193 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5194
5195 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5196 Known&: KnownSrc, Q, Depth: Depth + 1);
5197
5198 DenormalMode Mode = DenormalMode::getDynamic();
5199
5200 bool HasNSZ = Q.IIQ.hasNoSignedZeros(Op: II);
5201 if (!HasNSZ) {
5202 const Function *F = II->getFunction();
5203 const fltSemantics &FltSem =
5204 II->getType()->getScalarType()->getFltSemantics();
5205 Mode = F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5206 }
5207
5208 Known = KnownFPClass::sqrt(Src: KnownSrc, Mode);
5209 if (HasNSZ)
5210 Known.knownNot(RuleOut: fcNegZero);
5211
5212 break;
5213 }
5214 case Intrinsic::sin:
5215 case Intrinsic::cos: {
5216 // Return NaN on infinite inputs.
5217 KnownFPClass KnownSrc;
5218 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5219 Known&: KnownSrc, Q, Depth: Depth + 1);
5220 Known = IID == Intrinsic::sin ? KnownFPClass::sin(Src: KnownSrc)
5221 : KnownFPClass::cos(Src: KnownSrc);
5222 break;
5223 }
5224 case Intrinsic::maxnum:
5225 case Intrinsic::minnum:
5226 case Intrinsic::minimum:
5227 case Intrinsic::maximum:
5228 case Intrinsic::minimumnum:
5229 case Intrinsic::maximumnum: {
5230 KnownFPClass KnownLHS, KnownRHS;
5231 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5232 Known&: KnownLHS, Q, Depth: Depth + 1);
5233 computeKnownFPClass(V: II->getArgOperand(i: 1), DemandedElts, InterestedClasses,
5234 Known&: KnownRHS, Q, Depth: Depth + 1);
5235
5236 const Function *F = II->getFunction();
5237
5238 DenormalMode Mode =
5239 F ? F->getDenormalMode(
5240 FPType: II->getType()->getScalarType()->getFltSemantics())
5241 : DenormalMode::getDynamic();
5242
5243 Known = KnownFPClass::minMaxLike(LHS: KnownLHS, RHS: KnownRHS, Kind: getMinMaxKind(IID),
5244 DenormMode: Mode);
5245 break;
5246 }
5247 case Intrinsic::canonicalize: {
5248 KnownFPClass KnownSrc;
5249 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5250 Known&: KnownSrc, Q, Depth: Depth + 1);
5251
5252 const Function *F = II->getFunction();
5253 DenormalMode DenormMode =
5254 F ? F->getDenormalMode(
5255 FPType: II->getType()->getScalarType()->getFltSemantics())
5256 : DenormalMode::getDynamic();
5257 Known = KnownFPClass::canonicalize(Src: KnownSrc, DenormMode);
5258 break;
5259 }
5260 case Intrinsic::vector_reduce_fmax:
5261 case Intrinsic::vector_reduce_fmin:
5262 case Intrinsic::vector_reduce_fmaximum:
5263 case Intrinsic::vector_reduce_fminimum: {
5264 // reduce min/max will choose an element from one of the vector elements,
5265 // so we can infer and class information that is common to all elements.
5266 Known = computeKnownFPClass(V: II->getArgOperand(i: 0), FMF: II->getFastMathFlags(),
5267 InterestedClasses, SQ: Q, Depth: Depth + 1);
5268 // Can only propagate sign if output is never NaN.
5269 if (!Known.isKnownNeverNaN())
5270 Known.SignBit.reset();
5271 break;
5272 }
5273 // reverse preserves all characteristics of the input vec's element.
5274 case Intrinsic::vector_reverse:
5275 Known = computeKnownFPClass(
5276 V: II->getArgOperand(i: 0), DemandedElts: DemandedElts.reverseBits(),
5277 FMF: II->getFastMathFlags(), InterestedClasses, SQ: Q, Depth: Depth + 1);
5278 break;
5279 case Intrinsic::trunc:
5280 case Intrinsic::floor:
5281 case Intrinsic::ceil:
5282 case Intrinsic::rint:
5283 case Intrinsic::nearbyint:
5284 case Intrinsic::round:
5285 case Intrinsic::roundeven: {
5286 KnownFPClass KnownSrc;
5287 FPClassTest InterestedSrcs = InterestedClasses;
5288 if (InterestedSrcs & fcPosFinite)
5289 InterestedSrcs |= fcPosFinite;
5290 if (InterestedSrcs & fcNegFinite)
5291 InterestedSrcs |= fcNegFinite;
5292 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5293 Known&: KnownSrc, Q, Depth: Depth + 1);
5294
5295 Known = KnownFPClass::roundToIntegral(
5296 Src: KnownSrc, IsTrunc: IID == Intrinsic::trunc,
5297 IsMultiUnitFPType: V->getType()->getScalarType()->isMultiUnitFPType());
5298 break;
5299 }
5300 case Intrinsic::exp:
5301 case Intrinsic::exp2:
5302 case Intrinsic::exp10:
5303 case Intrinsic::amdgcn_exp2: {
5304 KnownFPClass KnownSrc;
5305 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5306 Known&: KnownSrc, Q, Depth: Depth + 1);
5307
5308 Known = KnownFPClass::exp(Src: KnownSrc);
5309
5310 Type *EltTy = II->getType()->getScalarType();
5311 if (IID == Intrinsic::amdgcn_exp2 && EltTy->isFloatTy())
5312 Known.knownNot(RuleOut: fcSubnormal);
5313
5314 break;
5315 }
5316 case Intrinsic::fptrunc_round: {
5317 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known,
5318 Q, Depth);
5319 break;
5320 }
5321 case Intrinsic::log:
5322 case Intrinsic::log10:
5323 case Intrinsic::log2:
5324 case Intrinsic::experimental_constrained_log:
5325 case Intrinsic::experimental_constrained_log10:
5326 case Intrinsic::experimental_constrained_log2:
5327 case Intrinsic::amdgcn_log: {
5328 Type *EltTy = II->getType()->getScalarType();
5329
5330 // log(+inf) -> +inf
5331 // log([+-]0.0) -> -inf
5332 // log(-inf) -> nan
5333 // log(-x) -> nan
5334 if ((InterestedClasses & (fcNan | fcInf)) != fcNone) {
5335 FPClassTest InterestedSrcs = InterestedClasses;
5336 if ((InterestedClasses & fcNegInf) != fcNone)
5337 InterestedSrcs |= fcZero | fcSubnormal;
5338 if ((InterestedClasses & fcNan) != fcNone)
5339 InterestedSrcs |= fcNan | fcNegative;
5340
5341 KnownFPClass KnownSrc;
5342 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5343 Known&: KnownSrc, Q, Depth: Depth + 1);
5344
5345 const Function *F = II->getFunction();
5346 DenormalMode Mode = F ? F->getDenormalMode(FPType: EltTy->getFltSemantics())
5347 : DenormalMode::getDynamic();
5348 Known = KnownFPClass::log(Src: KnownSrc, Mode);
5349 }
5350
5351 break;
5352 }
5353 case Intrinsic::powi: {
5354 if ((InterestedClasses & fcNegative) == fcNone)
5355 break;
5356
5357 const Value *Exp = II->getArgOperand(i: 1);
5358 Type *ExpTy = Exp->getType();
5359 unsigned BitWidth = ExpTy->getScalarType()->getIntegerBitWidth();
5360 KnownBits ExponentKnownBits(BitWidth);
5361 computeKnownBits(V: Exp, DemandedElts: isa<VectorType>(Val: ExpTy) ? DemandedElts : APInt(1, 1),
5362 Known&: ExponentKnownBits, Q, Depth: Depth + 1);
5363
5364 KnownFPClass KnownSrc;
5365 if (ExponentKnownBits.isZero() || !ExponentKnownBits.isEven()) {
5366 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses: fcNegative,
5367 Known&: KnownSrc, Q, Depth: Depth + 1);
5368 }
5369
5370 Known = KnownFPClass::powi(Src: KnownSrc, N: ExponentKnownBits);
5371 break;
5372 }
5373 case Intrinsic::ldexp: {
5374 KnownFPClass KnownSrc;
5375 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5376 Known&: KnownSrc, Q, Depth: Depth + 1);
5377 // Can refine inf/zero handling based on the exponent operand.
5378 const FPClassTest ExpInfoMask = fcZero | fcSubnormal | fcInf;
5379
5380 KnownBits ExpBits;
5381 if ((KnownSrc.KnownFPClasses & ExpInfoMask) != fcNone) {
5382 const Value *ExpArg = II->getArgOperand(i: 1);
5383 ExpBits = computeKnownBits(V: ExpArg, DemandedElts, Q, Depth: Depth + 1);
5384 }
5385
5386 const fltSemantics &Flt =
5387 II->getType()->getScalarType()->getFltSemantics();
5388
5389 const Function *F = II->getFunction();
5390 DenormalMode Mode =
5391 F ? F->getDenormalMode(FPType: Flt) : DenormalMode::getDynamic();
5392
5393 Known = KnownFPClass::ldexp(Src: KnownSrc, N: ExpBits, Flt, Mode);
5394 break;
5395 }
5396 case Intrinsic::arithmetic_fence: {
5397 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5398 Known, Q, Depth: Depth + 1);
5399 break;
5400 }
5401 case Intrinsic::experimental_constrained_sitofp:
5402 case Intrinsic::experimental_constrained_uitofp:
5403 // Cannot produce nan
5404 Known.knownNot(RuleOut: fcNan);
5405
5406 // sitofp and uitofp turn into +0.0 for zero.
5407 Known.knownNot(RuleOut: fcNegZero);
5408
5409 // Integers cannot be subnormal
5410 Known.knownNot(RuleOut: fcSubnormal);
5411
5412 if (IID == Intrinsic::experimental_constrained_uitofp)
5413 Known.signBitMustBeZero();
5414
5415 // TODO: Copy inf handling from instructions
5416 break;
5417
5418 case Intrinsic::amdgcn_fract: {
5419 Known.knownNot(RuleOut: fcInf);
5420
5421 if (InterestedClasses & fcNan) {
5422 KnownFPClass KnownSrc;
5423 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5424 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5425
5426 if (KnownSrc.isKnownNeverInfOrNaN())
5427 Known.knownNot(RuleOut: fcNan);
5428 else if (KnownSrc.isKnownNever(Mask: fcSNan))
5429 Known.knownNot(RuleOut: fcSNan);
5430 }
5431
5432 break;
5433 }
5434 case Intrinsic::amdgcn_rcp: {
5435 KnownFPClass KnownSrc;
5436 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5437 Known&: KnownSrc, Q, Depth: Depth + 1);
5438
5439 Known.propagateNaN(Src: KnownSrc);
5440
5441 Type *EltTy = II->getType()->getScalarType();
5442
5443 // f32 denormal always flushed.
5444 if (EltTy->isFloatTy()) {
5445 Known.knownNot(RuleOut: fcSubnormal);
5446 KnownSrc.knownNot(RuleOut: fcSubnormal);
5447 }
5448
5449 if (KnownSrc.isKnownNever(Mask: fcNegative))
5450 Known.knownNot(RuleOut: fcNegative);
5451 if (KnownSrc.isKnownNever(Mask: fcPositive))
5452 Known.knownNot(RuleOut: fcPositive);
5453
5454 if (const Function *F = II->getFunction()) {
5455 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5456 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5457 Known.knownNot(RuleOut: fcPosInf);
5458 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5459 Known.knownNot(RuleOut: fcNegInf);
5460 }
5461
5462 break;
5463 }
5464 case Intrinsic::amdgcn_rsq: {
5465 KnownFPClass KnownSrc;
5466 // The only negative value that can be returned is -inf for -0 inputs.
5467 Known.knownNot(RuleOut: fcNegZero | fcNegSubnormal | fcNegNormal);
5468
5469 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts, InterestedClasses,
5470 Known&: KnownSrc, Q, Depth: Depth + 1);
5471
5472 // Negative -> nan
5473 if (KnownSrc.isKnownNeverNaN() && KnownSrc.cannotBeOrderedLessThanZero())
5474 Known.knownNot(RuleOut: fcNan);
5475 else if (KnownSrc.isKnownNever(Mask: fcSNan))
5476 Known.knownNot(RuleOut: fcSNan);
5477
5478 // +inf -> +0
5479 if (KnownSrc.isKnownNeverPosInfinity())
5480 Known.knownNot(RuleOut: fcPosZero);
5481
5482 Type *EltTy = II->getType()->getScalarType();
5483
5484 // f32 denormal always flushed.
5485 if (EltTy->isFloatTy())
5486 Known.knownNot(RuleOut: fcPosSubnormal);
5487
5488 if (const Function *F = II->getFunction()) {
5489 DenormalMode Mode = F->getDenormalMode(FPType: EltTy->getFltSemantics());
5490
5491 // -0 -> -inf
5492 if (KnownSrc.isKnownNeverLogicalNegZero(Mode))
5493 Known.knownNot(RuleOut: fcNegInf);
5494
5495 // +0 -> +inf
5496 if (KnownSrc.isKnownNeverLogicalPosZero(Mode))
5497 Known.knownNot(RuleOut: fcPosInf);
5498 }
5499
5500 break;
5501 }
5502 case Intrinsic::amdgcn_trig_preop: {
5503 Known.knownNot(RuleOut: fcNan | fcInf);
5504 break;
5505 }
5506 default:
5507 break;
5508 }
5509
5510 break;
5511 }
5512 case Instruction::FAdd:
5513 case Instruction::FSub: {
5514 KnownFPClass KnownLHS, KnownRHS;
5515 bool WantNegative =
5516 Op->getOpcode() == Instruction::FAdd &&
5517 (InterestedClasses & KnownFPClass::OrderedLessThanZeroMask) != fcNone;
5518 bool WantNaN = (InterestedClasses & fcNan) != fcNone;
5519 bool WantNegZero = (InterestedClasses & fcNegZero) != fcNone;
5520
5521 if (!WantNaN && !WantNegative && !WantNegZero)
5522 break;
5523
5524 FPClassTest InterestedSrcs = InterestedClasses;
5525 if (WantNegative)
5526 InterestedSrcs |= KnownFPClass::OrderedLessThanZeroMask;
5527 if (InterestedClasses & fcNan)
5528 InterestedSrcs |= fcInf;
5529 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts, InterestedClasses: InterestedSrcs,
5530 Known&: KnownRHS, Q, Depth: Depth + 1);
5531
5532 // Special case fadd x, x, which is the canonical form of fmul x, 2.
5533 bool Self = Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5534 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT,
5535 Depth: Depth + 1);
5536 if (Self)
5537 KnownLHS = KnownRHS;
5538
5539 if ((WantNaN && KnownRHS.isKnownNeverNaN()) ||
5540 (WantNegative && KnownRHS.cannotBeOrderedLessThanZero()) ||
5541 WantNegZero || Opc == Instruction::FSub) {
5542
5543 // FIXME: Context function should always be passed in separately
5544 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5545 const fltSemantics &FltSem =
5546 Op->getType()->getScalarType()->getFltSemantics();
5547 DenormalMode Mode =
5548 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5549
5550 if (Self && Opc == Instruction::FAdd) {
5551 Known = KnownFPClass::fadd_self(Src: KnownLHS, Mode);
5552 } else {
5553 // RHS is canonically cheaper to compute. Skip inspecting the LHS if
5554 // there's no point.
5555
5556 if (!Self) {
5557 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: InterestedSrcs,
5558 Known&: KnownLHS, Q, Depth: Depth + 1);
5559 }
5560
5561 Known = Opc == Instruction::FAdd
5562 ? KnownFPClass::fadd(LHS: KnownLHS, RHS: KnownRHS, Mode)
5563 : KnownFPClass::fsub(LHS: KnownLHS, RHS: KnownRHS, Mode);
5564 }
5565 }
5566
5567 break;
5568 }
5569 case Instruction::FMul: {
5570 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5571 DenormalMode Mode =
5572 F ? F->getDenormalMode(
5573 FPType: Op->getType()->getScalarType()->getFltSemantics())
5574 : DenormalMode::getDynamic();
5575
5576 // X * X is always non-negative or a NaN.
5577 // FIXME: Should check isGuaranteedNotToBeUndef
5578 if (Op->getOperand(i: 0) == Op->getOperand(i: 1)) {
5579 KnownFPClass KnownSrc;
5580 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownSrc,
5581 Q, Depth: Depth + 1);
5582 Known = KnownFPClass::square(Src: KnownSrc, Mode);
5583 break;
5584 }
5585
5586 KnownFPClass KnownLHS, KnownRHS;
5587
5588 bool CannotBeSubnormal = false;
5589 const APFloat *CRHS;
5590 if (match(V: Op->getOperand(i: 1), P: m_APFloat(Res&: CRHS))) {
5591 // Match denormal scaling pattern, similar to the case in ldexp. If the
5592 // constant's exponent is sufficiently large, the result cannot be
5593 // subnormal.
5594
5595 // TODO: Should do general ConstantFPRange analysis.
5596 const fltSemantics &Flt =
5597 Op->getType()->getScalarType()->getFltSemantics();
5598 unsigned Precision = APFloat::semanticsPrecision(Flt);
5599 const int MantissaBits = Precision - 1;
5600
5601 int MinKnownExponent = ilogb(Arg: *CRHS);
5602 if (MinKnownExponent >= MantissaBits)
5603 CannotBeSubnormal = true;
5604
5605 KnownRHS = KnownFPClass(*CRHS);
5606 } else {
5607 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownRHS,
5608 Q, Depth: Depth + 1);
5609 }
5610
5611 // TODO: Improve accuracy in unfused FMA pattern. We can prove an additional
5612 // not-nan if the addend is known-not negative infinity if the multiply is
5613 // known-not infinity.
5614
5615 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS,
5616 Q, Depth: Depth + 1);
5617
5618 Known = KnownFPClass::fmul(LHS: KnownLHS, RHS: KnownRHS, Mode);
5619 if (CannotBeSubnormal)
5620 Known.knownNot(RuleOut: fcSubnormal);
5621 break;
5622 }
5623 case Instruction::FDiv:
5624 case Instruction::FRem: {
5625 const bool WantNan = (InterestedClasses & fcNan) != fcNone;
5626
5627 if (Op->getOperand(i: 0) == Op->getOperand(i: 1) &&
5628 isGuaranteedNotToBeUndef(V: Op->getOperand(i: 0), AC: Q.AC, CtxI: Q.CxtI, DT: Q.DT)) {
5629 if (Op->getOpcode() == Instruction::FDiv) {
5630 // X / X is always exactly 1.0 or a NaN.
5631 Known.KnownFPClasses = fcNan | fcPosNormal;
5632 } else {
5633 // X % X is always exactly [+-]0.0 or a NaN.
5634 Known.KnownFPClasses = fcNan | fcZero;
5635 }
5636
5637 if (!WantNan)
5638 break;
5639
5640 KnownFPClass KnownSrc;
5641 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts,
5642 InterestedClasses: fcNan | fcInf | fcZero | fcSubnormal, Known&: KnownSrc, Q,
5643 Depth: Depth + 1);
5644 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5645 const fltSemantics &FltSem =
5646 Op->getType()->getScalarType()->getFltSemantics();
5647
5648 DenormalMode Mode =
5649 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5650
5651 Known = Op->getOpcode() == Instruction::FDiv
5652 ? KnownFPClass::fdiv_self(Src: KnownSrc, Mode)
5653 : KnownFPClass::frem_self(Src: KnownSrc, Mode);
5654 break;
5655 }
5656
5657 const bool WantNegative = (InterestedClasses & fcNegative) != fcNone;
5658 const bool WantPositive =
5659 Opc == Instruction::FRem && (InterestedClasses & fcPositive) != fcNone;
5660 if (!WantNan && !WantNegative && !WantPositive)
5661 break;
5662
5663 KnownFPClass KnownLHS, KnownRHS;
5664
5665 computeKnownFPClass(V: Op->getOperand(i: 1), DemandedElts,
5666 InterestedClasses: fcNan | fcInf | fcZero | fcNegative, Known&: KnownRHS, Q,
5667 Depth: Depth + 1);
5668
5669 bool KnowSomethingUseful = KnownRHS.isKnownNeverNaN() ||
5670 KnownRHS.isKnownNever(Mask: fcNegative) ||
5671 KnownRHS.isKnownNever(Mask: fcPositive);
5672
5673 if (KnowSomethingUseful || WantPositive) {
5674 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses: fcAllFlags, Known&: KnownLHS,
5675 Q, Depth: Depth + 1);
5676 }
5677
5678 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5679 const fltSemantics &FltSem =
5680 Op->getType()->getScalarType()->getFltSemantics();
5681
5682 if (Op->getOpcode() == Instruction::FDiv) {
5683 DenormalMode Mode =
5684 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5685 Known = KnownFPClass::fdiv(LHS: KnownLHS, RHS: KnownRHS, Mode);
5686 } else {
5687 // Inf REM x and x REM 0 produce NaN.
5688 if (KnownLHS.isKnownNeverNaN() && KnownRHS.isKnownNeverNaN() &&
5689 KnownLHS.isKnownNeverInfinity() && F &&
5690 KnownRHS.isKnownNeverLogicalZero(Mode: F->getDenormalMode(FPType: FltSem))) {
5691 Known.knownNot(RuleOut: fcNan);
5692 }
5693
5694 // The sign for frem is the same as the first operand.
5695 if (KnownLHS.cannotBeOrderedLessThanZero())
5696 Known.knownNot(RuleOut: KnownFPClass::OrderedLessThanZeroMask);
5697 if (KnownLHS.cannotBeOrderedGreaterThanZero())
5698 Known.knownNot(RuleOut: KnownFPClass::OrderedGreaterThanZeroMask);
5699
5700 // See if we can be more aggressive about the sign of 0.
5701 if (KnownLHS.isKnownNever(Mask: fcNegative))
5702 Known.knownNot(RuleOut: fcNegative);
5703 if (KnownLHS.isKnownNever(Mask: fcPositive))
5704 Known.knownNot(RuleOut: fcPositive);
5705 }
5706
5707 break;
5708 }
5709 case Instruction::FPExt: {
5710 KnownFPClass KnownSrc;
5711 computeKnownFPClass(V: Op->getOperand(i: 0), DemandedElts, InterestedClasses,
5712 Known&: KnownSrc, Q, Depth: Depth + 1);
5713
5714 const fltSemantics &DstTy =
5715 Op->getType()->getScalarType()->getFltSemantics();
5716 const fltSemantics &SrcTy =
5717 Op->getOperand(i: 0)->getType()->getScalarType()->getFltSemantics();
5718
5719 Known = KnownFPClass::fpext(KnownSrc, DstTy, SrcTy);
5720 break;
5721 }
5722 case Instruction::FPTrunc: {
5723 computeKnownFPClassForFPTrunc(Op, DemandedElts, InterestedClasses, Known, Q,
5724 Depth);
5725 break;
5726 }
5727 case Instruction::SIToFP:
5728 case Instruction::UIToFP: {
5729 // Cannot produce nan
5730 Known.knownNot(RuleOut: fcNan);
5731
5732 // Integers cannot be subnormal
5733 Known.knownNot(RuleOut: fcSubnormal);
5734
5735 // sitofp and uitofp turn into +0.0 for zero.
5736 Known.knownNot(RuleOut: fcNegZero);
5737 if (Op->getOpcode() == Instruction::UIToFP)
5738 Known.signBitMustBeZero();
5739
5740 if (InterestedClasses & fcInf) {
5741 // Get width of largest magnitude integer (remove a bit if signed).
5742 // This still works for a signed minimum value because the largest FP
5743 // value is scaled by some fraction close to 2.0 (1.0 + 0.xxxx).
5744 int IntSize = Op->getOperand(i: 0)->getType()->getScalarSizeInBits();
5745 if (Op->getOpcode() == Instruction::SIToFP)
5746 --IntSize;
5747
5748 // If the exponent of the largest finite FP value can hold the largest
5749 // integer, the result of the cast must be finite.
5750 Type *FPTy = Op->getType()->getScalarType();
5751 if (ilogb(Arg: APFloat::getLargest(Sem: FPTy->getFltSemantics())) >= IntSize)
5752 Known.knownNot(RuleOut: fcInf);
5753 }
5754
5755 break;
5756 }
5757 case Instruction::ExtractElement: {
5758 // Look through extract element. If the index is non-constant or
5759 // out-of-range demand all elements, otherwise just the extracted element.
5760 const Value *Vec = Op->getOperand(i: 0);
5761
5762 APInt DemandedVecElts;
5763 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType())) {
5764 unsigned NumElts = VecTy->getNumElements();
5765 DemandedVecElts = APInt::getAllOnes(numBits: NumElts);
5766 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 1));
5767 if (CIdx && CIdx->getValue().ult(RHS: NumElts))
5768 DemandedVecElts = APInt::getOneBitSet(numBits: NumElts, BitNo: CIdx->getZExtValue());
5769 } else {
5770 DemandedVecElts = APInt(1, 1);
5771 }
5772
5773 return computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known,
5774 Q, Depth: Depth + 1);
5775 }
5776 case Instruction::InsertElement: {
5777 if (isa<ScalableVectorType>(Val: Op->getType()))
5778 return;
5779
5780 const Value *Vec = Op->getOperand(i: 0);
5781 const Value *Elt = Op->getOperand(i: 1);
5782 auto *CIdx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: 2));
5783 unsigned NumElts = DemandedElts.getBitWidth();
5784 APInt DemandedVecElts = DemandedElts;
5785 bool NeedsElt = true;
5786 // If we know the index we are inserting to, clear it from Vec check.
5787 if (CIdx && CIdx->getValue().ult(RHS: NumElts)) {
5788 DemandedVecElts.clearBit(BitPosition: CIdx->getZExtValue());
5789 NeedsElt = DemandedElts[CIdx->getZExtValue()];
5790 }
5791
5792 // Do we demand the inserted element?
5793 if (NeedsElt) {
5794 computeKnownFPClass(V: Elt, Known, InterestedClasses, Q, Depth: Depth + 1);
5795 // If we don't know any bits, early out.
5796 if (Known.isUnknown())
5797 break;
5798 } else {
5799 Known.KnownFPClasses = fcNone;
5800 }
5801
5802 // Do we need anymore elements from Vec?
5803 if (!DemandedVecElts.isZero()) {
5804 KnownFPClass Known2;
5805 computeKnownFPClass(V: Vec, DemandedElts: DemandedVecElts, InterestedClasses, Known&: Known2, Q,
5806 Depth: Depth + 1);
5807 Known |= Known2;
5808 }
5809
5810 break;
5811 }
5812 case Instruction::ShuffleVector: {
5813 // Handle vector splat idiom
5814 if (Value *Splat = getSplatValue(V)) {
5815 computeKnownFPClass(V: Splat, Known, InterestedClasses, Q, Depth: Depth + 1);
5816 break;
5817 }
5818
5819 // For undef elements, we don't know anything about the common state of
5820 // the shuffle result.
5821 APInt DemandedLHS, DemandedRHS;
5822 auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Op);
5823 if (!Shuf || !getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS))
5824 return;
5825
5826 if (!!DemandedLHS) {
5827 const Value *LHS = Shuf->getOperand(i_nocapture: 0);
5828 computeKnownFPClass(V: LHS, DemandedElts: DemandedLHS, InterestedClasses, Known, Q,
5829 Depth: Depth + 1);
5830
5831 // If we don't know any bits, early out.
5832 if (Known.isUnknown())
5833 break;
5834 } else {
5835 Known.KnownFPClasses = fcNone;
5836 }
5837
5838 if (!!DemandedRHS) {
5839 KnownFPClass Known2;
5840 const Value *RHS = Shuf->getOperand(i_nocapture: 1);
5841 computeKnownFPClass(V: RHS, DemandedElts: DemandedRHS, InterestedClasses, Known&: Known2, Q,
5842 Depth: Depth + 1);
5843 Known |= Known2;
5844 }
5845
5846 break;
5847 }
5848 case Instruction::ExtractValue: {
5849 const ExtractValueInst *Extract = cast<ExtractValueInst>(Val: Op);
5850 ArrayRef<unsigned> Indices = Extract->getIndices();
5851 const Value *Src = Extract->getAggregateOperand();
5852 if (isa<StructType>(Val: Src->getType()) && Indices.size() == 1 &&
5853 Indices[0] == 0) {
5854 if (const auto *II = dyn_cast<IntrinsicInst>(Val: Src)) {
5855 switch (II->getIntrinsicID()) {
5856 case Intrinsic::frexp: {
5857 Known.knownNot(RuleOut: fcSubnormal);
5858
5859 KnownFPClass KnownSrc;
5860 computeKnownFPClass(V: II->getArgOperand(i: 0), DemandedElts,
5861 InterestedClasses, Known&: KnownSrc, Q, Depth: Depth + 1);
5862
5863 const Function *F = cast<Instruction>(Val: Op)->getFunction();
5864 const fltSemantics &FltSem =
5865 Op->getType()->getScalarType()->getFltSemantics();
5866
5867 DenormalMode Mode =
5868 F ? F->getDenormalMode(FPType: FltSem) : DenormalMode::getDynamic();
5869 Known = KnownFPClass::frexp_mant(Src: KnownSrc, Mode);
5870 return;
5871 }
5872 default:
5873 break;
5874 }
5875 }
5876 }
5877
5878 computeKnownFPClass(V: Src, DemandedElts, InterestedClasses, Known, Q,
5879 Depth: Depth + 1);
5880 break;
5881 }
5882 case Instruction::PHI: {
5883 const PHINode *P = cast<PHINode>(Val: Op);
5884 // Unreachable blocks may have zero-operand PHI nodes.
5885 if (P->getNumIncomingValues() == 0)
5886 break;
5887
5888 // Otherwise take the unions of the known bit sets of the operands,
5889 // taking conservative care to avoid excessive recursion.
5890 const unsigned PhiRecursionLimit = MaxAnalysisRecursionDepth - 2;
5891
5892 if (Depth < PhiRecursionLimit) {
5893 // Skip if every incoming value references to ourself.
5894 if (isa_and_nonnull<UndefValue>(Val: P->hasConstantValue()))
5895 break;
5896
5897 bool First = true;
5898
5899 for (const Use &U : P->operands()) {
5900 Value *IncValue;
5901 Instruction *CxtI;
5902 breakSelfRecursivePHI(U: &U, PHI: P, ValOut&: IncValue, CtxIOut&: CxtI);
5903 // Skip direct self references.
5904 if (IncValue == P)
5905 continue;
5906
5907 KnownFPClass KnownSrc;
5908 // Recurse, but cap the recursion to two levels, because we don't want
5909 // to waste time spinning around in loops. We need at least depth 2 to
5910 // detect known sign bits.
5911 computeKnownFPClass(V: IncValue, DemandedElts, InterestedClasses, Known&: KnownSrc,
5912 Q: Q.getWithoutCondContext().getWithInstruction(I: CxtI),
5913 Depth: PhiRecursionLimit);
5914
5915 if (First) {
5916 Known = KnownSrc;
5917 First = false;
5918 } else {
5919 Known |= KnownSrc;
5920 }
5921
5922 if (Known.KnownFPClasses == fcAllFlags)
5923 break;
5924 }
5925 }
5926
5927 break;
5928 }
5929 case Instruction::BitCast: {
5930 const Value *Src;
5931 if (!match(V: Op, P: m_ElementWiseBitCast(Op: m_Value(V&: Src))) ||
5932 !Src->getType()->isIntOrIntVectorTy())
5933 break;
5934
5935 const Type *Ty = Op->getType()->getScalarType();
5936 KnownBits Bits(Ty->getScalarSizeInBits());
5937 computeKnownBits(V: Src, DemandedElts, Known&: Bits, Q, Depth: Depth + 1);
5938
5939 // Transfer information from the sign bit.
5940 if (Bits.isNonNegative())
5941 Known.signBitMustBeZero();
5942 else if (Bits.isNegative())
5943 Known.signBitMustBeOne();
5944
5945 if (Ty->isIEEELikeFPTy()) {
5946 // IEEE floats are NaN when all bits of the exponent plus at least one of
5947 // the fraction bits are 1. This means:
5948 // - If we assume unknown bits are 0 and the value is NaN, it will
5949 // always be NaN
5950 // - If we assume unknown bits are 1 and the value is not NaN, it can
5951 // never be NaN
5952 // Note: They do not hold for x86_fp80 format.
5953 if (APFloat(Ty->getFltSemantics(), Bits.One).isNaN())
5954 Known.KnownFPClasses = fcNan;
5955 else if (!APFloat(Ty->getFltSemantics(), ~Bits.Zero).isNaN())
5956 Known.knownNot(RuleOut: fcNan);
5957
5958 // Build KnownBits representing Inf and check if it must be equal or
5959 // unequal to this value.
5960 auto InfKB = KnownBits::makeConstant(
5961 C: APFloat::getInf(Sem: Ty->getFltSemantics()).bitcastToAPInt());
5962 InfKB.Zero.clearSignBit();
5963 if (const auto InfResult = KnownBits::eq(LHS: Bits, RHS: InfKB)) {
5964 assert(!InfResult.value());
5965 Known.knownNot(RuleOut: fcInf);
5966 } else if (Bits == InfKB) {
5967 Known.KnownFPClasses = fcInf;
5968 }
5969
5970 // Build KnownBits representing Zero and check if it must be equal or
5971 // unequal to this value.
5972 auto ZeroKB = KnownBits::makeConstant(
5973 C: APFloat::getZero(Sem: Ty->getFltSemantics()).bitcastToAPInt());
5974 ZeroKB.Zero.clearSignBit();
5975 if (const auto ZeroResult = KnownBits::eq(LHS: Bits, RHS: ZeroKB)) {
5976 assert(!ZeroResult.value());
5977 Known.knownNot(RuleOut: fcZero);
5978 } else if (Bits == ZeroKB) {
5979 Known.KnownFPClasses = fcZero;
5980 }
5981 }
5982
5983 break;
5984 }
5985 default:
5986 break;
5987 }
5988}
5989
5990KnownFPClass llvm::computeKnownFPClass(const Value *V,
5991 const APInt &DemandedElts,
5992 FPClassTest InterestedClasses,
5993 const SimplifyQuery &SQ,
5994 unsigned Depth) {
5995 KnownFPClass KnownClasses;
5996 ::computeKnownFPClass(V, DemandedElts, InterestedClasses, Known&: KnownClasses, Q: SQ,
5997 Depth);
5998 return KnownClasses;
5999}
6000
6001KnownFPClass llvm::computeKnownFPClass(const Value *V,
6002 FPClassTest InterestedClasses,
6003 const SimplifyQuery &SQ,
6004 unsigned Depth) {
6005 KnownFPClass Known;
6006 ::computeKnownFPClass(V, Known, InterestedClasses, Q: SQ, Depth);
6007 return Known;
6008}
6009
6010KnownFPClass llvm::computeKnownFPClass(
6011 const Value *V, const DataLayout &DL, FPClassTest InterestedClasses,
6012 const TargetLibraryInfo *TLI, AssumptionCache *AC, const Instruction *CxtI,
6013 const DominatorTree *DT, bool UseInstrInfo, unsigned Depth) {
6014 return computeKnownFPClass(V, InterestedClasses,
6015 SQ: SimplifyQuery(DL, TLI, DT, AC, CxtI, UseInstrInfo),
6016 Depth);
6017}
6018
6019KnownFPClass
6020llvm::computeKnownFPClass(const Value *V, const APInt &DemandedElts,
6021 FastMathFlags FMF, FPClassTest InterestedClasses,
6022 const SimplifyQuery &SQ, unsigned Depth) {
6023 if (FMF.noNaNs())
6024 InterestedClasses &= ~fcNan;
6025 if (FMF.noInfs())
6026 InterestedClasses &= ~fcInf;
6027
6028 KnownFPClass Result =
6029 computeKnownFPClass(V, DemandedElts, InterestedClasses, SQ, Depth);
6030
6031 if (FMF.noNaNs())
6032 Result.KnownFPClasses &= ~fcNan;
6033 if (FMF.noInfs())
6034 Result.KnownFPClasses &= ~fcInf;
6035 return Result;
6036}
6037
6038KnownFPClass llvm::computeKnownFPClass(const Value *V, FastMathFlags FMF,
6039 FPClassTest InterestedClasses,
6040 const SimplifyQuery &SQ,
6041 unsigned Depth) {
6042 auto *FVTy = dyn_cast<FixedVectorType>(Val: V->getType());
6043 APInt DemandedElts =
6044 FVTy ? APInt::getAllOnes(numBits: FVTy->getNumElements()) : APInt(1, 1);
6045 return computeKnownFPClass(V, DemandedElts, FMF, InterestedClasses, SQ,
6046 Depth);
6047}
6048
6049bool llvm::cannotBeNegativeZero(const Value *V, const SimplifyQuery &SQ,
6050 unsigned Depth) {
6051 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNegZero, SQ, Depth);
6052 return Known.isKnownNeverNegZero();
6053}
6054
6055bool llvm::cannotBeOrderedLessThanZero(const Value *V, const SimplifyQuery &SQ,
6056 unsigned Depth) {
6057 KnownFPClass Known =
6058 computeKnownFPClass(V, InterestedClasses: KnownFPClass::OrderedLessThanZeroMask, SQ, Depth);
6059 return Known.cannotBeOrderedLessThanZero();
6060}
6061
6062bool llvm::isKnownNeverInfinity(const Value *V, const SimplifyQuery &SQ,
6063 unsigned Depth) {
6064 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf, SQ, Depth);
6065 return Known.isKnownNeverInfinity();
6066}
6067
6068/// Return true if the floating-point value can never contain a NaN or infinity.
6069bool llvm::isKnownNeverInfOrNaN(const Value *V, const SimplifyQuery &SQ,
6070 unsigned Depth) {
6071 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcInf | fcNan, SQ, Depth);
6072 return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity();
6073}
6074
6075/// Return true if the floating-point scalar value is not a NaN or if the
6076/// floating-point vector value has no NaN elements. Return false if a value
6077/// could ever be NaN.
6078bool llvm::isKnownNeverNaN(const Value *V, const SimplifyQuery &SQ,
6079 unsigned Depth) {
6080 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcNan, SQ, Depth);
6081 return Known.isKnownNeverNaN();
6082}
6083
6084/// Return false if we can prove that the specified FP value's sign bit is 0.
6085/// Return true if we can prove that the specified FP value's sign bit is 1.
6086/// Otherwise return std::nullopt.
6087std::optional<bool> llvm::computeKnownFPSignBit(const Value *V,
6088 const SimplifyQuery &SQ,
6089 unsigned Depth) {
6090 KnownFPClass Known = computeKnownFPClass(V, InterestedClasses: fcAllFlags, SQ, Depth);
6091 return Known.SignBit;
6092}
6093
6094bool llvm::canIgnoreSignBitOfZero(const Use &U) {
6095 auto *User = cast<Instruction>(Val: U.getUser());
6096 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6097 if (FPOp->hasNoSignedZeros())
6098 return true;
6099 }
6100
6101 switch (User->getOpcode()) {
6102 case Instruction::FPToSI:
6103 case Instruction::FPToUI:
6104 return true;
6105 case Instruction::FCmp:
6106 // fcmp treats both positive and negative zero as equal.
6107 return true;
6108 case Instruction::Call:
6109 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6110 switch (II->getIntrinsicID()) {
6111 case Intrinsic::fabs:
6112 return true;
6113 case Intrinsic::copysign:
6114 return U.getOperandNo() == 0;
6115 case Intrinsic::is_fpclass:
6116 case Intrinsic::vp_is_fpclass: {
6117 auto Test =
6118 static_cast<FPClassTest>(
6119 cast<ConstantInt>(Val: II->getArgOperand(i: 1))->getZExtValue()) &
6120 FPClassTest::fcZero;
6121 return Test == FPClassTest::fcZero || Test == FPClassTest::fcNone;
6122 }
6123 default:
6124 return false;
6125 }
6126 }
6127 return false;
6128 default:
6129 return false;
6130 }
6131}
6132
6133bool llvm::canIgnoreSignBitOfNaN(const Use &U) {
6134 auto *User = cast<Instruction>(Val: U.getUser());
6135 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: User)) {
6136 if (FPOp->hasNoNaNs())
6137 return true;
6138 }
6139
6140 switch (User->getOpcode()) {
6141 case Instruction::FPToSI:
6142 case Instruction::FPToUI:
6143 return true;
6144 // Proper FP math operations ignore the sign bit of NaN.
6145 case Instruction::FAdd:
6146 case Instruction::FSub:
6147 case Instruction::FMul:
6148 case Instruction::FDiv:
6149 case Instruction::FRem:
6150 case Instruction::FPTrunc:
6151 case Instruction::FPExt:
6152 case Instruction::FCmp:
6153 return true;
6154 // Bitwise FP operations should preserve the sign bit of NaN.
6155 case Instruction::FNeg:
6156 case Instruction::Select:
6157 case Instruction::PHI:
6158 return false;
6159 case Instruction::Ret:
6160 return User->getFunction()->getAttributes().getRetNoFPClass() &
6161 FPClassTest::fcNan;
6162 case Instruction::Call:
6163 case Instruction::Invoke: {
6164 if (auto *II = dyn_cast<IntrinsicInst>(Val: User)) {
6165 switch (II->getIntrinsicID()) {
6166 case Intrinsic::fabs:
6167 return true;
6168 case Intrinsic::copysign:
6169 return U.getOperandNo() == 0;
6170 // Other proper FP math intrinsics ignore the sign bit of NaN.
6171 case Intrinsic::maxnum:
6172 case Intrinsic::minnum:
6173 case Intrinsic::maximum:
6174 case Intrinsic::minimum:
6175 case Intrinsic::maximumnum:
6176 case Intrinsic::minimumnum:
6177 case Intrinsic::canonicalize:
6178 case Intrinsic::fma:
6179 case Intrinsic::fmuladd:
6180 case Intrinsic::sqrt:
6181 case Intrinsic::pow:
6182 case Intrinsic::powi:
6183 case Intrinsic::fptoui_sat:
6184 case Intrinsic::fptosi_sat:
6185 case Intrinsic::is_fpclass:
6186 case Intrinsic::vp_is_fpclass:
6187 return true;
6188 default:
6189 return false;
6190 }
6191 }
6192
6193 FPClassTest NoFPClass =
6194 cast<CallBase>(Val: User)->getParamNoFPClass(i: U.getOperandNo());
6195 return NoFPClass & FPClassTest::fcNan;
6196 }
6197 default:
6198 return false;
6199 }
6200}
6201
6202bool llvm::isKnownIntegral(const Value *V, const SimplifyQuery &SQ,
6203 FastMathFlags FMF) {
6204 if (isa<PoisonValue>(Val: V))
6205 return true;
6206 if (isa<UndefValue>(Val: V))
6207 return false;
6208
6209 if (match(V, P: m_CheckedFp(CheckFn: [](const APFloat &Val) { return Val.isInteger(); })))
6210 return true;
6211
6212 const Instruction *I = dyn_cast<Instruction>(Val: V);
6213 if (!I)
6214 return false;
6215
6216 switch (I->getOpcode()) {
6217 case Instruction::SIToFP:
6218 case Instruction::UIToFP:
6219 // TODO: Could check nofpclass(inf) on incoming argument
6220 if (FMF.noInfs())
6221 return true;
6222
6223 // Need to check int size cannot produce infinity, which computeKnownFPClass
6224 // knows how to do already.
6225 return isKnownNeverInfinity(V: I, SQ);
6226 case Instruction::Call: {
6227 const CallInst *CI = cast<CallInst>(Val: I);
6228 switch (CI->getIntrinsicID()) {
6229 case Intrinsic::trunc:
6230 case Intrinsic::floor:
6231 case Intrinsic::ceil:
6232 case Intrinsic::rint:
6233 case Intrinsic::nearbyint:
6234 case Intrinsic::round:
6235 case Intrinsic::roundeven:
6236 return (FMF.noInfs() && FMF.noNaNs()) || isKnownNeverInfOrNaN(V: I, SQ);
6237 default:
6238 break;
6239 }
6240
6241 break;
6242 }
6243 default:
6244 break;
6245 }
6246
6247 return false;
6248}
6249
6250Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) {
6251
6252 // All byte-wide stores are splatable, even of arbitrary variables.
6253 if (V->getType()->isIntegerTy(Bitwidth: 8))
6254 return V;
6255
6256 LLVMContext &Ctx = V->getContext();
6257
6258 // Undef don't care.
6259 auto *UndefInt8 = UndefValue::get(T: Type::getInt8Ty(C&: Ctx));
6260 if (isa<UndefValue>(Val: V))
6261 return UndefInt8;
6262
6263 // Return poison for zero-sized type.
6264 if (DL.getTypeStoreSize(Ty: V->getType()).isZero())
6265 return PoisonValue::get(T: Type::getInt8Ty(C&: Ctx));
6266
6267 Constant *C = dyn_cast<Constant>(Val: V);
6268 if (!C) {
6269 // Conceptually, we could handle things like:
6270 // %a = zext i8 %X to i16
6271 // %b = shl i16 %a, 8
6272 // %c = or i16 %a, %b
6273 // but until there is an example that actually needs this, it doesn't seem
6274 // worth worrying about.
6275 return nullptr;
6276 }
6277
6278 // Handle 'null' ConstantArrayZero etc.
6279 if (C->isNullValue())
6280 return Constant::getNullValue(Ty: Type::getInt8Ty(C&: Ctx));
6281
6282 // Constant floating-point values can be handled as integer values if the
6283 // corresponding integer value is "byteable". An important case is 0.0.
6284 if (ConstantFP *CFP = dyn_cast<ConstantFP>(Val: C)) {
6285 Type *ScalarTy = CFP->getType()->getScalarType();
6286 if (ScalarTy->isHalfTy() || ScalarTy->isFloatTy() || ScalarTy->isDoubleTy())
6287 return isBytewiseValue(
6288 V: ConstantInt::get(Context&: Ctx, V: CFP->getValue().bitcastToAPInt()), DL);
6289
6290 // Don't handle long double formats, which have strange constraints.
6291 return nullptr;
6292 }
6293
6294 // We can handle constant integers that are multiple of 8 bits.
6295 if (ConstantInt *CI = dyn_cast<ConstantInt>(Val: C)) {
6296 if (CI->getBitWidth() % 8 == 0) {
6297 if (!CI->getValue().isSplat(SplatSizeInBits: 8))
6298 return nullptr;
6299 return ConstantInt::get(Context&: Ctx, V: CI->getValue().trunc(width: 8));
6300 }
6301 }
6302
6303 if (auto *CE = dyn_cast<ConstantExpr>(Val: C)) {
6304 if (CE->getOpcode() == Instruction::IntToPtr) {
6305 if (auto *PtrTy = dyn_cast<PointerType>(Val: CE->getType())) {
6306 unsigned BitWidth = DL.getPointerSizeInBits(AS: PtrTy->getAddressSpace());
6307 if (Constant *Op = ConstantFoldIntegerCast(
6308 C: CE->getOperand(i_nocapture: 0), DestTy: Type::getIntNTy(C&: Ctx, N: BitWidth), IsSigned: false, DL))
6309 return isBytewiseValue(V: Op, DL);
6310 }
6311 }
6312 }
6313
6314 auto Merge = [&](Value *LHS, Value *RHS) -> Value * {
6315 if (LHS == RHS)
6316 return LHS;
6317 if (!LHS || !RHS)
6318 return nullptr;
6319 if (LHS == UndefInt8)
6320 return RHS;
6321 if (RHS == UndefInt8)
6322 return LHS;
6323 return nullptr;
6324 };
6325
6326 if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(Val: C)) {
6327 Value *Val = UndefInt8;
6328 for (uint64_t I = 0, E = CA->getNumElements(); I != E; ++I)
6329 if (!(Val = Merge(Val, isBytewiseValue(V: CA->getElementAsConstant(i: I), DL))))
6330 return nullptr;
6331 return Val;
6332 }
6333
6334 if (isa<ConstantAggregate>(Val: C)) {
6335 Value *Val = UndefInt8;
6336 for (Value *Op : C->operands())
6337 if (!(Val = Merge(Val, isBytewiseValue(V: Op, DL))))
6338 return nullptr;
6339 return Val;
6340 }
6341
6342 // Don't try to handle the handful of other constants.
6343 return nullptr;
6344}
6345
6346// This is the recursive version of BuildSubAggregate. It takes a few different
6347// arguments. Idxs is the index within the nested struct From that we are
6348// looking at now (which is of type IndexedType). IdxSkip is the number of
6349// indices from Idxs that should be left out when inserting into the resulting
6350// struct. To is the result struct built so far, new insertvalue instructions
6351// build on that.
6352static Value *BuildSubAggregate(Value *From, Value *To, Type *IndexedType,
6353 SmallVectorImpl<unsigned> &Idxs,
6354 unsigned IdxSkip,
6355 BasicBlock::iterator InsertBefore) {
6356 StructType *STy = dyn_cast<StructType>(Val: IndexedType);
6357 if (STy) {
6358 // Save the original To argument so we can modify it
6359 Value *OrigTo = To;
6360 // General case, the type indexed by Idxs is a struct
6361 for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) {
6362 // Process each struct element recursively
6363 Idxs.push_back(Elt: i);
6364 Value *PrevTo = To;
6365 To = BuildSubAggregate(From, To, IndexedType: STy->getElementType(N: i), Idxs, IdxSkip,
6366 InsertBefore);
6367 Idxs.pop_back();
6368 if (!To) {
6369 // Couldn't find any inserted value for this index? Cleanup
6370 while (PrevTo != OrigTo) {
6371 InsertValueInst* Del = cast<InsertValueInst>(Val: PrevTo);
6372 PrevTo = Del->getAggregateOperand();
6373 Del->eraseFromParent();
6374 }
6375 // Stop processing elements
6376 break;
6377 }
6378 }
6379 // If we successfully found a value for each of our subaggregates
6380 if (To)
6381 return To;
6382 }
6383 // Base case, the type indexed by SourceIdxs is not a struct, or not all of
6384 // the struct's elements had a value that was inserted directly. In the latter
6385 // case, perhaps we can't determine each of the subelements individually, but
6386 // we might be able to find the complete struct somewhere.
6387
6388 // Find the value that is at that particular spot
6389 Value *V = FindInsertedValue(V: From, idx_range: Idxs);
6390
6391 if (!V)
6392 return nullptr;
6393
6394 // Insert the value in the new (sub) aggregate
6395 return InsertValueInst::Create(Agg: To, Val: V, Idxs: ArrayRef(Idxs).slice(N: IdxSkip), NameStr: "tmp",
6396 InsertBefore);
6397}
6398
6399// This helper takes a nested struct and extracts a part of it (which is again a
6400// struct) into a new value. For example, given the struct:
6401// { a, { b, { c, d }, e } }
6402// and the indices "1, 1" this returns
6403// { c, d }.
6404//
6405// It does this by inserting an insertvalue for each element in the resulting
6406// struct, as opposed to just inserting a single struct. This will only work if
6407// each of the elements of the substruct are known (ie, inserted into From by an
6408// insertvalue instruction somewhere).
6409//
6410// All inserted insertvalue instructions are inserted before InsertBefore
6411static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range,
6412 BasicBlock::iterator InsertBefore) {
6413 Type *IndexedType = ExtractValueInst::getIndexedType(Agg: From->getType(),
6414 Idxs: idx_range);
6415 Value *To = PoisonValue::get(T: IndexedType);
6416 SmallVector<unsigned, 10> Idxs(idx_range);
6417 unsigned IdxSkip = Idxs.size();
6418
6419 return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore);
6420}
6421
6422/// Given an aggregate and a sequence of indices, see if the scalar value
6423/// indexed is already around as a register, for example if it was inserted
6424/// directly into the aggregate.
6425///
6426/// If InsertBefore is not null, this function will duplicate (modified)
6427/// insertvalues when a part of a nested struct is extracted.
6428Value *
6429llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range,
6430 std::optional<BasicBlock::iterator> InsertBefore) {
6431 // Nothing to index? Just return V then (this is useful at the end of our
6432 // recursion).
6433 if (idx_range.empty())
6434 return V;
6435 // We have indices, so V should have an indexable type.
6436 assert((V->getType()->isStructTy() || V->getType()->isArrayTy()) &&
6437 "Not looking at a struct or array?");
6438 assert(ExtractValueInst::getIndexedType(V->getType(), idx_range) &&
6439 "Invalid indices for type?");
6440
6441 if (Constant *C = dyn_cast<Constant>(Val: V)) {
6442 C = C->getAggregateElement(Elt: idx_range[0]);
6443 if (!C) return nullptr;
6444 return FindInsertedValue(V: C, idx_range: idx_range.slice(N: 1), InsertBefore);
6445 }
6446
6447 if (InsertValueInst *I = dyn_cast<InsertValueInst>(Val: V)) {
6448 // Loop the indices for the insertvalue instruction in parallel with the
6449 // requested indices
6450 const unsigned *req_idx = idx_range.begin();
6451 for (const unsigned *i = I->idx_begin(), *e = I->idx_end();
6452 i != e; ++i, ++req_idx) {
6453 if (req_idx == idx_range.end()) {
6454 // We can't handle this without inserting insertvalues
6455 if (!InsertBefore)
6456 return nullptr;
6457
6458 // The requested index identifies a part of a nested aggregate. Handle
6459 // this specially. For example,
6460 // %A = insertvalue { i32, {i32, i32 } } undef, i32 10, 1, 0
6461 // %B = insertvalue { i32, {i32, i32 } } %A, i32 11, 1, 1
6462 // %C = extractvalue {i32, { i32, i32 } } %B, 1
6463 // This can be changed into
6464 // %A = insertvalue {i32, i32 } undef, i32 10, 0
6465 // %C = insertvalue {i32, i32 } %A, i32 11, 1
6466 // which allows the unused 0,0 element from the nested struct to be
6467 // removed.
6468 return BuildSubAggregate(From: V, idx_range: ArrayRef(idx_range.begin(), req_idx),
6469 InsertBefore: *InsertBefore);
6470 }
6471
6472 // This insert value inserts something else than what we are looking for.
6473 // See if the (aggregate) value inserted into has the value we are
6474 // looking for, then.
6475 if (*req_idx != *i)
6476 return FindInsertedValue(V: I->getAggregateOperand(), idx_range,
6477 InsertBefore);
6478 }
6479 // If we end up here, the indices of the insertvalue match with those
6480 // requested (though possibly only partially). Now we recursively look at
6481 // the inserted value, passing any remaining indices.
6482 return FindInsertedValue(V: I->getInsertedValueOperand(),
6483 idx_range: ArrayRef(req_idx, idx_range.end()), InsertBefore);
6484 }
6485
6486 if (ExtractValueInst *I = dyn_cast<ExtractValueInst>(Val: V)) {
6487 // If we're extracting a value from an aggregate that was extracted from
6488 // something else, we can extract from that something else directly instead.
6489 // However, we will need to chain I's indices with the requested indices.
6490
6491 // Calculate the number of indices required
6492 unsigned size = I->getNumIndices() + idx_range.size();
6493 // Allocate some space to put the new indices in
6494 SmallVector<unsigned, 5> Idxs;
6495 Idxs.reserve(N: size);
6496 // Add indices from the extract value instruction
6497 Idxs.append(in_start: I->idx_begin(), in_end: I->idx_end());
6498
6499 // Add requested indices
6500 Idxs.append(in_start: idx_range.begin(), in_end: idx_range.end());
6501
6502 assert(Idxs.size() == size
6503 && "Number of indices added not correct?");
6504
6505 return FindInsertedValue(V: I->getAggregateOperand(), idx_range: Idxs, InsertBefore);
6506 }
6507 // Otherwise, we don't know (such as, extracting from a function return value
6508 // or load instruction)
6509 return nullptr;
6510}
6511
6512// If V refers to an initialized global constant, set Slice either to
6513// its initializer if the size of its elements equals ElementSize, or,
6514// for ElementSize == 8, to its representation as an array of unsiged
6515// char. Return true on success.
6516// Offset is in the unit "nr of ElementSize sized elements".
6517bool llvm::getConstantDataArrayInfo(const Value *V,
6518 ConstantDataArraySlice &Slice,
6519 unsigned ElementSize, uint64_t Offset) {
6520 assert(V && "V should not be null.");
6521 assert((ElementSize % 8) == 0 &&
6522 "ElementSize expected to be a multiple of the size of a byte.");
6523 unsigned ElementSizeInBytes = ElementSize / 8;
6524
6525 // Drill down into the pointer expression V, ignoring any intervening
6526 // casts, and determine the identity of the object it references along
6527 // with the cumulative byte offset into it.
6528 const GlobalVariable *GV =
6529 dyn_cast<GlobalVariable>(Val: getUnderlyingObject(V));
6530 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
6531 // Fail if V is not based on constant global object.
6532 return false;
6533
6534 const DataLayout &DL = GV->getDataLayout();
6535 APInt Off(DL.getIndexTypeSizeInBits(Ty: V->getType()), 0);
6536
6537 if (GV != V->stripAndAccumulateConstantOffsets(DL, Offset&: Off,
6538 /*AllowNonInbounds*/ true))
6539 // Fail if a constant offset could not be determined.
6540 return false;
6541
6542 uint64_t StartIdx = Off.getLimitedValue();
6543 if (StartIdx == UINT64_MAX)
6544 // Fail if the constant offset is excessive.
6545 return false;
6546
6547 // Off/StartIdx is in the unit of bytes. So we need to convert to number of
6548 // elements. Simply bail out if that isn't possible.
6549 if ((StartIdx % ElementSizeInBytes) != 0)
6550 return false;
6551
6552 Offset += StartIdx / ElementSizeInBytes;
6553 ConstantDataArray *Array = nullptr;
6554 ArrayType *ArrayTy = nullptr;
6555
6556 if (GV->getInitializer()->isNullValue()) {
6557 Type *GVTy = GV->getValueType();
6558 uint64_t SizeInBytes = DL.getTypeStoreSize(Ty: GVTy).getFixedValue();
6559 uint64_t Length = SizeInBytes / ElementSizeInBytes;
6560
6561 Slice.Array = nullptr;
6562 Slice.Offset = 0;
6563 // Return an empty Slice for undersized constants to let callers
6564 // transform even undefined library calls into simpler, well-defined
6565 // expressions. This is preferable to making the calls although it
6566 // prevents sanitizers from detecting such calls.
6567 Slice.Length = Length < Offset ? 0 : Length - Offset;
6568 return true;
6569 }
6570
6571 auto *Init = const_cast<Constant *>(GV->getInitializer());
6572 if (auto *ArrayInit = dyn_cast<ConstantDataArray>(Val: Init)) {
6573 Type *InitElTy = ArrayInit->getElementType();
6574 if (InitElTy->isIntegerTy(Bitwidth: ElementSize)) {
6575 // If Init is an initializer for an array of the expected type
6576 // and size, use it as is.
6577 Array = ArrayInit;
6578 ArrayTy = ArrayInit->getType();
6579 }
6580 }
6581
6582 if (!Array) {
6583 if (ElementSize != 8)
6584 // TODO: Handle conversions to larger integral types.
6585 return false;
6586
6587 // Otherwise extract the portion of the initializer starting
6588 // at Offset as an array of bytes, and reset Offset.
6589 Init = ReadByteArrayFromGlobal(GV, Offset);
6590 if (!Init)
6591 return false;
6592
6593 Offset = 0;
6594 Array = dyn_cast<ConstantDataArray>(Val: Init);
6595 ArrayTy = dyn_cast<ArrayType>(Val: Init->getType());
6596 }
6597
6598 uint64_t NumElts = ArrayTy->getArrayNumElements();
6599 if (Offset > NumElts)
6600 return false;
6601
6602 Slice.Array = Array;
6603 Slice.Offset = Offset;
6604 Slice.Length = NumElts - Offset;
6605 return true;
6606}
6607
6608/// Extract bytes from the initializer of the constant array V, which need
6609/// not be a nul-terminated string. On success, store the bytes in Str and
6610/// return true. When TrimAtNul is set, Str will contain only the bytes up
6611/// to but not including the first nul. Return false on failure.
6612bool llvm::getConstantStringInfo(const Value *V, StringRef &Str,
6613 bool TrimAtNul) {
6614 ConstantDataArraySlice Slice;
6615 if (!getConstantDataArrayInfo(V, Slice, ElementSize: 8))
6616 return false;
6617
6618 if (Slice.Array == nullptr) {
6619 if (TrimAtNul) {
6620 // Return a nul-terminated string even for an empty Slice. This is
6621 // safe because all existing SimplifyLibcalls callers require string
6622 // arguments and the behavior of the functions they fold is undefined
6623 // otherwise. Folding the calls this way is preferable to making
6624 // the undefined library calls, even though it prevents sanitizers
6625 // from reporting such calls.
6626 Str = StringRef();
6627 return true;
6628 }
6629 if (Slice.Length == 1) {
6630 Str = StringRef("", 1);
6631 return true;
6632 }
6633 // We cannot instantiate a StringRef as we do not have an appropriate string
6634 // of 0s at hand.
6635 return false;
6636 }
6637
6638 // Start out with the entire array in the StringRef.
6639 Str = Slice.Array->getAsString();
6640 // Skip over 'offset' bytes.
6641 Str = Str.substr(Start: Slice.Offset);
6642
6643 if (TrimAtNul) {
6644 // Trim off the \0 and anything after it. If the array is not nul
6645 // terminated, we just return the whole end of string. The client may know
6646 // some other way that the string is length-bound.
6647 Str = Str.substr(Start: 0, N: Str.find(C: '\0'));
6648 }
6649 return true;
6650}
6651
6652// These next two are very similar to the above, but also look through PHI
6653// nodes.
6654// TODO: See if we can integrate these two together.
6655
6656/// If we can compute the length of the string pointed to by
6657/// the specified pointer, return 'len+1'. If we can't, return 0.
6658static uint64_t GetStringLengthH(const Value *V,
6659 SmallPtrSetImpl<const PHINode*> &PHIs,
6660 unsigned CharSize) {
6661 // Look through noop bitcast instructions.
6662 V = V->stripPointerCasts();
6663
6664 // If this is a PHI node, there are two cases: either we have already seen it
6665 // or we haven't.
6666 if (const PHINode *PN = dyn_cast<PHINode>(Val: V)) {
6667 if (!PHIs.insert(Ptr: PN).second)
6668 return ~0ULL; // already in the set.
6669
6670 // If it was new, see if all the input strings are the same length.
6671 uint64_t LenSoFar = ~0ULL;
6672 for (Value *IncValue : PN->incoming_values()) {
6673 uint64_t Len = GetStringLengthH(V: IncValue, PHIs, CharSize);
6674 if (Len == 0) return 0; // Unknown length -> unknown.
6675
6676 if (Len == ~0ULL) continue;
6677
6678 if (Len != LenSoFar && LenSoFar != ~0ULL)
6679 return 0; // Disagree -> unknown.
6680 LenSoFar = Len;
6681 }
6682
6683 // Success, all agree.
6684 return LenSoFar;
6685 }
6686
6687 // strlen(select(c,x,y)) -> strlen(x) ^ strlen(y)
6688 if (const SelectInst *SI = dyn_cast<SelectInst>(Val: V)) {
6689 uint64_t Len1 = GetStringLengthH(V: SI->getTrueValue(), PHIs, CharSize);
6690 if (Len1 == 0) return 0;
6691 uint64_t Len2 = GetStringLengthH(V: SI->getFalseValue(), PHIs, CharSize);
6692 if (Len2 == 0) return 0;
6693 if (Len1 == ~0ULL) return Len2;
6694 if (Len2 == ~0ULL) return Len1;
6695 if (Len1 != Len2) return 0;
6696 return Len1;
6697 }
6698
6699 // Otherwise, see if we can read the string.
6700 ConstantDataArraySlice Slice;
6701 if (!getConstantDataArrayInfo(V, Slice, ElementSize: CharSize))
6702 return 0;
6703
6704 if (Slice.Array == nullptr)
6705 // Zeroinitializer (including an empty one).
6706 return 1;
6707
6708 // Search for the first nul character. Return a conservative result even
6709 // when there is no nul. This is safe since otherwise the string function
6710 // being folded such as strlen is undefined, and can be preferable to
6711 // making the undefined library call.
6712 unsigned NullIndex = 0;
6713 for (unsigned E = Slice.Length; NullIndex < E; ++NullIndex) {
6714 if (Slice.Array->getElementAsInteger(i: Slice.Offset + NullIndex) == 0)
6715 break;
6716 }
6717
6718 return NullIndex + 1;
6719}
6720
6721/// If we can compute the length of the string pointed to by
6722/// the specified pointer, return 'len+1'. If we can't, return 0.
6723uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) {
6724 if (!V->getType()->isPointerTy())
6725 return 0;
6726
6727 SmallPtrSet<const PHINode*, 32> PHIs;
6728 uint64_t Len = GetStringLengthH(V, PHIs, CharSize);
6729 // If Len is ~0ULL, we had an infinite phi cycle: this is dead code, so return
6730 // an empty string as a length.
6731 return Len == ~0ULL ? 1 : Len;
6732}
6733
6734const Value *
6735llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call,
6736 bool MustPreserveNullness) {
6737 assert(Call &&
6738 "getArgumentAliasingToReturnedPointer only works on nonnull calls");
6739 if (const Value *RV = Call->getReturnedArgOperand())
6740 return RV;
6741 // This can be used only as a aliasing property.
6742 if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6743 Call, MustPreserveNullness))
6744 return Call->getArgOperand(i: 0);
6745 return nullptr;
6746}
6747
6748bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(
6749 const CallBase *Call, bool MustPreserveNullness) {
6750 switch (Call->getIntrinsicID()) {
6751 case Intrinsic::launder_invariant_group:
6752 case Intrinsic::strip_invariant_group:
6753 case Intrinsic::aarch64_irg:
6754 case Intrinsic::aarch64_tagp:
6755 // The amdgcn_make_buffer_rsrc function does not alter the address of the
6756 // input pointer (and thus preserve null-ness for the purposes of escape
6757 // analysis, which is where the MustPreserveNullness flag comes in to play).
6758 // However, it will not necessarily map ptr addrspace(N) null to ptr
6759 // addrspace(8) null, aka the "null descriptor", which has "all loads return
6760 // 0, all stores are dropped" semantics. Given the context of this intrinsic
6761 // list, no one should be relying on such a strict interpretation of
6762 // MustPreserveNullness (and, at time of writing, they are not), but we
6763 // document this fact out of an abundance of caution.
6764 case Intrinsic::amdgcn_make_buffer_rsrc:
6765 return true;
6766 case Intrinsic::ptrmask:
6767 return !MustPreserveNullness;
6768 case Intrinsic::threadlocal_address:
6769 // The underlying variable changes with thread ID. The Thread ID may change
6770 // at coroutine suspend points.
6771 return !Call->getParent()->getParent()->isPresplitCoroutine();
6772 default:
6773 return false;
6774 }
6775}
6776
6777/// \p PN defines a loop-variant pointer to an object. Check if the
6778/// previous iteration of the loop was referring to the same object as \p PN.
6779static bool isSameUnderlyingObjectInLoop(const PHINode *PN,
6780 const LoopInfo *LI) {
6781 // Find the loop-defined value.
6782 Loop *L = LI->getLoopFor(BB: PN->getParent());
6783 if (PN->getNumIncomingValues() != 2)
6784 return true;
6785
6786 // Find the value from previous iteration.
6787 auto *PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 0));
6788 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6789 PrevValue = dyn_cast<Instruction>(Val: PN->getIncomingValue(i: 1));
6790 if (!PrevValue || LI->getLoopFor(BB: PrevValue->getParent()) != L)
6791 return true;
6792
6793 // If a new pointer is loaded in the loop, the pointer references a different
6794 // object in every iteration. E.g.:
6795 // for (i)
6796 // int *p = a[i];
6797 // ...
6798 if (auto *Load = dyn_cast<LoadInst>(Val: PrevValue))
6799 if (!L->isLoopInvariant(V: Load->getPointerOperand()))
6800 return false;
6801 return true;
6802}
6803
6804const Value *llvm::getUnderlyingObject(const Value *V, unsigned MaxLookup) {
6805 for (unsigned Count = 0; MaxLookup == 0 || Count < MaxLookup; ++Count) {
6806 if (auto *GEP = dyn_cast<GEPOperator>(Val: V)) {
6807 const Value *PtrOp = GEP->getPointerOperand();
6808 if (!PtrOp->getType()->isPointerTy()) // Only handle scalar pointer base.
6809 return V;
6810 V = PtrOp;
6811 } else if (Operator::getOpcode(V) == Instruction::BitCast ||
6812 Operator::getOpcode(V) == Instruction::AddrSpaceCast) {
6813 Value *NewV = cast<Operator>(Val: V)->getOperand(i: 0);
6814 if (!NewV->getType()->isPointerTy())
6815 return V;
6816 V = NewV;
6817 } else if (auto *GA = dyn_cast<GlobalAlias>(Val: V)) {
6818 if (GA->isInterposable())
6819 return V;
6820 V = GA->getAliasee();
6821 } else {
6822 if (auto *PHI = dyn_cast<PHINode>(Val: V)) {
6823 // Look through single-arg phi nodes created by LCSSA.
6824 if (PHI->getNumIncomingValues() == 1) {
6825 V = PHI->getIncomingValue(i: 0);
6826 continue;
6827 }
6828 } else if (auto *Call = dyn_cast<CallBase>(Val: V)) {
6829 // CaptureTracking can know about special capturing properties of some
6830 // intrinsics like launder.invariant.group, that can't be expressed with
6831 // the attributes, but have properties like returning aliasing pointer.
6832 // Because some analysis may assume that nocaptured pointer is not
6833 // returned from some special intrinsic (because function would have to
6834 // be marked with returns attribute), it is crucial to use this function
6835 // because it should be in sync with CaptureTracking. Not using it may
6836 // cause weird miscompilations where 2 aliasing pointers are assumed to
6837 // noalias.
6838 if (auto *RP = getArgumentAliasingToReturnedPointer(Call, MustPreserveNullness: false)) {
6839 V = RP;
6840 continue;
6841 }
6842 }
6843
6844 return V;
6845 }
6846 assert(V->getType()->isPointerTy() && "Unexpected operand type!");
6847 }
6848 return V;
6849}
6850
6851void llvm::getUnderlyingObjects(const Value *V,
6852 SmallVectorImpl<const Value *> &Objects,
6853 const LoopInfo *LI, unsigned MaxLookup) {
6854 SmallPtrSet<const Value *, 4> Visited;
6855 SmallVector<const Value *, 4> Worklist;
6856 Worklist.push_back(Elt: V);
6857 do {
6858 const Value *P = Worklist.pop_back_val();
6859 P = getUnderlyingObject(V: P, MaxLookup);
6860
6861 if (!Visited.insert(Ptr: P).second)
6862 continue;
6863
6864 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6865 Worklist.push_back(Elt: SI->getTrueValue());
6866 Worklist.push_back(Elt: SI->getFalseValue());
6867 continue;
6868 }
6869
6870 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6871 // If this PHI changes the underlying object in every iteration of the
6872 // loop, don't look through it. Consider:
6873 // int **A;
6874 // for (i) {
6875 // Prev = Curr; // Prev = PHI (Prev_0, Curr)
6876 // Curr = A[i];
6877 // *Prev, *Curr;
6878 //
6879 // Prev is tracking Curr one iteration behind so they refer to different
6880 // underlying objects.
6881 if (!LI || !LI->isLoopHeader(BB: PN->getParent()) ||
6882 isSameUnderlyingObjectInLoop(PN, LI))
6883 append_range(C&: Worklist, R: PN->incoming_values());
6884 else
6885 Objects.push_back(Elt: P);
6886 continue;
6887 }
6888
6889 Objects.push_back(Elt: P);
6890 } while (!Worklist.empty());
6891}
6892
6893const Value *llvm::getUnderlyingObjectAggressive(const Value *V) {
6894 const unsigned MaxVisited = 8;
6895
6896 SmallPtrSet<const Value *, 8> Visited;
6897 SmallVector<const Value *, 8> Worklist;
6898 Worklist.push_back(Elt: V);
6899 const Value *Object = nullptr;
6900 // Used as fallback if we can't find a common underlying object through
6901 // recursion.
6902 bool First = true;
6903 const Value *FirstObject = getUnderlyingObject(V);
6904 do {
6905 const Value *P = Worklist.pop_back_val();
6906 P = First ? FirstObject : getUnderlyingObject(V: P);
6907 First = false;
6908
6909 if (!Visited.insert(Ptr: P).second)
6910 continue;
6911
6912 if (Visited.size() == MaxVisited)
6913 return FirstObject;
6914
6915 if (auto *SI = dyn_cast<SelectInst>(Val: P)) {
6916 Worklist.push_back(Elt: SI->getTrueValue());
6917 Worklist.push_back(Elt: SI->getFalseValue());
6918 continue;
6919 }
6920
6921 if (auto *PN = dyn_cast<PHINode>(Val: P)) {
6922 append_range(C&: Worklist, R: PN->incoming_values());
6923 continue;
6924 }
6925
6926 if (!Object)
6927 Object = P;
6928 else if (Object != P)
6929 return FirstObject;
6930 } while (!Worklist.empty());
6931
6932 return Object ? Object : FirstObject;
6933}
6934
6935/// This is the function that does the work of looking through basic
6936/// ptrtoint+arithmetic+inttoptr sequences.
6937static const Value *getUnderlyingObjectFromInt(const Value *V) {
6938 do {
6939 if (const Operator *U = dyn_cast<Operator>(Val: V)) {
6940 // If we find a ptrtoint, we can transfer control back to the
6941 // regular getUnderlyingObjectFromInt.
6942 if (U->getOpcode() == Instruction::PtrToInt)
6943 return U->getOperand(i: 0);
6944 // If we find an add of a constant, a multiplied value, or a phi, it's
6945 // likely that the other operand will lead us to the base
6946 // object. We don't have to worry about the case where the
6947 // object address is somehow being computed by the multiply,
6948 // because our callers only care when the result is an
6949 // identifiable object.
6950 if (U->getOpcode() != Instruction::Add ||
6951 (!isa<ConstantInt>(Val: U->getOperand(i: 1)) &&
6952 Operator::getOpcode(V: U->getOperand(i: 1)) != Instruction::Mul &&
6953 !isa<PHINode>(Val: U->getOperand(i: 1))))
6954 return V;
6955 V = U->getOperand(i: 0);
6956 } else {
6957 return V;
6958 }
6959 assert(V->getType()->isIntegerTy() && "Unexpected operand type!");
6960 } while (true);
6961}
6962
6963/// This is a wrapper around getUnderlyingObjects and adds support for basic
6964/// ptrtoint+arithmetic+inttoptr sequences.
6965/// It returns false if unidentified object is found in getUnderlyingObjects.
6966bool llvm::getUnderlyingObjectsForCodeGen(const Value *V,
6967 SmallVectorImpl<Value *> &Objects) {
6968 SmallPtrSet<const Value *, 16> Visited;
6969 SmallVector<const Value *, 4> Working(1, V);
6970 do {
6971 V = Working.pop_back_val();
6972
6973 SmallVector<const Value *, 4> Objs;
6974 getUnderlyingObjects(V, Objects&: Objs);
6975
6976 for (const Value *V : Objs) {
6977 if (!Visited.insert(Ptr: V).second)
6978 continue;
6979 if (Operator::getOpcode(V) == Instruction::IntToPtr) {
6980 const Value *O =
6981 getUnderlyingObjectFromInt(V: cast<User>(Val: V)->getOperand(i: 0));
6982 if (O->getType()->isPointerTy()) {
6983 Working.push_back(Elt: O);
6984 continue;
6985 }
6986 }
6987 // If getUnderlyingObjects fails to find an identifiable object,
6988 // getUnderlyingObjectsForCodeGen also fails for safety.
6989 if (!isIdentifiedObject(V)) {
6990 Objects.clear();
6991 return false;
6992 }
6993 Objects.push_back(Elt: const_cast<Value *>(V));
6994 }
6995 } while (!Working.empty());
6996 return true;
6997}
6998
6999AllocaInst *llvm::findAllocaForValue(Value *V, bool OffsetZero) {
7000 AllocaInst *Result = nullptr;
7001 SmallPtrSet<Value *, 4> Visited;
7002 SmallVector<Value *, 4> Worklist;
7003
7004 auto AddWork = [&](Value *V) {
7005 if (Visited.insert(Ptr: V).second)
7006 Worklist.push_back(Elt: V);
7007 };
7008
7009 AddWork(V);
7010 do {
7011 V = Worklist.pop_back_val();
7012 assert(Visited.count(V));
7013
7014 if (AllocaInst *AI = dyn_cast<AllocaInst>(Val: V)) {
7015 if (Result && Result != AI)
7016 return nullptr;
7017 Result = AI;
7018 } else if (CastInst *CI = dyn_cast<CastInst>(Val: V)) {
7019 AddWork(CI->getOperand(i_nocapture: 0));
7020 } else if (PHINode *PN = dyn_cast<PHINode>(Val: V)) {
7021 for (Value *IncValue : PN->incoming_values())
7022 AddWork(IncValue);
7023 } else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
7024 AddWork(SI->getTrueValue());
7025 AddWork(SI->getFalseValue());
7026 } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Val: V)) {
7027 if (OffsetZero && !GEP->hasAllZeroIndices())
7028 return nullptr;
7029 AddWork(GEP->getPointerOperand());
7030 } else if (CallBase *CB = dyn_cast<CallBase>(Val: V)) {
7031 Value *Returned = CB->getReturnedArgOperand();
7032 if (Returned)
7033 AddWork(Returned);
7034 else
7035 return nullptr;
7036 } else {
7037 return nullptr;
7038 }
7039 } while (!Worklist.empty());
7040
7041 return Result;
7042}
7043
7044static bool onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7045 const Value *V, bool AllowLifetime, bool AllowDroppable) {
7046 for (const User *U : V->users()) {
7047 const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: U);
7048 if (!II)
7049 return false;
7050
7051 if (AllowLifetime && II->isLifetimeStartOrEnd())
7052 continue;
7053
7054 if (AllowDroppable && II->isDroppable())
7055 continue;
7056
7057 return false;
7058 }
7059 return true;
7060}
7061
7062bool llvm::onlyUsedByLifetimeMarkers(const Value *V) {
7063 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7064 V, /* AllowLifetime */ true, /* AllowDroppable */ false);
7065}
7066bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
7067 return onlyUsedByLifetimeMarkersOrDroppableInstsHelper(
7068 V, /* AllowLifetime */ true, /* AllowDroppable */ true);
7069}
7070
7071bool llvm::isNotCrossLaneOperation(const Instruction *I) {
7072 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
7073 return isTriviallyVectorizable(ID: II->getIntrinsicID());
7074 auto *Shuffle = dyn_cast<ShuffleVectorInst>(Val: I);
7075 return (!Shuffle || Shuffle->isSelect()) &&
7076 !isa<CallBase, BitCastInst, ExtractElementInst>(Val: I);
7077}
7078
7079bool llvm::isSafeToSpeculativelyExecute(
7080 const Instruction *Inst, const Instruction *CtxI, AssumptionCache *AC,
7081 const DominatorTree *DT, const TargetLibraryInfo *TLI, bool UseVariableInfo,
7082 bool IgnoreUBImplyingAttrs) {
7083 return isSafeToSpeculativelyExecuteWithOpcode(Opcode: Inst->getOpcode(), Inst, CtxI,
7084 AC, DT, TLI, UseVariableInfo,
7085 IgnoreUBImplyingAttrs);
7086}
7087
7088bool llvm::isSafeToSpeculativelyExecuteWithOpcode(
7089 unsigned Opcode, const Instruction *Inst, const Instruction *CtxI,
7090 AssumptionCache *AC, const DominatorTree *DT, const TargetLibraryInfo *TLI,
7091 bool UseVariableInfo, bool IgnoreUBImplyingAttrs) {
7092#ifndef NDEBUG
7093 if (Inst->getOpcode() != Opcode) {
7094 // Check that the operands are actually compatible with the Opcode override.
7095 auto hasEqualReturnAndLeadingOperandTypes =
7096 [](const Instruction *Inst, unsigned NumLeadingOperands) {
7097 if (Inst->getNumOperands() < NumLeadingOperands)
7098 return false;
7099 const Type *ExpectedType = Inst->getType();
7100 for (unsigned ItOp = 0; ItOp < NumLeadingOperands; ++ItOp)
7101 if (Inst->getOperand(ItOp)->getType() != ExpectedType)
7102 return false;
7103 return true;
7104 };
7105 assert(!Instruction::isBinaryOp(Opcode) ||
7106 hasEqualReturnAndLeadingOperandTypes(Inst, 2));
7107 assert(!Instruction::isUnaryOp(Opcode) ||
7108 hasEqualReturnAndLeadingOperandTypes(Inst, 1));
7109 }
7110#endif
7111
7112 switch (Opcode) {
7113 default:
7114 return true;
7115 case Instruction::UDiv:
7116 case Instruction::URem: {
7117 // x / y is undefined if y == 0.
7118 const APInt *V;
7119 if (match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: V)))
7120 return *V != 0;
7121 return false;
7122 }
7123 case Instruction::SDiv:
7124 case Instruction::SRem: {
7125 // x / y is undefined if y == 0 or x == INT_MIN and y == -1
7126 const APInt *Numerator, *Denominator;
7127 if (!match(V: Inst->getOperand(i: 1), P: m_APInt(Res&: Denominator)))
7128 return false;
7129 // We cannot hoist this division if the denominator is 0.
7130 if (*Denominator == 0)
7131 return false;
7132 // It's safe to hoist if the denominator is not 0 or -1.
7133 if (!Denominator->isAllOnes())
7134 return true;
7135 // At this point we know that the denominator is -1. It is safe to hoist as
7136 // long we know that the numerator is not INT_MIN.
7137 if (match(V: Inst->getOperand(i: 0), P: m_APInt(Res&: Numerator)))
7138 return !Numerator->isMinSignedValue();
7139 // The numerator *might* be MinSignedValue.
7140 return false;
7141 }
7142 case Instruction::Load: {
7143 if (!UseVariableInfo)
7144 return false;
7145
7146 const LoadInst *LI = dyn_cast<LoadInst>(Val: Inst);
7147 if (!LI)
7148 return false;
7149 if (mustSuppressSpeculation(LI: *LI))
7150 return false;
7151 const DataLayout &DL = LI->getDataLayout();
7152 return isDereferenceableAndAlignedPointer(V: LI->getPointerOperand(),
7153 Ty: LI->getType(), Alignment: LI->getAlign(), DL,
7154 CtxI, AC, DT, TLI);
7155 }
7156 case Instruction::Call: {
7157 auto *CI = dyn_cast<const CallInst>(Val: Inst);
7158 if (!CI)
7159 return false;
7160 const Function *Callee = CI->getCalledFunction();
7161
7162 // The called function could have undefined behavior or side-effects, even
7163 // if marked readnone nounwind.
7164 if (!Callee || !Callee->isSpeculatable())
7165 return false;
7166 // Since the operands may be changed after hoisting, undefined behavior may
7167 // be triggered by some UB-implying attributes.
7168 return IgnoreUBImplyingAttrs || !CI->hasUBImplyingAttrs();
7169 }
7170 case Instruction::VAArg:
7171 case Instruction::Alloca:
7172 case Instruction::Invoke:
7173 case Instruction::CallBr:
7174 case Instruction::PHI:
7175 case Instruction::Store:
7176 case Instruction::Ret:
7177 case Instruction::Br:
7178 case Instruction::IndirectBr:
7179 case Instruction::Switch:
7180 case Instruction::Unreachable:
7181 case Instruction::Fence:
7182 case Instruction::AtomicRMW:
7183 case Instruction::AtomicCmpXchg:
7184 case Instruction::LandingPad:
7185 case Instruction::Resume:
7186 case Instruction::CatchSwitch:
7187 case Instruction::CatchPad:
7188 case Instruction::CatchRet:
7189 case Instruction::CleanupPad:
7190 case Instruction::CleanupRet:
7191 return false; // Misc instructions which have effects
7192 }
7193}
7194
7195bool llvm::mayHaveNonDefUseDependency(const Instruction &I) {
7196 if (I.mayReadOrWriteMemory())
7197 // Memory dependency possible
7198 return true;
7199 if (!isSafeToSpeculativelyExecute(Inst: &I))
7200 // Can't move above a maythrow call or infinite loop. Or if an
7201 // inalloca alloca, above a stacksave call.
7202 return true;
7203 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7204 // 1) Can't reorder two inf-loop calls, even if readonly
7205 // 2) Also can't reorder an inf-loop call below a instruction which isn't
7206 // safe to speculative execute. (Inverse of above)
7207 return true;
7208 return false;
7209}
7210
7211/// Convert ConstantRange OverflowResult into ValueTracking OverflowResult.
7212static OverflowResult mapOverflowResult(ConstantRange::OverflowResult OR) {
7213 switch (OR) {
7214 case ConstantRange::OverflowResult::MayOverflow:
7215 return OverflowResult::MayOverflow;
7216 case ConstantRange::OverflowResult::AlwaysOverflowsLow:
7217 return OverflowResult::AlwaysOverflowsLow;
7218 case ConstantRange::OverflowResult::AlwaysOverflowsHigh:
7219 return OverflowResult::AlwaysOverflowsHigh;
7220 case ConstantRange::OverflowResult::NeverOverflows:
7221 return OverflowResult::NeverOverflows;
7222 }
7223 llvm_unreachable("Unknown OverflowResult");
7224}
7225
7226/// Combine constant ranges from computeConstantRange() and computeKnownBits().
7227ConstantRange
7228llvm::computeConstantRangeIncludingKnownBits(const WithCache<const Value *> &V,
7229 bool ForSigned,
7230 const SimplifyQuery &SQ) {
7231 ConstantRange CR1 =
7232 ConstantRange::fromKnownBits(Known: V.getKnownBits(Q: SQ), IsSigned: ForSigned);
7233 ConstantRange CR2 = computeConstantRange(V, ForSigned, UseInstrInfo: SQ.IIQ.UseInstrInfo);
7234 ConstantRange::PreferredRangeType RangeType =
7235 ForSigned ? ConstantRange::Signed : ConstantRange::Unsigned;
7236 return CR1.intersectWith(CR: CR2, Type: RangeType);
7237}
7238
7239OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS,
7240 const Value *RHS,
7241 const SimplifyQuery &SQ,
7242 bool IsNSW) {
7243 ConstantRange LHSRange =
7244 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7245 ConstantRange RHSRange =
7246 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7247
7248 // mul nsw of two non-negative numbers is also nuw.
7249 if (IsNSW && LHSRange.isAllNonNegative() && RHSRange.isAllNonNegative())
7250 return OverflowResult::NeverOverflows;
7251
7252 return mapOverflowResult(OR: LHSRange.unsignedMulMayOverflow(Other: RHSRange));
7253}
7254
7255OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS,
7256 const Value *RHS,
7257 const SimplifyQuery &SQ) {
7258 // Multiplying n * m significant bits yields a result of n + m significant
7259 // bits. If the total number of significant bits does not exceed the
7260 // result bit width (minus 1), there is no overflow.
7261 // This means if we have enough leading sign bits in the operands
7262 // we can guarantee that the result does not overflow.
7263 // Ref: "Hacker's Delight" by Henry Warren
7264 unsigned BitWidth = LHS->getType()->getScalarSizeInBits();
7265
7266 // Note that underestimating the number of sign bits gives a more
7267 // conservative answer.
7268 unsigned SignBits =
7269 ::ComputeNumSignBits(V: LHS, Q: SQ) + ::ComputeNumSignBits(V: RHS, Q: SQ);
7270
7271 // First handle the easy case: if we have enough sign bits there's
7272 // definitely no overflow.
7273 if (SignBits > BitWidth + 1)
7274 return OverflowResult::NeverOverflows;
7275
7276 // There are two ambiguous cases where there can be no overflow:
7277 // SignBits == BitWidth + 1 and
7278 // SignBits == BitWidth
7279 // The second case is difficult to check, therefore we only handle the
7280 // first case.
7281 if (SignBits == BitWidth + 1) {
7282 // It overflows only when both arguments are negative and the true
7283 // product is exactly the minimum negative number.
7284 // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000
7285 // For simplicity we just check if at least one side is not negative.
7286 KnownBits LHSKnown = computeKnownBits(V: LHS, Q: SQ);
7287 KnownBits RHSKnown = computeKnownBits(V: RHS, Q: SQ);
7288 if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative())
7289 return OverflowResult::NeverOverflows;
7290 }
7291 return OverflowResult::MayOverflow;
7292}
7293
7294OverflowResult
7295llvm::computeOverflowForUnsignedAdd(const WithCache<const Value *> &LHS,
7296 const WithCache<const Value *> &RHS,
7297 const SimplifyQuery &SQ) {
7298 ConstantRange LHSRange =
7299 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7300 ConstantRange RHSRange =
7301 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7302 return mapOverflowResult(OR: LHSRange.unsignedAddMayOverflow(Other: RHSRange));
7303}
7304
7305static OverflowResult
7306computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7307 const WithCache<const Value *> &RHS,
7308 const AddOperator *Add, const SimplifyQuery &SQ) {
7309 if (Add && Add->hasNoSignedWrap()) {
7310 return OverflowResult::NeverOverflows;
7311 }
7312
7313 // If LHS and RHS each have at least two sign bits, the addition will look
7314 // like
7315 //
7316 // XX..... +
7317 // YY.....
7318 //
7319 // If the carry into the most significant position is 0, X and Y can't both
7320 // be 1 and therefore the carry out of the addition is also 0.
7321 //
7322 // If the carry into the most significant position is 1, X and Y can't both
7323 // be 0 and therefore the carry out of the addition is also 1.
7324 //
7325 // Since the carry into the most significant position is always equal to
7326 // the carry out of the addition, there is no signed overflow.
7327 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7328 return OverflowResult::NeverOverflows;
7329
7330 ConstantRange LHSRange =
7331 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7332 ConstantRange RHSRange =
7333 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7334 OverflowResult OR =
7335 mapOverflowResult(OR: LHSRange.signedAddMayOverflow(Other: RHSRange));
7336 if (OR != OverflowResult::MayOverflow)
7337 return OR;
7338
7339 // The remaining code needs Add to be available. Early returns if not so.
7340 if (!Add)
7341 return OverflowResult::MayOverflow;
7342
7343 // If the sign of Add is the same as at least one of the operands, this add
7344 // CANNOT overflow. If this can be determined from the known bits of the
7345 // operands the above signedAddMayOverflow() check will have already done so.
7346 // The only other way to improve on the known bits is from an assumption, so
7347 // call computeKnownBitsFromContext() directly.
7348 bool LHSOrRHSKnownNonNegative =
7349 (LHSRange.isAllNonNegative() || RHSRange.isAllNonNegative());
7350 bool LHSOrRHSKnownNegative =
7351 (LHSRange.isAllNegative() || RHSRange.isAllNegative());
7352 if (LHSOrRHSKnownNonNegative || LHSOrRHSKnownNegative) {
7353 KnownBits AddKnown(LHSRange.getBitWidth());
7354 computeKnownBitsFromContext(V: Add, Known&: AddKnown, Q: SQ);
7355 if ((AddKnown.isNonNegative() && LHSOrRHSKnownNonNegative) ||
7356 (AddKnown.isNegative() && LHSOrRHSKnownNegative))
7357 return OverflowResult::NeverOverflows;
7358 }
7359
7360 return OverflowResult::MayOverflow;
7361}
7362
7363OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS,
7364 const Value *RHS,
7365 const SimplifyQuery &SQ) {
7366 // X - (X % ?)
7367 // The remainder of a value can't have greater magnitude than itself,
7368 // so the subtraction can't overflow.
7369
7370 // X - (X -nuw ?)
7371 // In the minimal case, this would simplify to "?", so there's no subtract
7372 // at all. But if this analysis is used to peek through casts, for example,
7373 // then determining no-overflow may allow other transforms.
7374
7375 // TODO: There are other patterns like this.
7376 // See simplifyICmpWithBinOpOnLHS() for candidates.
7377 if (match(V: RHS, P: m_URem(L: m_Specific(V: LHS), R: m_Value())) ||
7378 match(V: RHS, P: m_NUWSub(L: m_Specific(V: LHS), R: m_Value())))
7379 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7380 return OverflowResult::NeverOverflows;
7381
7382 if (auto C = isImpliedByDomCondition(Pred: CmpInst::ICMP_UGE, LHS, RHS, ContextI: SQ.CxtI,
7383 DL: SQ.DL)) {
7384 if (*C)
7385 return OverflowResult::NeverOverflows;
7386 return OverflowResult::AlwaysOverflowsLow;
7387 }
7388
7389 ConstantRange LHSRange =
7390 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/false, SQ);
7391 ConstantRange RHSRange =
7392 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/false, SQ);
7393 return mapOverflowResult(OR: LHSRange.unsignedSubMayOverflow(Other: RHSRange));
7394}
7395
7396OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS,
7397 const Value *RHS,
7398 const SimplifyQuery &SQ) {
7399 // X - (X % ?)
7400 // The remainder of a value can't have greater magnitude than itself,
7401 // so the subtraction can't overflow.
7402
7403 // X - (X -nsw ?)
7404 // In the minimal case, this would simplify to "?", so there's no subtract
7405 // at all. But if this analysis is used to peek through casts, for example,
7406 // then determining no-overflow may allow other transforms.
7407 if (match(V: RHS, P: m_SRem(L: m_Specific(V: LHS), R: m_Value())) ||
7408 match(V: RHS, P: m_NSWSub(L: m_Specific(V: LHS), R: m_Value())))
7409 if (isGuaranteedNotToBeUndef(V: LHS, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT))
7410 return OverflowResult::NeverOverflows;
7411
7412 // If LHS and RHS each have at least two sign bits, the subtraction
7413 // cannot overflow.
7414 if (::ComputeNumSignBits(V: LHS, Q: SQ) > 1 && ::ComputeNumSignBits(V: RHS, Q: SQ) > 1)
7415 return OverflowResult::NeverOverflows;
7416
7417 ConstantRange LHSRange =
7418 computeConstantRangeIncludingKnownBits(V: LHS, /*ForSigned=*/true, SQ);
7419 ConstantRange RHSRange =
7420 computeConstantRangeIncludingKnownBits(V: RHS, /*ForSigned=*/true, SQ);
7421 return mapOverflowResult(OR: LHSRange.signedSubMayOverflow(Other: RHSRange));
7422}
7423
7424bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO,
7425 const DominatorTree &DT) {
7426 SmallVector<const BranchInst *, 2> GuardingBranches;
7427 SmallVector<const ExtractValueInst *, 2> Results;
7428
7429 for (const User *U : WO->users()) {
7430 if (const auto *EVI = dyn_cast<ExtractValueInst>(Val: U)) {
7431 assert(EVI->getNumIndices() == 1 && "Obvious from CI's type");
7432
7433 if (EVI->getIndices()[0] == 0)
7434 Results.push_back(Elt: EVI);
7435 else {
7436 assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type");
7437
7438 for (const auto *U : EVI->users())
7439 if (const auto *B = dyn_cast<BranchInst>(Val: U)) {
7440 assert(B->isConditional() && "How else is it using an i1?");
7441 GuardingBranches.push_back(Elt: B);
7442 }
7443 }
7444 } else {
7445 // We are using the aggregate directly in a way we don't want to analyze
7446 // here (storing it to a global, say).
7447 return false;
7448 }
7449 }
7450
7451 auto AllUsesGuardedByBranch = [&](const BranchInst *BI) {
7452 BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(i: 1));
7453 if (!NoWrapEdge.isSingleEdge())
7454 return false;
7455
7456 // Check if all users of the add are provably no-wrap.
7457 for (const auto *Result : Results) {
7458 // If the extractvalue itself is not executed on overflow, the we don't
7459 // need to check each use separately, since domination is transitive.
7460 if (DT.dominates(BBE: NoWrapEdge, BB: Result->getParent()))
7461 continue;
7462
7463 for (const auto &RU : Result->uses())
7464 if (!DT.dominates(BBE: NoWrapEdge, U: RU))
7465 return false;
7466 }
7467
7468 return true;
7469 };
7470
7471 return llvm::any_of(Range&: GuardingBranches, P: AllUsesGuardedByBranch);
7472}
7473
7474/// Shifts return poison if shiftwidth is larger than the bitwidth.
7475static bool shiftAmountKnownInRange(const Value *ShiftAmount) {
7476 auto *C = dyn_cast<Constant>(Val: ShiftAmount);
7477 if (!C)
7478 return false;
7479
7480 // Shifts return poison if shiftwidth is larger than the bitwidth.
7481 SmallVector<const Constant *, 4> ShiftAmounts;
7482 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: C->getType())) {
7483 unsigned NumElts = FVTy->getNumElements();
7484 for (unsigned i = 0; i < NumElts; ++i)
7485 ShiftAmounts.push_back(Elt: C->getAggregateElement(Elt: i));
7486 } else if (isa<ScalableVectorType>(Val: C->getType()))
7487 return false; // Can't tell, just return false to be safe
7488 else
7489 ShiftAmounts.push_back(Elt: C);
7490
7491 bool Safe = llvm::all_of(Range&: ShiftAmounts, P: [](const Constant *C) {
7492 auto *CI = dyn_cast_or_null<ConstantInt>(Val: C);
7493 return CI && CI->getValue().ult(RHS: C->getType()->getIntegerBitWidth());
7494 });
7495
7496 return Safe;
7497}
7498
7499enum class UndefPoisonKind {
7500 PoisonOnly = (1 << 0),
7501 UndefOnly = (1 << 1),
7502 UndefOrPoison = PoisonOnly | UndefOnly,
7503};
7504
7505static bool includesPoison(UndefPoisonKind Kind) {
7506 return (unsigned(Kind) & unsigned(UndefPoisonKind::PoisonOnly)) != 0;
7507}
7508
7509static bool includesUndef(UndefPoisonKind Kind) {
7510 return (unsigned(Kind) & unsigned(UndefPoisonKind::UndefOnly)) != 0;
7511}
7512
7513static bool canCreateUndefOrPoison(const Operator *Op, UndefPoisonKind Kind,
7514 bool ConsiderFlagsAndMetadata) {
7515
7516 if (ConsiderFlagsAndMetadata && includesPoison(Kind) &&
7517 Op->hasPoisonGeneratingAnnotations())
7518 return true;
7519
7520 unsigned Opcode = Op->getOpcode();
7521
7522 // Check whether opcode is a poison/undef-generating operation
7523 switch (Opcode) {
7524 case Instruction::Shl:
7525 case Instruction::AShr:
7526 case Instruction::LShr:
7527 return includesPoison(Kind) && !shiftAmountKnownInRange(ShiftAmount: Op->getOperand(i: 1));
7528 case Instruction::FPToSI:
7529 case Instruction::FPToUI:
7530 // fptosi/ui yields poison if the resulting value does not fit in the
7531 // destination type.
7532 return true;
7533 case Instruction::Call:
7534 if (auto *II = dyn_cast<IntrinsicInst>(Val: Op)) {
7535 switch (II->getIntrinsicID()) {
7536 // TODO: Add more intrinsics.
7537 case Intrinsic::ctlz:
7538 case Intrinsic::cttz:
7539 case Intrinsic::abs:
7540 // We're not considering flags so it is safe to just return false.
7541 return false;
7542 case Intrinsic::sshl_sat:
7543 case Intrinsic::ushl_sat:
7544 if (!includesPoison(Kind) ||
7545 shiftAmountKnownInRange(ShiftAmount: II->getArgOperand(i: 1)))
7546 return false;
7547 break;
7548 }
7549 }
7550 [[fallthrough]];
7551 case Instruction::CallBr:
7552 case Instruction::Invoke: {
7553 const auto *CB = cast<CallBase>(Val: Op);
7554 return !CB->hasRetAttr(Kind: Attribute::NoUndef) &&
7555 !CB->hasFnAttr(Kind: Attribute::NoCreateUndefOrPoison);
7556 }
7557 case Instruction::InsertElement:
7558 case Instruction::ExtractElement: {
7559 // If index exceeds the length of the vector, it returns poison
7560 auto *VTy = cast<VectorType>(Val: Op->getOperand(i: 0)->getType());
7561 unsigned IdxOp = Op->getOpcode() == Instruction::InsertElement ? 2 : 1;
7562 auto *Idx = dyn_cast<ConstantInt>(Val: Op->getOperand(i: IdxOp));
7563 if (includesPoison(Kind))
7564 return !Idx ||
7565 Idx->getValue().uge(RHS: VTy->getElementCount().getKnownMinValue());
7566 return false;
7567 }
7568 case Instruction::ShuffleVector: {
7569 ArrayRef<int> Mask = isa<ConstantExpr>(Val: Op)
7570 ? cast<ConstantExpr>(Val: Op)->getShuffleMask()
7571 : cast<ShuffleVectorInst>(Val: Op)->getShuffleMask();
7572 return includesPoison(Kind) && is_contained(Range&: Mask, Element: PoisonMaskElem);
7573 }
7574 case Instruction::FNeg:
7575 case Instruction::PHI:
7576 case Instruction::Select:
7577 case Instruction::ExtractValue:
7578 case Instruction::InsertValue:
7579 case Instruction::Freeze:
7580 case Instruction::ICmp:
7581 case Instruction::FCmp:
7582 case Instruction::GetElementPtr:
7583 return false;
7584 case Instruction::AddrSpaceCast:
7585 return true;
7586 default: {
7587 const auto *CE = dyn_cast<ConstantExpr>(Val: Op);
7588 if (isa<CastInst>(Val: Op) || (CE && CE->isCast()))
7589 return false;
7590 else if (Instruction::isBinaryOp(Opcode))
7591 return false;
7592 // Be conservative and return true.
7593 return true;
7594 }
7595 }
7596}
7597
7598bool llvm::canCreateUndefOrPoison(const Operator *Op,
7599 bool ConsiderFlagsAndMetadata) {
7600 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::UndefOrPoison,
7601 ConsiderFlagsAndMetadata);
7602}
7603
7604bool llvm::canCreatePoison(const Operator *Op, bool ConsiderFlagsAndMetadata) {
7605 return ::canCreateUndefOrPoison(Op, Kind: UndefPoisonKind::PoisonOnly,
7606 ConsiderFlagsAndMetadata);
7607}
7608
7609static bool directlyImpliesPoison(const Value *ValAssumedPoison, const Value *V,
7610 unsigned Depth) {
7611 if (ValAssumedPoison == V)
7612 return true;
7613
7614 const unsigned MaxDepth = 2;
7615 if (Depth >= MaxDepth)
7616 return false;
7617
7618 if (const auto *I = dyn_cast<Instruction>(Val: V)) {
7619 if (any_of(Range: I->operands(), P: [=](const Use &Op) {
7620 return propagatesPoison(PoisonOp: Op) &&
7621 directlyImpliesPoison(ValAssumedPoison, V: Op, Depth: Depth + 1);
7622 }))
7623 return true;
7624
7625 // V = extractvalue V0, idx
7626 // V2 = extractvalue V0, idx2
7627 // V0's elements are all poison or not. (e.g., add_with_overflow)
7628 const WithOverflowInst *II;
7629 if (match(V: I, P: m_ExtractValue(V: m_WithOverflowInst(I&: II))) &&
7630 (match(V: ValAssumedPoison, P: m_ExtractValue(V: m_Specific(V: II))) ||
7631 llvm::is_contained(Range: II->args(), Element: ValAssumedPoison)))
7632 return true;
7633 }
7634 return false;
7635}
7636
7637static bool impliesPoison(const Value *ValAssumedPoison, const Value *V,
7638 unsigned Depth) {
7639 if (isGuaranteedNotToBePoison(V: ValAssumedPoison))
7640 return true;
7641
7642 if (directlyImpliesPoison(ValAssumedPoison, V, /* Depth */ 0))
7643 return true;
7644
7645 const unsigned MaxDepth = 2;
7646 if (Depth >= MaxDepth)
7647 return false;
7648
7649 const auto *I = dyn_cast<Instruction>(Val: ValAssumedPoison);
7650 if (I && !canCreatePoison(Op: cast<Operator>(Val: I))) {
7651 return all_of(Range: I->operands(), P: [=](const Value *Op) {
7652 return impliesPoison(ValAssumedPoison: Op, V, Depth: Depth + 1);
7653 });
7654 }
7655 return false;
7656}
7657
7658bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) {
7659 return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0);
7660}
7661
7662static bool programUndefinedIfUndefOrPoison(const Value *V, bool PoisonOnly);
7663
7664static bool isGuaranteedNotToBeUndefOrPoison(
7665 const Value *V, AssumptionCache *AC, const Instruction *CtxI,
7666 const DominatorTree *DT, unsigned Depth, UndefPoisonKind Kind) {
7667 if (Depth >= MaxAnalysisRecursionDepth)
7668 return false;
7669
7670 if (isa<MetadataAsValue>(Val: V))
7671 return false;
7672
7673 if (const auto *A = dyn_cast<Argument>(Val: V)) {
7674 if (A->hasAttribute(Kind: Attribute::NoUndef) ||
7675 A->hasAttribute(Kind: Attribute::Dereferenceable) ||
7676 A->hasAttribute(Kind: Attribute::DereferenceableOrNull))
7677 return true;
7678 }
7679
7680 if (auto *C = dyn_cast<Constant>(Val: V)) {
7681 if (isa<PoisonValue>(Val: C))
7682 return !includesPoison(Kind);
7683
7684 if (isa<UndefValue>(Val: C))
7685 return !includesUndef(Kind);
7686
7687 if (isa<ConstantInt>(Val: C) || isa<GlobalVariable>(Val: C) || isa<ConstantFP>(Val: C) ||
7688 isa<ConstantPointerNull>(Val: C) || isa<Function>(Val: C))
7689 return true;
7690
7691 if (C->getType()->isVectorTy()) {
7692 if (isa<ConstantExpr>(Val: C)) {
7693 // Scalable vectors can use a ConstantExpr to build a splat.
7694 if (Constant *SplatC = C->getSplatValue())
7695 if (isa<ConstantInt>(Val: SplatC) || isa<ConstantFP>(Val: SplatC))
7696 return true;
7697 } else {
7698 if (includesUndef(Kind) && C->containsUndefElement())
7699 return false;
7700 if (includesPoison(Kind) && C->containsPoisonElement())
7701 return false;
7702 return !C->containsConstantExpression();
7703 }
7704 }
7705 }
7706
7707 // Strip cast operations from a pointer value.
7708 // Note that stripPointerCastsSameRepresentation can strip off getelementptr
7709 // inbounds with zero offset. To guarantee that the result isn't poison, the
7710 // stripped pointer is checked as it has to be pointing into an allocated
7711 // object or be null `null` to ensure `inbounds` getelement pointers with a
7712 // zero offset could not produce poison.
7713 // It can strip off addrspacecast that do not change bit representation as
7714 // well. We believe that such addrspacecast is equivalent to no-op.
7715 auto *StrippedV = V->stripPointerCastsSameRepresentation();
7716 if (isa<AllocaInst>(Val: StrippedV) || isa<GlobalVariable>(Val: StrippedV) ||
7717 isa<Function>(Val: StrippedV) || isa<ConstantPointerNull>(Val: StrippedV))
7718 return true;
7719
7720 auto OpCheck = [&](const Value *V) {
7721 return isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth: Depth + 1, Kind);
7722 };
7723
7724 if (auto *Opr = dyn_cast<Operator>(Val: V)) {
7725 // If the value is a freeze instruction, then it can never
7726 // be undef or poison.
7727 if (isa<FreezeInst>(Val: V))
7728 return true;
7729
7730 if (const auto *CB = dyn_cast<CallBase>(Val: V)) {
7731 if (CB->hasRetAttr(Kind: Attribute::NoUndef) ||
7732 CB->hasRetAttr(Kind: Attribute::Dereferenceable) ||
7733 CB->hasRetAttr(Kind: Attribute::DereferenceableOrNull))
7734 return true;
7735 }
7736
7737 if (!::canCreateUndefOrPoison(Op: Opr, Kind,
7738 /*ConsiderFlagsAndMetadata=*/true)) {
7739 if (const auto *PN = dyn_cast<PHINode>(Val: V)) {
7740 unsigned Num = PN->getNumIncomingValues();
7741 bool IsWellDefined = true;
7742 for (unsigned i = 0; i < Num; ++i) {
7743 if (PN == PN->getIncomingValue(i))
7744 continue;
7745 auto *TI = PN->getIncomingBlock(i)->getTerminator();
7746 if (!isGuaranteedNotToBeUndefOrPoison(V: PN->getIncomingValue(i), AC, CtxI: TI,
7747 DT, Depth: Depth + 1, Kind)) {
7748 IsWellDefined = false;
7749 break;
7750 }
7751 }
7752 if (IsWellDefined)
7753 return true;
7754 } else if (auto *Splat = isa<ShuffleVectorInst>(Val: Opr) ? getSplatValue(V: Opr)
7755 : nullptr) {
7756 // For splats we only need to check the value being splatted.
7757 if (OpCheck(Splat))
7758 return true;
7759 } else if (all_of(Range: Opr->operands(), P: OpCheck))
7760 return true;
7761 }
7762 }
7763
7764 if (auto *I = dyn_cast<LoadInst>(Val: V))
7765 if (I->hasMetadata(KindID: LLVMContext::MD_noundef) ||
7766 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable) ||
7767 I->hasMetadata(KindID: LLVMContext::MD_dereferenceable_or_null))
7768 return true;
7769
7770 if (programUndefinedIfUndefOrPoison(V, PoisonOnly: !includesUndef(Kind)))
7771 return true;
7772
7773 // CxtI may be null or a cloned instruction.
7774 if (!CtxI || !CtxI->getParent() || !DT)
7775 return false;
7776
7777 auto *DNode = DT->getNode(BB: CtxI->getParent());
7778 if (!DNode)
7779 // Unreachable block
7780 return false;
7781
7782 // If V is used as a branch condition before reaching CtxI, V cannot be
7783 // undef or poison.
7784 // br V, BB1, BB2
7785 // BB1:
7786 // CtxI ; V cannot be undef or poison here
7787 auto *Dominator = DNode->getIDom();
7788 // This check is purely for compile time reasons: we can skip the IDom walk
7789 // if what we are checking for includes undef and the value is not an integer.
7790 if (!includesUndef(Kind) || V->getType()->isIntegerTy())
7791 while (Dominator) {
7792 auto *TI = Dominator->getBlock()->getTerminator();
7793
7794 Value *Cond = nullptr;
7795 if (auto BI = dyn_cast_or_null<BranchInst>(Val: TI)) {
7796 if (BI->isConditional())
7797 Cond = BI->getCondition();
7798 } else if (auto SI = dyn_cast_or_null<SwitchInst>(Val: TI)) {
7799 Cond = SI->getCondition();
7800 }
7801
7802 if (Cond) {
7803 if (Cond == V)
7804 return true;
7805 else if (!includesUndef(Kind) && isa<Operator>(Val: Cond)) {
7806 // For poison, we can analyze further
7807 auto *Opr = cast<Operator>(Val: Cond);
7808 if (any_of(Range: Opr->operands(), P: [V](const Use &U) {
7809 return V == U && propagatesPoison(PoisonOp: U);
7810 }))
7811 return true;
7812 }
7813 }
7814
7815 Dominator = Dominator->getIDom();
7816 }
7817
7818 if (AC && getKnowledgeValidInContext(V, AttrKinds: {Attribute::NoUndef}, AC&: *AC, CtxI, DT))
7819 return true;
7820
7821 return false;
7822}
7823
7824bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, AssumptionCache *AC,
7825 const Instruction *CtxI,
7826 const DominatorTree *DT,
7827 unsigned Depth) {
7828 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7829 Kind: UndefPoisonKind::UndefOrPoison);
7830}
7831
7832bool llvm::isGuaranteedNotToBePoison(const Value *V, AssumptionCache *AC,
7833 const Instruction *CtxI,
7834 const DominatorTree *DT, unsigned Depth) {
7835 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7836 Kind: UndefPoisonKind::PoisonOnly);
7837}
7838
7839bool llvm::isGuaranteedNotToBeUndef(const Value *V, AssumptionCache *AC,
7840 const Instruction *CtxI,
7841 const DominatorTree *DT, unsigned Depth) {
7842 return ::isGuaranteedNotToBeUndefOrPoison(V, AC, CtxI, DT, Depth,
7843 Kind: UndefPoisonKind::UndefOnly);
7844}
7845
7846/// Return true if undefined behavior would provably be executed on the path to
7847/// OnPathTo if Root produced a posion result. Note that this doesn't say
7848/// anything about whether OnPathTo is actually executed or whether Root is
7849/// actually poison. This can be used to assess whether a new use of Root can
7850/// be added at a location which is control equivalent with OnPathTo (such as
7851/// immediately before it) without introducing UB which didn't previously
7852/// exist. Note that a false result conveys no information.
7853bool llvm::mustExecuteUBIfPoisonOnPathTo(Instruction *Root,
7854 Instruction *OnPathTo,
7855 DominatorTree *DT) {
7856 // Basic approach is to assume Root is poison, propagate poison forward
7857 // through all users we can easily track, and then check whether any of those
7858 // users are provable UB and must execute before out exiting block might
7859 // exit.
7860
7861 // The set of all recursive users we've visited (which are assumed to all be
7862 // poison because of said visit)
7863 SmallPtrSet<const Value *, 16> KnownPoison;
7864 SmallVector<const Instruction*, 16> Worklist;
7865 Worklist.push_back(Elt: Root);
7866 while (!Worklist.empty()) {
7867 const Instruction *I = Worklist.pop_back_val();
7868
7869 // If we know this must trigger UB on a path leading our target.
7870 if (mustTriggerUB(I, KnownPoison) && DT->dominates(Def: I, User: OnPathTo))
7871 return true;
7872
7873 // If we can't analyze propagation through this instruction, just skip it
7874 // and transitive users. Safe as false is a conservative result.
7875 if (I != Root && !any_of(Range: I->operands(), P: [&KnownPoison](const Use &U) {
7876 return KnownPoison.contains(Ptr: U) && propagatesPoison(PoisonOp: U);
7877 }))
7878 continue;
7879
7880 if (KnownPoison.insert(Ptr: I).second)
7881 for (const User *User : I->users())
7882 Worklist.push_back(Elt: cast<Instruction>(Val: User));
7883 }
7884
7885 // Might be non-UB, or might have a path we couldn't prove must execute on
7886 // way to exiting bb.
7887 return false;
7888}
7889
7890OverflowResult llvm::computeOverflowForSignedAdd(const AddOperator *Add,
7891 const SimplifyQuery &SQ) {
7892 return ::computeOverflowForSignedAdd(LHS: Add->getOperand(i_nocapture: 0), RHS: Add->getOperand(i_nocapture: 1),
7893 Add, SQ);
7894}
7895
7896OverflowResult
7897llvm::computeOverflowForSignedAdd(const WithCache<const Value *> &LHS,
7898 const WithCache<const Value *> &RHS,
7899 const SimplifyQuery &SQ) {
7900 return ::computeOverflowForSignedAdd(LHS, RHS, Add: nullptr, SQ);
7901}
7902
7903bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) {
7904 // Note: An atomic operation isn't guaranteed to return in a reasonable amount
7905 // of time because it's possible for another thread to interfere with it for an
7906 // arbitrary length of time, but programs aren't allowed to rely on that.
7907
7908 // If there is no successor, then execution can't transfer to it.
7909 if (isa<ReturnInst>(Val: I))
7910 return false;
7911 if (isa<UnreachableInst>(Val: I))
7912 return false;
7913
7914 // Note: Do not add new checks here; instead, change Instruction::mayThrow or
7915 // Instruction::willReturn.
7916 //
7917 // FIXME: Move this check into Instruction::willReturn.
7918 if (isa<CatchPadInst>(Val: I)) {
7919 switch (classifyEHPersonality(Pers: I->getFunction()->getPersonalityFn())) {
7920 default:
7921 // A catchpad may invoke exception object constructors and such, which
7922 // in some languages can be arbitrary code, so be conservative by default.
7923 return false;
7924 case EHPersonality::CoreCLR:
7925 // For CoreCLR, it just involves a type test.
7926 return true;
7927 }
7928 }
7929
7930 // An instruction that returns without throwing must transfer control flow
7931 // to a successor.
7932 return !I->mayThrow() && I->willReturn();
7933}
7934
7935bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) {
7936 // TODO: This is slightly conservative for invoke instruction since exiting
7937 // via an exception *is* normal control for them.
7938 for (const Instruction &I : *BB)
7939 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7940 return false;
7941 return true;
7942}
7943
7944bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7945 BasicBlock::const_iterator Begin, BasicBlock::const_iterator End,
7946 unsigned ScanLimit) {
7947 return isGuaranteedToTransferExecutionToSuccessor(Range: make_range(x: Begin, y: End),
7948 ScanLimit);
7949}
7950
7951bool llvm::isGuaranteedToTransferExecutionToSuccessor(
7952 iterator_range<BasicBlock::const_iterator> Range, unsigned ScanLimit) {
7953 assert(ScanLimit && "scan limit must be non-zero");
7954 for (const Instruction &I : Range) {
7955 if (--ScanLimit == 0)
7956 return false;
7957 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
7958 return false;
7959 }
7960 return true;
7961}
7962
7963bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I,
7964 const Loop *L) {
7965 // The loop header is guaranteed to be executed for every iteration.
7966 //
7967 // FIXME: Relax this constraint to cover all basic blocks that are
7968 // guaranteed to be executed at every iteration.
7969 if (I->getParent() != L->getHeader()) return false;
7970
7971 for (const Instruction &LI : *L->getHeader()) {
7972 if (&LI == I) return true;
7973 if (!isGuaranteedToTransferExecutionToSuccessor(I: &LI)) return false;
7974 }
7975 llvm_unreachable("Instruction not contained in its own parent basic block.");
7976}
7977
7978bool llvm::intrinsicPropagatesPoison(Intrinsic::ID IID) {
7979 switch (IID) {
7980 // TODO: Add more intrinsics.
7981 case Intrinsic::sadd_with_overflow:
7982 case Intrinsic::ssub_with_overflow:
7983 case Intrinsic::smul_with_overflow:
7984 case Intrinsic::uadd_with_overflow:
7985 case Intrinsic::usub_with_overflow:
7986 case Intrinsic::umul_with_overflow:
7987 // If an input is a vector containing a poison element, the
7988 // two output vectors (calculated results, overflow bits)'
7989 // corresponding lanes are poison.
7990 return true;
7991 case Intrinsic::ctpop:
7992 case Intrinsic::ctlz:
7993 case Intrinsic::cttz:
7994 case Intrinsic::abs:
7995 case Intrinsic::smax:
7996 case Intrinsic::smin:
7997 case Intrinsic::umax:
7998 case Intrinsic::umin:
7999 case Intrinsic::scmp:
8000 case Intrinsic::is_fpclass:
8001 case Intrinsic::ptrmask:
8002 case Intrinsic::ucmp:
8003 case Intrinsic::bitreverse:
8004 case Intrinsic::bswap:
8005 case Intrinsic::sadd_sat:
8006 case Intrinsic::ssub_sat:
8007 case Intrinsic::sshl_sat:
8008 case Intrinsic::uadd_sat:
8009 case Intrinsic::usub_sat:
8010 case Intrinsic::ushl_sat:
8011 case Intrinsic::smul_fix:
8012 case Intrinsic::smul_fix_sat:
8013 case Intrinsic::umul_fix:
8014 case Intrinsic::umul_fix_sat:
8015 case Intrinsic::pow:
8016 case Intrinsic::powi:
8017 case Intrinsic::sin:
8018 case Intrinsic::sinh:
8019 case Intrinsic::cos:
8020 case Intrinsic::cosh:
8021 case Intrinsic::sincos:
8022 case Intrinsic::sincospi:
8023 case Intrinsic::tan:
8024 case Intrinsic::tanh:
8025 case Intrinsic::asin:
8026 case Intrinsic::acos:
8027 case Intrinsic::atan:
8028 case Intrinsic::atan2:
8029 case Intrinsic::canonicalize:
8030 case Intrinsic::sqrt:
8031 case Intrinsic::exp:
8032 case Intrinsic::exp2:
8033 case Intrinsic::exp10:
8034 case Intrinsic::log:
8035 case Intrinsic::log2:
8036 case Intrinsic::log10:
8037 case Intrinsic::modf:
8038 case Intrinsic::floor:
8039 case Intrinsic::ceil:
8040 case Intrinsic::trunc:
8041 case Intrinsic::rint:
8042 case Intrinsic::nearbyint:
8043 case Intrinsic::round:
8044 case Intrinsic::roundeven:
8045 case Intrinsic::lrint:
8046 case Intrinsic::llrint:
8047 case Intrinsic::fshl:
8048 case Intrinsic::fshr:
8049 return true;
8050 default:
8051 return false;
8052 }
8053}
8054
8055bool llvm::propagatesPoison(const Use &PoisonOp) {
8056 const Operator *I = cast<Operator>(Val: PoisonOp.getUser());
8057 switch (I->getOpcode()) {
8058 case Instruction::Freeze:
8059 case Instruction::PHI:
8060 case Instruction::Invoke:
8061 return false;
8062 case Instruction::Select:
8063 return PoisonOp.getOperandNo() == 0;
8064 case Instruction::Call:
8065 if (auto *II = dyn_cast<IntrinsicInst>(Val: I))
8066 return intrinsicPropagatesPoison(IID: II->getIntrinsicID());
8067 return false;
8068 case Instruction::ICmp:
8069 case Instruction::FCmp:
8070 case Instruction::GetElementPtr:
8071 return true;
8072 default:
8073 if (isa<BinaryOperator>(Val: I) || isa<UnaryOperator>(Val: I) || isa<CastInst>(Val: I))
8074 return true;
8075
8076 // Be conservative and return false.
8077 return false;
8078 }
8079}
8080
8081/// Enumerates all operands of \p I that are guaranteed to not be undef or
8082/// poison. If the callback \p Handle returns true, stop processing and return
8083/// true. Otherwise, return false.
8084template <typename CallableT>
8085static bool handleGuaranteedWellDefinedOps(const Instruction *I,
8086 const CallableT &Handle) {
8087 switch (I->getOpcode()) {
8088 case Instruction::Store:
8089 if (Handle(cast<StoreInst>(Val: I)->getPointerOperand()))
8090 return true;
8091 break;
8092
8093 case Instruction::Load:
8094 if (Handle(cast<LoadInst>(Val: I)->getPointerOperand()))
8095 return true;
8096 break;
8097
8098 // Since dereferenceable attribute imply noundef, atomic operations
8099 // also implicitly have noundef pointers too
8100 case Instruction::AtomicCmpXchg:
8101 if (Handle(cast<AtomicCmpXchgInst>(Val: I)->getPointerOperand()))
8102 return true;
8103 break;
8104
8105 case Instruction::AtomicRMW:
8106 if (Handle(cast<AtomicRMWInst>(Val: I)->getPointerOperand()))
8107 return true;
8108 break;
8109
8110 case Instruction::Call:
8111 case Instruction::Invoke: {
8112 const CallBase *CB = cast<CallBase>(Val: I);
8113 if (CB->isIndirectCall() && Handle(CB->getCalledOperand()))
8114 return true;
8115 for (unsigned i = 0; i < CB->arg_size(); ++i)
8116 if ((CB->paramHasAttr(ArgNo: i, Kind: Attribute::NoUndef) ||
8117 CB->paramHasAttr(ArgNo: i, Kind: Attribute::Dereferenceable) ||
8118 CB->paramHasAttr(ArgNo: i, Kind: Attribute::DereferenceableOrNull)) &&
8119 Handle(CB->getArgOperand(i)))
8120 return true;
8121 break;
8122 }
8123 case Instruction::Ret:
8124 if (I->getFunction()->hasRetAttribute(Kind: Attribute::NoUndef) &&
8125 Handle(I->getOperand(i: 0)))
8126 return true;
8127 break;
8128 case Instruction::Switch:
8129 if (Handle(cast<SwitchInst>(Val: I)->getCondition()))
8130 return true;
8131 break;
8132 case Instruction::Br: {
8133 auto *BR = cast<BranchInst>(Val: I);
8134 if (BR->isConditional() && Handle(BR->getCondition()))
8135 return true;
8136 break;
8137 }
8138 default:
8139 break;
8140 }
8141
8142 return false;
8143}
8144
8145/// Enumerates all operands of \p I that are guaranteed to not be poison.
8146template <typename CallableT>
8147static bool handleGuaranteedNonPoisonOps(const Instruction *I,
8148 const CallableT &Handle) {
8149 if (handleGuaranteedWellDefinedOps(I, Handle))
8150 return true;
8151 switch (I->getOpcode()) {
8152 // Divisors of these operations are allowed to be partially undef.
8153 case Instruction::UDiv:
8154 case Instruction::SDiv:
8155 case Instruction::URem:
8156 case Instruction::SRem:
8157 return Handle(I->getOperand(i: 1));
8158 default:
8159 return false;
8160 }
8161}
8162
8163bool llvm::mustTriggerUB(const Instruction *I,
8164 const SmallPtrSetImpl<const Value *> &KnownPoison) {
8165 return handleGuaranteedNonPoisonOps(
8166 I, Handle: [&](const Value *V) { return KnownPoison.count(Ptr: V); });
8167}
8168
8169static bool programUndefinedIfUndefOrPoison(const Value *V,
8170 bool PoisonOnly) {
8171 // We currently only look for uses of values within the same basic
8172 // block, as that makes it easier to guarantee that the uses will be
8173 // executed given that Inst is executed.
8174 //
8175 // FIXME: Expand this to consider uses beyond the same basic block. To do
8176 // this, look out for the distinction between post-dominance and strong
8177 // post-dominance.
8178 const BasicBlock *BB = nullptr;
8179 BasicBlock::const_iterator Begin;
8180 if (const auto *Inst = dyn_cast<Instruction>(Val: V)) {
8181 BB = Inst->getParent();
8182 Begin = Inst->getIterator();
8183 Begin++;
8184 } else if (const auto *Arg = dyn_cast<Argument>(Val: V)) {
8185 if (Arg->getParent()->isDeclaration())
8186 return false;
8187 BB = &Arg->getParent()->getEntryBlock();
8188 Begin = BB->begin();
8189 } else {
8190 return false;
8191 }
8192
8193 // Limit number of instructions we look at, to avoid scanning through large
8194 // blocks. The current limit is chosen arbitrarily.
8195 unsigned ScanLimit = 32;
8196 BasicBlock::const_iterator End = BB->end();
8197
8198 if (!PoisonOnly) {
8199 // Since undef does not propagate eagerly, be conservative & just check
8200 // whether a value is directly passed to an instruction that must take
8201 // well-defined operands.
8202
8203 for (const auto &I : make_range(x: Begin, y: End)) {
8204 if (--ScanLimit == 0)
8205 break;
8206
8207 if (handleGuaranteedWellDefinedOps(I: &I, Handle: [V](const Value *WellDefinedOp) {
8208 return WellDefinedOp == V;
8209 }))
8210 return true;
8211
8212 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8213 break;
8214 }
8215 return false;
8216 }
8217
8218 // Set of instructions that we have proved will yield poison if Inst
8219 // does.
8220 SmallPtrSet<const Value *, 16> YieldsPoison;
8221 SmallPtrSet<const BasicBlock *, 4> Visited;
8222
8223 YieldsPoison.insert(Ptr: V);
8224 Visited.insert(Ptr: BB);
8225
8226 while (true) {
8227 for (const auto &I : make_range(x: Begin, y: End)) {
8228 if (--ScanLimit == 0)
8229 return false;
8230 if (mustTriggerUB(I: &I, KnownPoison: YieldsPoison))
8231 return true;
8232 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
8233 return false;
8234
8235 // If an operand is poison and propagates it, mark I as yielding poison.
8236 for (const Use &Op : I.operands()) {
8237 if (YieldsPoison.count(Ptr: Op) && propagatesPoison(PoisonOp: Op)) {
8238 YieldsPoison.insert(Ptr: &I);
8239 break;
8240 }
8241 }
8242
8243 // Special handling for select, which returns poison if its operand 0 is
8244 // poison (handled in the loop above) *or* if both its true/false operands
8245 // are poison (handled here).
8246 if (I.getOpcode() == Instruction::Select &&
8247 YieldsPoison.count(Ptr: I.getOperand(i: 1)) &&
8248 YieldsPoison.count(Ptr: I.getOperand(i: 2))) {
8249 YieldsPoison.insert(Ptr: &I);
8250 }
8251 }
8252
8253 BB = BB->getSingleSuccessor();
8254 if (!BB || !Visited.insert(Ptr: BB).second)
8255 break;
8256
8257 Begin = BB->getFirstNonPHIIt();
8258 End = BB->end();
8259 }
8260 return false;
8261}
8262
8263bool llvm::programUndefinedIfUndefOrPoison(const Instruction *Inst) {
8264 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: false);
8265}
8266
8267bool llvm::programUndefinedIfPoison(const Instruction *Inst) {
8268 return ::programUndefinedIfUndefOrPoison(V: Inst, PoisonOnly: true);
8269}
8270
8271static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) {
8272 if (FMF.noNaNs())
8273 return true;
8274
8275 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8276 return !C->isNaN();
8277
8278 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8279 if (!C->getElementType()->isFloatingPointTy())
8280 return false;
8281 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8282 if (C->getElementAsAPFloat(i: I).isNaN())
8283 return false;
8284 }
8285 return true;
8286 }
8287
8288 if (isa<ConstantAggregateZero>(Val: V))
8289 return true;
8290
8291 return false;
8292}
8293
8294static bool isKnownNonZero(const Value *V) {
8295 if (auto *C = dyn_cast<ConstantFP>(Val: V))
8296 return !C->isZero();
8297
8298 if (auto *C = dyn_cast<ConstantDataVector>(Val: V)) {
8299 if (!C->getElementType()->isFloatingPointTy())
8300 return false;
8301 for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) {
8302 if (C->getElementAsAPFloat(i: I).isZero())
8303 return false;
8304 }
8305 return true;
8306 }
8307
8308 return false;
8309}
8310
8311/// Match clamp pattern for float types without care about NaNs or signed zeros.
8312/// Given non-min/max outer cmp/select from the clamp pattern this
8313/// function recognizes if it can be substitued by a "canonical" min/max
8314/// pattern.
8315static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred,
8316 Value *CmpLHS, Value *CmpRHS,
8317 Value *TrueVal, Value *FalseVal,
8318 Value *&LHS, Value *&RHS) {
8319 // Try to match
8320 // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2))
8321 // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2))
8322 // and return description of the outer Max/Min.
8323
8324 // First, check if select has inverse order:
8325 if (CmpRHS == FalseVal) {
8326 std::swap(a&: TrueVal, b&: FalseVal);
8327 Pred = CmpInst::getInversePredicate(pred: Pred);
8328 }
8329
8330 // Assume success now. If there's no match, callers should not use these anyway.
8331 LHS = TrueVal;
8332 RHS = FalseVal;
8333
8334 const APFloat *FC1;
8335 if (CmpRHS != TrueVal || !match(V: CmpRHS, P: m_APFloat(Res&: FC1)) || !FC1->isFinite())
8336 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8337
8338 const APFloat *FC2;
8339 switch (Pred) {
8340 case CmpInst::FCMP_OLT:
8341 case CmpInst::FCMP_OLE:
8342 case CmpInst::FCMP_ULT:
8343 case CmpInst::FCMP_ULE:
8344 if (match(V: FalseVal, P: m_OrdOrUnordFMin(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8345 *FC1 < *FC2)
8346 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8347 break;
8348 case CmpInst::FCMP_OGT:
8349 case CmpInst::FCMP_OGE:
8350 case CmpInst::FCMP_UGT:
8351 case CmpInst::FCMP_UGE:
8352 if (match(V: FalseVal, P: m_OrdOrUnordFMax(L: m_Specific(V: CmpLHS), R: m_APFloat(Res&: FC2))) &&
8353 *FC1 > *FC2)
8354 return {.Flavor: SPF_FMINNUM, .NaNBehavior: SPNB_RETURNS_ANY, .Ordered: false};
8355 break;
8356 default:
8357 break;
8358 }
8359
8360 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8361}
8362
8363/// Recognize variations of:
8364/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v)))
8365static SelectPatternResult matchClamp(CmpInst::Predicate Pred,
8366 Value *CmpLHS, Value *CmpRHS,
8367 Value *TrueVal, Value *FalseVal) {
8368 // Swap the select operands and predicate to match the patterns below.
8369 if (CmpRHS != TrueVal) {
8370 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8371 std::swap(a&: TrueVal, b&: FalseVal);
8372 }
8373 const APInt *C1;
8374 if (CmpRHS == TrueVal && match(V: CmpRHS, P: m_APInt(Res&: C1))) {
8375 const APInt *C2;
8376 // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1)
8377 if (match(V: FalseVal, P: m_SMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8378 C1->slt(RHS: *C2) && Pred == CmpInst::ICMP_SLT)
8379 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8380
8381 // (X >s C1) ? C1 : SMAX(X, C2) ==> SMIN(SMAX(X, C2), C1)
8382 if (match(V: FalseVal, P: m_SMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8383 C1->sgt(RHS: *C2) && Pred == CmpInst::ICMP_SGT)
8384 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8385
8386 // (X <u C1) ? C1 : UMIN(X, C2) ==> UMAX(UMIN(X, C2), C1)
8387 if (match(V: FalseVal, P: m_UMin(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8388 C1->ult(RHS: *C2) && Pred == CmpInst::ICMP_ULT)
8389 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8390
8391 // (X >u C1) ? C1 : UMAX(X, C2) ==> UMIN(UMAX(X, C2), C1)
8392 if (match(V: FalseVal, P: m_UMax(L: m_Specific(V: CmpLHS), R: m_APInt(Res&: C2))) &&
8393 C1->ugt(RHS: *C2) && Pred == CmpInst::ICMP_UGT)
8394 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8395 }
8396 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8397}
8398
8399/// Recognize variations of:
8400/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c))
8401static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred,
8402 Value *CmpLHS, Value *CmpRHS,
8403 Value *TVal, Value *FVal,
8404 unsigned Depth) {
8405 // TODO: Allow FP min/max with nnan/nsz.
8406 assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison");
8407
8408 Value *A = nullptr, *B = nullptr;
8409 SelectPatternResult L = matchSelectPattern(V: TVal, LHS&: A, RHS&: B, CastOp: nullptr, Depth: Depth + 1);
8410 if (!SelectPatternResult::isMinOrMax(SPF: L.Flavor))
8411 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8412
8413 Value *C = nullptr, *D = nullptr;
8414 SelectPatternResult R = matchSelectPattern(V: FVal, LHS&: C, RHS&: D, CastOp: nullptr, Depth: Depth + 1);
8415 if (L.Flavor != R.Flavor)
8416 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8417
8418 // We have something like: x Pred y ? min(a, b) : min(c, d).
8419 // Try to match the compare to the min/max operations of the select operands.
8420 // First, make sure we have the right compare predicate.
8421 switch (L.Flavor) {
8422 case SPF_SMIN:
8423 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) {
8424 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8425 std::swap(a&: CmpLHS, b&: CmpRHS);
8426 }
8427 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
8428 break;
8429 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8430 case SPF_SMAX:
8431 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) {
8432 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8433 std::swap(a&: CmpLHS, b&: CmpRHS);
8434 }
8435 if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
8436 break;
8437 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8438 case SPF_UMIN:
8439 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) {
8440 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8441 std::swap(a&: CmpLHS, b&: CmpRHS);
8442 }
8443 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE)
8444 break;
8445 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8446 case SPF_UMAX:
8447 if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
8448 Pred = ICmpInst::getSwappedPredicate(pred: Pred);
8449 std::swap(a&: CmpLHS, b&: CmpRHS);
8450 }
8451 if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
8452 break;
8453 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8454 default:
8455 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8456 }
8457
8458 // If there is a common operand in the already matched min/max and the other
8459 // min/max operands match the compare operands (either directly or inverted),
8460 // then this is min/max of the same flavor.
8461
8462 // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8463 // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b))
8464 if (D == B) {
8465 if ((CmpLHS == A && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8466 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8467 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8468 }
8469 // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8470 // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d))
8471 if (C == B) {
8472 if ((CmpLHS == A && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8473 match(V: A, P: m_Not(V: m_Specific(V: CmpRHS)))))
8474 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8475 }
8476 // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8477 // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a))
8478 if (D == A) {
8479 if ((CmpLHS == B && CmpRHS == C) || (match(V: C, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8480 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8481 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8482 }
8483 // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8484 // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d))
8485 if (C == A) {
8486 if ((CmpLHS == B && CmpRHS == D) || (match(V: D, P: m_Not(V: m_Specific(V: CmpLHS))) &&
8487 match(V: B, P: m_Not(V: m_Specific(V: CmpRHS)))))
8488 return {.Flavor: L.Flavor, .NaNBehavior: SPNB_NA, .Ordered: false};
8489 }
8490
8491 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8492}
8493
8494/// If the input value is the result of a 'not' op, constant integer, or vector
8495/// splat of a constant integer, return the bitwise-not source value.
8496/// TODO: This could be extended to handle non-splat vector integer constants.
8497static Value *getNotValue(Value *V) {
8498 Value *NotV;
8499 if (match(V, P: m_Not(V: m_Value(V&: NotV))))
8500 return NotV;
8501
8502 const APInt *C;
8503 if (match(V, P: m_APInt(Res&: C)))
8504 return ConstantInt::get(Ty: V->getType(), V: ~(*C));
8505
8506 return nullptr;
8507}
8508
8509/// Match non-obvious integer minimum and maximum sequences.
8510static SelectPatternResult matchMinMax(CmpInst::Predicate Pred,
8511 Value *CmpLHS, Value *CmpRHS,
8512 Value *TrueVal, Value *FalseVal,
8513 Value *&LHS, Value *&RHS,
8514 unsigned Depth) {
8515 // Assume success. If there's no match, callers should not use these anyway.
8516 LHS = TrueVal;
8517 RHS = FalseVal;
8518
8519 SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal);
8520 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8521 return SPR;
8522
8523 SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TVal: TrueVal, FVal: FalseVal, Depth);
8524 if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN)
8525 return SPR;
8526
8527 // Look through 'not' ops to find disguised min/max.
8528 // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y)
8529 // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y)
8530 if (CmpLHS == getNotValue(V: TrueVal) && CmpRHS == getNotValue(V: FalseVal)) {
8531 switch (Pred) {
8532 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8533 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8534 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8535 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8536 default: break;
8537 }
8538 }
8539
8540 // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X)
8541 // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X)
8542 if (CmpLHS == getNotValue(V: FalseVal) && CmpRHS == getNotValue(V: TrueVal)) {
8543 switch (Pred) {
8544 case CmpInst::ICMP_SGT: return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8545 case CmpInst::ICMP_SLT: return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8546 case CmpInst::ICMP_UGT: return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8547 case CmpInst::ICMP_ULT: return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8548 default: break;
8549 }
8550 }
8551
8552 if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT)
8553 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8554
8555 const APInt *C1;
8556 if (!match(V: CmpRHS, P: m_APInt(Res&: C1)))
8557 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8558
8559 // An unsigned min/max can be written with a signed compare.
8560 const APInt *C2;
8561 if ((CmpLHS == TrueVal && match(V: FalseVal, P: m_APInt(Res&: C2))) ||
8562 (CmpLHS == FalseVal && match(V: TrueVal, P: m_APInt(Res&: C2)))) {
8563 // Is the sign bit set?
8564 // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX
8565 // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN
8566 if (Pred == CmpInst::ICMP_SLT && C1->isZero() && C2->isMaxSignedValue())
8567 return {.Flavor: CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8568
8569 // Is the sign bit clear?
8570 // (X >s -1) ? MINVAL : X ==> (X <u MINVAL) ? MINVAL : X ==> UMAX
8571 // (X >s -1) ? X : MINVAL ==> (X <u MINVAL) ? X : MINVAL ==> UMIN
8572 if (Pred == CmpInst::ICMP_SGT && C1->isAllOnes() && C2->isMinSignedValue())
8573 return {.Flavor: CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8574 }
8575
8576 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8577}
8578
8579bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW,
8580 bool AllowPoison) {
8581 assert(X && Y && "Invalid operand");
8582
8583 auto IsNegationOf = [&](const Value *X, const Value *Y) {
8584 if (!match(V: X, P: m_Neg(V: m_Specific(V: Y))))
8585 return false;
8586
8587 auto *BO = cast<BinaryOperator>(Val: X);
8588 if (NeedNSW && !BO->hasNoSignedWrap())
8589 return false;
8590
8591 auto *Zero = cast<Constant>(Val: BO->getOperand(i_nocapture: 0));
8592 if (!AllowPoison && !Zero->isNullValue())
8593 return false;
8594
8595 return true;
8596 };
8597
8598 // X = -Y or Y = -X
8599 if (IsNegationOf(X, Y) || IsNegationOf(Y, X))
8600 return true;
8601
8602 // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A)
8603 Value *A, *B;
8604 return (!NeedNSW && (match(V: X, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8605 match(V: Y, P: m_Sub(L: m_Specific(V: B), R: m_Specific(V: A))))) ||
8606 (NeedNSW && (match(V: X, P: m_NSWSub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
8607 match(V: Y, P: m_NSWSub(L: m_Specific(V: B), R: m_Specific(V: A)))));
8608}
8609
8610bool llvm::isKnownInversion(const Value *X, const Value *Y) {
8611 // Handle X = icmp pred A, B, Y = icmp pred A, C.
8612 Value *A, *B, *C;
8613 CmpPredicate Pred1, Pred2;
8614 if (!match(V: X, P: m_ICmp(Pred&: Pred1, L: m_Value(V&: A), R: m_Value(V&: B))) ||
8615 !match(V: Y, P: m_c_ICmp(Pred&: Pred2, L: m_Specific(V: A), R: m_Value(V&: C))))
8616 return false;
8617
8618 // They must both have samesign flag or not.
8619 if (Pred1.hasSameSign() != Pred2.hasSameSign())
8620 return false;
8621
8622 if (B == C)
8623 return Pred1 == ICmpInst::getInversePredicate(pred: Pred2);
8624
8625 // Try to infer the relationship from constant ranges.
8626 const APInt *RHSC1, *RHSC2;
8627 if (!match(V: B, P: m_APInt(Res&: RHSC1)) || !match(V: C, P: m_APInt(Res&: RHSC2)))
8628 return false;
8629
8630 // Sign bits of two RHSCs should match.
8631 if (Pred1.hasSameSign() && RHSC1->isNonNegative() != RHSC2->isNonNegative())
8632 return false;
8633
8634 const auto CR1 = ConstantRange::makeExactICmpRegion(Pred: Pred1, Other: *RHSC1);
8635 const auto CR2 = ConstantRange::makeExactICmpRegion(Pred: Pred2, Other: *RHSC2);
8636
8637 return CR1.inverse() == CR2;
8638}
8639
8640SelectPatternResult llvm::getSelectPattern(CmpInst::Predicate Pred,
8641 SelectPatternNaNBehavior NaNBehavior,
8642 bool Ordered) {
8643 switch (Pred) {
8644 default:
8645 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false}; // Equality.
8646 case ICmpInst::ICMP_UGT:
8647 case ICmpInst::ICMP_UGE:
8648 return {.Flavor: SPF_UMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8649 case ICmpInst::ICMP_SGT:
8650 case ICmpInst::ICMP_SGE:
8651 return {.Flavor: SPF_SMAX, .NaNBehavior: SPNB_NA, .Ordered: false};
8652 case ICmpInst::ICMP_ULT:
8653 case ICmpInst::ICMP_ULE:
8654 return {.Flavor: SPF_UMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8655 case ICmpInst::ICMP_SLT:
8656 case ICmpInst::ICMP_SLE:
8657 return {.Flavor: SPF_SMIN, .NaNBehavior: SPNB_NA, .Ordered: false};
8658 case FCmpInst::FCMP_UGT:
8659 case FCmpInst::FCMP_UGE:
8660 case FCmpInst::FCMP_OGT:
8661 case FCmpInst::FCMP_OGE:
8662 return {.Flavor: SPF_FMAXNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8663 case FCmpInst::FCMP_ULT:
8664 case FCmpInst::FCMP_ULE:
8665 case FCmpInst::FCMP_OLT:
8666 case FCmpInst::FCMP_OLE:
8667 return {.Flavor: SPF_FMINNUM, .NaNBehavior: NaNBehavior, .Ordered: Ordered};
8668 }
8669}
8670
8671std::optional<std::pair<CmpPredicate, Constant *>>
8672llvm::getFlippedStrictnessPredicateAndConstant(CmpPredicate Pred, Constant *C) {
8673 assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
8674 "Only for relational integer predicates.");
8675 if (isa<UndefValue>(Val: C))
8676 return std::nullopt;
8677
8678 Type *Type = C->getType();
8679 bool IsSigned = ICmpInst::isSigned(predicate: Pred);
8680
8681 CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred);
8682 bool WillIncrement =
8683 UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT;
8684
8685 // Check if the constant operand can be safely incremented/decremented
8686 // without overflowing/underflowing.
8687 auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) {
8688 return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned);
8689 };
8690
8691 Constant *SafeReplacementConstant = nullptr;
8692 if (auto *CI = dyn_cast<ConstantInt>(Val: C)) {
8693 // Bail out if the constant can't be safely incremented/decremented.
8694 if (!ConstantIsOk(CI))
8695 return std::nullopt;
8696 } else if (auto *FVTy = dyn_cast<FixedVectorType>(Val: Type)) {
8697 unsigned NumElts = FVTy->getNumElements();
8698 for (unsigned i = 0; i != NumElts; ++i) {
8699 Constant *Elt = C->getAggregateElement(Elt: i);
8700 if (!Elt)
8701 return std::nullopt;
8702
8703 if (isa<UndefValue>(Val: Elt))
8704 continue;
8705
8706 // Bail out if we can't determine if this constant is min/max or if we
8707 // know that this constant is min/max.
8708 auto *CI = dyn_cast<ConstantInt>(Val: Elt);
8709 if (!CI || !ConstantIsOk(CI))
8710 return std::nullopt;
8711
8712 if (!SafeReplacementConstant)
8713 SafeReplacementConstant = CI;
8714 }
8715 } else if (isa<VectorType>(Val: C->getType())) {
8716 // Handle scalable splat
8717 Value *SplatC = C->getSplatValue();
8718 auto *CI = dyn_cast_or_null<ConstantInt>(Val: SplatC);
8719 // Bail out if the constant can't be safely incremented/decremented.
8720 if (!CI || !ConstantIsOk(CI))
8721 return std::nullopt;
8722 } else {
8723 // ConstantExpr?
8724 return std::nullopt;
8725 }
8726
8727 // It may not be safe to change a compare predicate in the presence of
8728 // undefined elements, so replace those elements with the first safe constant
8729 // that we found.
8730 // TODO: in case of poison, it is safe; let's replace undefs only.
8731 if (C->containsUndefOrPoisonElement()) {
8732 assert(SafeReplacementConstant && "Replacement constant not set");
8733 C = Constant::replaceUndefsWith(C, Replacement: SafeReplacementConstant);
8734 }
8735
8736 CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(pred: Pred);
8737
8738 // Increment or decrement the constant.
8739 Constant *OneOrNegOne = ConstantInt::get(Ty: Type, V: WillIncrement ? 1 : -1, IsSigned: true);
8740 Constant *NewC = ConstantExpr::getAdd(C1: C, C2: OneOrNegOne);
8741
8742 return std::make_pair(x&: NewPred, y&: NewC);
8743}
8744
8745static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred,
8746 FastMathFlags FMF,
8747 Value *CmpLHS, Value *CmpRHS,
8748 Value *TrueVal, Value *FalseVal,
8749 Value *&LHS, Value *&RHS,
8750 unsigned Depth) {
8751 bool HasMismatchedZeros = false;
8752 if (CmpInst::isFPPredicate(P: Pred)) {
8753 // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one
8754 // 0.0 operand, set the compare's 0.0 operands to that same value for the
8755 // purpose of identifying min/max. Disregard vector constants with undefined
8756 // elements because those can not be back-propagated for analysis.
8757 Value *OutputZeroVal = nullptr;
8758 if (match(V: TrueVal, P: m_AnyZeroFP()) && !match(V: FalseVal, P: m_AnyZeroFP()) &&
8759 !cast<Constant>(Val: TrueVal)->containsUndefOrPoisonElement())
8760 OutputZeroVal = TrueVal;
8761 else if (match(V: FalseVal, P: m_AnyZeroFP()) && !match(V: TrueVal, P: m_AnyZeroFP()) &&
8762 !cast<Constant>(Val: FalseVal)->containsUndefOrPoisonElement())
8763 OutputZeroVal = FalseVal;
8764
8765 if (OutputZeroVal) {
8766 if (match(V: CmpLHS, P: m_AnyZeroFP()) && CmpLHS != OutputZeroVal) {
8767 HasMismatchedZeros = true;
8768 CmpLHS = OutputZeroVal;
8769 }
8770 if (match(V: CmpRHS, P: m_AnyZeroFP()) && CmpRHS != OutputZeroVal) {
8771 HasMismatchedZeros = true;
8772 CmpRHS = OutputZeroVal;
8773 }
8774 }
8775 }
8776
8777 LHS = CmpLHS;
8778 RHS = CmpRHS;
8779
8780 // Signed zero may return inconsistent results between implementations.
8781 // (0.0 <= -0.0) ? 0.0 : -0.0 // Returns 0.0
8782 // minNum(0.0, -0.0) // May return -0.0 or 0.0 (IEEE 754-2008 5.3.1)
8783 // Therefore, we behave conservatively and only proceed if at least one of the
8784 // operands is known to not be zero or if we don't care about signed zero.
8785 switch (Pred) {
8786 default: break;
8787 case CmpInst::FCMP_OGT: case CmpInst::FCMP_OLT:
8788 case CmpInst::FCMP_UGT: case CmpInst::FCMP_ULT:
8789 if (!HasMismatchedZeros)
8790 break;
8791 [[fallthrough]];
8792 case CmpInst::FCMP_OGE: case CmpInst::FCMP_OLE:
8793 case CmpInst::FCMP_UGE: case CmpInst::FCMP_ULE:
8794 if (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8795 !isKnownNonZero(V: CmpRHS))
8796 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8797 }
8798
8799 SelectPatternNaNBehavior NaNBehavior = SPNB_NA;
8800 bool Ordered = false;
8801
8802 // When given one NaN and one non-NaN input:
8803 // - maxnum/minnum (C99 fmaxf()/fminf()) return the non-NaN input.
8804 // - A simple C99 (a < b ? a : b) construction will return 'b' (as the
8805 // ordered comparison fails), which could be NaN or non-NaN.
8806 // so here we discover exactly what NaN behavior is required/accepted.
8807 if (CmpInst::isFPPredicate(P: Pred)) {
8808 bool LHSSafe = isKnownNonNaN(V: CmpLHS, FMF);
8809 bool RHSSafe = isKnownNonNaN(V: CmpRHS, FMF);
8810
8811 if (LHSSafe && RHSSafe) {
8812 // Both operands are known non-NaN.
8813 NaNBehavior = SPNB_RETURNS_ANY;
8814 Ordered = CmpInst::isOrdered(predicate: Pred);
8815 } else if (CmpInst::isOrdered(predicate: Pred)) {
8816 // An ordered comparison will return false when given a NaN, so it
8817 // returns the RHS.
8818 Ordered = true;
8819 if (LHSSafe)
8820 // LHS is non-NaN, so if RHS is NaN then NaN will be returned.
8821 NaNBehavior = SPNB_RETURNS_NAN;
8822 else if (RHSSafe)
8823 NaNBehavior = SPNB_RETURNS_OTHER;
8824 else
8825 // Completely unsafe.
8826 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8827 } else {
8828 Ordered = false;
8829 // An unordered comparison will return true when given a NaN, so it
8830 // returns the LHS.
8831 if (LHSSafe)
8832 // LHS is non-NaN, so if RHS is NaN then non-NaN will be returned.
8833 NaNBehavior = SPNB_RETURNS_OTHER;
8834 else if (RHSSafe)
8835 NaNBehavior = SPNB_RETURNS_NAN;
8836 else
8837 // Completely unsafe.
8838 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8839 }
8840 }
8841
8842 if (TrueVal == CmpRHS && FalseVal == CmpLHS) {
8843 std::swap(a&: CmpLHS, b&: CmpRHS);
8844 Pred = CmpInst::getSwappedPredicate(pred: Pred);
8845 if (NaNBehavior == SPNB_RETURNS_NAN)
8846 NaNBehavior = SPNB_RETURNS_OTHER;
8847 else if (NaNBehavior == SPNB_RETURNS_OTHER)
8848 NaNBehavior = SPNB_RETURNS_NAN;
8849 Ordered = !Ordered;
8850 }
8851
8852 // ([if]cmp X, Y) ? X : Y
8853 if (TrueVal == CmpLHS && FalseVal == CmpRHS)
8854 return getSelectPattern(Pred, NaNBehavior, Ordered);
8855
8856 if (isKnownNegation(X: TrueVal, Y: FalseVal)) {
8857 // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can
8858 // match against either LHS or sext(LHS).
8859 auto MaybeSExtCmpLHS =
8860 m_CombineOr(L: m_Specific(V: CmpLHS), R: m_SExt(Op: m_Specific(V: CmpLHS)));
8861 auto ZeroOrAllOnes = m_CombineOr(L: m_ZeroInt(), R: m_AllOnes());
8862 auto ZeroOrOne = m_CombineOr(L: m_ZeroInt(), R: m_One());
8863 if (match(V: TrueVal, P: MaybeSExtCmpLHS)) {
8864 // Set the return values. If the compare uses the negated value (-X >s 0),
8865 // swap the return values because the negated value is always 'RHS'.
8866 LHS = TrueVal;
8867 RHS = FalseVal;
8868 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: FalseVal))))
8869 std::swap(a&: LHS, b&: RHS);
8870
8871 // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X)
8872 // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X)
8873 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8874 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8875
8876 // (X >=s 0) ? X : -X or (X >=s 1) ? X : -X --> ABS(X)
8877 if (Pred == ICmpInst::ICMP_SGE && match(V: CmpRHS, P: ZeroOrOne))
8878 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8879
8880 // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X)
8881 // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X)
8882 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8883 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8884 }
8885 else if (match(V: FalseVal, P: MaybeSExtCmpLHS)) {
8886 // Set the return values. If the compare uses the negated value (-X >s 0),
8887 // swap the return values because the negated value is always 'RHS'.
8888 LHS = FalseVal;
8889 RHS = TrueVal;
8890 if (match(V: CmpLHS, P: m_Neg(V: m_Specific(V: TrueVal))))
8891 std::swap(a&: LHS, b&: RHS);
8892
8893 // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X)
8894 // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X)
8895 if (Pred == ICmpInst::ICMP_SGT && match(V: CmpRHS, P: ZeroOrAllOnes))
8896 return {.Flavor: SPF_NABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8897
8898 // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X)
8899 // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X)
8900 if (Pred == ICmpInst::ICMP_SLT && match(V: CmpRHS, P: ZeroOrOne))
8901 return {.Flavor: SPF_ABS, .NaNBehavior: SPNB_NA, .Ordered: false};
8902 }
8903 }
8904
8905 if (CmpInst::isIntPredicate(P: Pred))
8906 return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth);
8907
8908 // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar
8909 // may return either -0.0 or 0.0, so fcmp/select pair has stricter
8910 // semantics than minNum. Be conservative in such case.
8911 if (NaNBehavior != SPNB_RETURNS_ANY ||
8912 (!FMF.noSignedZeros() && !isKnownNonZero(V: CmpLHS) &&
8913 !isKnownNonZero(V: CmpRHS)))
8914 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
8915
8916 return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS);
8917}
8918
8919static Value *lookThroughCastConst(CmpInst *CmpI, Type *SrcTy, Constant *C,
8920 Instruction::CastOps *CastOp) {
8921 const DataLayout &DL = CmpI->getDataLayout();
8922
8923 Constant *CastedTo = nullptr;
8924 switch (*CastOp) {
8925 case Instruction::ZExt:
8926 if (CmpI->isUnsigned())
8927 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy);
8928 break;
8929 case Instruction::SExt:
8930 if (CmpI->isSigned())
8931 CastedTo = ConstantExpr::getTrunc(C, Ty: SrcTy, OnlyIfReduced: true);
8932 break;
8933 case Instruction::Trunc:
8934 Constant *CmpConst;
8935 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_Constant(C&: CmpConst)) &&
8936 CmpConst->getType() == SrcTy) {
8937 // Here we have the following case:
8938 //
8939 // %cond = cmp iN %x, CmpConst
8940 // %tr = trunc iN %x to iK
8941 // %narrowsel = select i1 %cond, iK %t, iK C
8942 //
8943 // We can always move trunc after select operation:
8944 //
8945 // %cond = cmp iN %x, CmpConst
8946 // %widesel = select i1 %cond, iN %x, iN CmpConst
8947 // %tr = trunc iN %widesel to iK
8948 //
8949 // Note that C could be extended in any way because we don't care about
8950 // upper bits after truncation. It can't be abs pattern, because it would
8951 // look like:
8952 //
8953 // select i1 %cond, x, -x.
8954 //
8955 // So only min/max pattern could be matched. Such match requires widened C
8956 // == CmpConst. That is why set widened C = CmpConst, condition trunc
8957 // CmpConst == C is checked below.
8958 CastedTo = CmpConst;
8959 } else {
8960 unsigned ExtOp = CmpI->isSigned() ? Instruction::SExt : Instruction::ZExt;
8961 CastedTo = ConstantFoldCastOperand(Opcode: ExtOp, C, DestTy: SrcTy, DL);
8962 }
8963 break;
8964 case Instruction::FPTrunc:
8965 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPExt, C, DestTy: SrcTy, DL);
8966 break;
8967 case Instruction::FPExt:
8968 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPTrunc, C, DestTy: SrcTy, DL);
8969 break;
8970 case Instruction::FPToUI:
8971 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::UIToFP, C, DestTy: SrcTy, DL);
8972 break;
8973 case Instruction::FPToSI:
8974 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::SIToFP, C, DestTy: SrcTy, DL);
8975 break;
8976 case Instruction::UIToFP:
8977 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToUI, C, DestTy: SrcTy, DL);
8978 break;
8979 case Instruction::SIToFP:
8980 CastedTo = ConstantFoldCastOperand(Opcode: Instruction::FPToSI, C, DestTy: SrcTy, DL);
8981 break;
8982 default:
8983 break;
8984 }
8985
8986 if (!CastedTo)
8987 return nullptr;
8988
8989 // Make sure the cast doesn't lose any information.
8990 Constant *CastedBack =
8991 ConstantFoldCastOperand(Opcode: *CastOp, C: CastedTo, DestTy: C->getType(), DL);
8992 if (CastedBack && CastedBack != C)
8993 return nullptr;
8994
8995 return CastedTo;
8996}
8997
8998/// Helps to match a select pattern in case of a type mismatch.
8999///
9000/// The function processes the case when type of true and false values of a
9001/// select instruction differs from type of the cmp instruction operands because
9002/// of a cast instruction. The function checks if it is legal to move the cast
9003/// operation after "select". If yes, it returns the new second value of
9004/// "select" (with the assumption that cast is moved):
9005/// 1. As operand of cast instruction when both values of "select" are same cast
9006/// instructions.
9007/// 2. As restored constant (by applying reverse cast operation) when the first
9008/// value of the "select" is a cast operation and the second value is a
9009/// constant. It is implemented in lookThroughCastConst().
9010/// 3. As one operand is cast instruction and the other is not. The operands in
9011/// sel(cmp) are in different type integer.
9012/// NOTE: We return only the new second value because the first value could be
9013/// accessed as operand of cast instruction.
9014static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2,
9015 Instruction::CastOps *CastOp) {
9016 auto *Cast1 = dyn_cast<CastInst>(Val: V1);
9017 if (!Cast1)
9018 return nullptr;
9019
9020 *CastOp = Cast1->getOpcode();
9021 Type *SrcTy = Cast1->getSrcTy();
9022 if (auto *Cast2 = dyn_cast<CastInst>(Val: V2)) {
9023 // If V1 and V2 are both the same cast from the same type, look through V1.
9024 if (*CastOp == Cast2->getOpcode() && SrcTy == Cast2->getSrcTy())
9025 return Cast2->getOperand(i_nocapture: 0);
9026 return nullptr;
9027 }
9028
9029 auto *C = dyn_cast<Constant>(Val: V2);
9030 if (C)
9031 return lookThroughCastConst(CmpI, SrcTy, C, CastOp);
9032
9033 Value *CastedTo = nullptr;
9034 if (*CastOp == Instruction::Trunc) {
9035 if (match(V: CmpI->getOperand(i_nocapture: 1), P: m_ZExtOrSExt(Op: m_Specific(V: V2)))) {
9036 // Here we have the following case:
9037 // %y_ext = sext iK %y to iN
9038 // %cond = cmp iN %x, %y_ext
9039 // %tr = trunc iN %x to iK
9040 // %narrowsel = select i1 %cond, iK %tr, iK %y
9041 //
9042 // We can always move trunc after select operation:
9043 // %y_ext = sext iK %y to iN
9044 // %cond = cmp iN %x, %y_ext
9045 // %widesel = select i1 %cond, iN %x, iN %y_ext
9046 // %tr = trunc iN %widesel to iK
9047 assert(V2->getType() == Cast1->getType() &&
9048 "V2 and Cast1 should be the same type.");
9049 CastedTo = CmpI->getOperand(i_nocapture: 1);
9050 }
9051 }
9052
9053 return CastedTo;
9054}
9055SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS,
9056 Instruction::CastOps *CastOp,
9057 unsigned Depth) {
9058 if (Depth >= MaxAnalysisRecursionDepth)
9059 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9060
9061 SelectInst *SI = dyn_cast<SelectInst>(Val: V);
9062 if (!SI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9063
9064 CmpInst *CmpI = dyn_cast<CmpInst>(Val: SI->getCondition());
9065 if (!CmpI) return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9066
9067 Value *TrueVal = SI->getTrueValue();
9068 Value *FalseVal = SI->getFalseValue();
9069
9070 return llvm::matchDecomposedSelectPattern(
9071 CmpI, TrueVal, FalseVal, LHS, RHS,
9072 FMF: isa<FPMathOperator>(Val: SI) ? SI->getFastMathFlags() : FastMathFlags(),
9073 CastOp, Depth);
9074}
9075
9076SelectPatternResult llvm::matchDecomposedSelectPattern(
9077 CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS,
9078 FastMathFlags FMF, Instruction::CastOps *CastOp, unsigned Depth) {
9079 CmpInst::Predicate Pred = CmpI->getPredicate();
9080 Value *CmpLHS = CmpI->getOperand(i_nocapture: 0);
9081 Value *CmpRHS = CmpI->getOperand(i_nocapture: 1);
9082 if (isa<FPMathOperator>(Val: CmpI) && CmpI->hasNoNaNs())
9083 FMF.setNoNaNs();
9084
9085 // Bail out early.
9086 if (CmpI->isEquality())
9087 return {.Flavor: SPF_UNKNOWN, .NaNBehavior: SPNB_NA, .Ordered: false};
9088
9089 // Deal with type mismatches.
9090 if (CastOp && CmpLHS->getType() != TrueVal->getType()) {
9091 if (Value *C = lookThroughCast(CmpI, V1: TrueVal, V2: FalseVal, CastOp)) {
9092 // If this is a potential fmin/fmax with a cast to integer, then ignore
9093 // -0.0 because there is no corresponding integer value.
9094 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9095 FMF.setNoSignedZeros();
9096 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9097 TrueVal: cast<CastInst>(Val: TrueVal)->getOperand(i_nocapture: 0), FalseVal: C,
9098 LHS, RHS, Depth);
9099 }
9100 if (Value *C = lookThroughCast(CmpI, V1: FalseVal, V2: TrueVal, CastOp)) {
9101 // If this is a potential fmin/fmax with a cast to integer, then ignore
9102 // -0.0 because there is no corresponding integer value.
9103 if (*CastOp == Instruction::FPToSI || *CastOp == Instruction::FPToUI)
9104 FMF.setNoSignedZeros();
9105 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS,
9106 TrueVal: C, FalseVal: cast<CastInst>(Val: FalseVal)->getOperand(i_nocapture: 0),
9107 LHS, RHS, Depth);
9108 }
9109 }
9110 return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal,
9111 LHS, RHS, Depth);
9112}
9113
9114CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) {
9115 if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT;
9116 if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT;
9117 if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT;
9118 if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT;
9119 if (SPF == SPF_FMINNUM)
9120 return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT;
9121 if (SPF == SPF_FMAXNUM)
9122 return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT;
9123 llvm_unreachable("unhandled!");
9124}
9125
9126Intrinsic::ID llvm::getMinMaxIntrinsic(SelectPatternFlavor SPF) {
9127 switch (SPF) {
9128 case SelectPatternFlavor::SPF_UMIN:
9129 return Intrinsic::umin;
9130 case SelectPatternFlavor::SPF_UMAX:
9131 return Intrinsic::umax;
9132 case SelectPatternFlavor::SPF_SMIN:
9133 return Intrinsic::smin;
9134 case SelectPatternFlavor::SPF_SMAX:
9135 return Intrinsic::smax;
9136 default:
9137 llvm_unreachable("Unexpected SPF");
9138 }
9139}
9140
9141SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) {
9142 if (SPF == SPF_SMIN) return SPF_SMAX;
9143 if (SPF == SPF_UMIN) return SPF_UMAX;
9144 if (SPF == SPF_SMAX) return SPF_SMIN;
9145 if (SPF == SPF_UMAX) return SPF_UMIN;
9146 llvm_unreachable("unhandled!");
9147}
9148
9149Intrinsic::ID llvm::getInverseMinMaxIntrinsic(Intrinsic::ID MinMaxID) {
9150 switch (MinMaxID) {
9151 case Intrinsic::smax: return Intrinsic::smin;
9152 case Intrinsic::smin: return Intrinsic::smax;
9153 case Intrinsic::umax: return Intrinsic::umin;
9154 case Intrinsic::umin: return Intrinsic::umax;
9155 // Please note that next four intrinsics may produce the same result for
9156 // original and inverted case even if X != Y due to NaN is handled specially.
9157 case Intrinsic::maximum: return Intrinsic::minimum;
9158 case Intrinsic::minimum: return Intrinsic::maximum;
9159 case Intrinsic::maxnum: return Intrinsic::minnum;
9160 case Intrinsic::minnum: return Intrinsic::maxnum;
9161 case Intrinsic::maximumnum:
9162 return Intrinsic::minimumnum;
9163 case Intrinsic::minimumnum:
9164 return Intrinsic::maximumnum;
9165 default: llvm_unreachable("Unexpected intrinsic");
9166 }
9167}
9168
9169APInt llvm::getMinMaxLimit(SelectPatternFlavor SPF, unsigned BitWidth) {
9170 switch (SPF) {
9171 case SPF_SMAX: return APInt::getSignedMaxValue(numBits: BitWidth);
9172 case SPF_SMIN: return APInt::getSignedMinValue(numBits: BitWidth);
9173 case SPF_UMAX: return APInt::getMaxValue(numBits: BitWidth);
9174 case SPF_UMIN: return APInt::getMinValue(numBits: BitWidth);
9175 default: llvm_unreachable("Unexpected flavor");
9176 }
9177}
9178
9179std::pair<Intrinsic::ID, bool>
9180llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) {
9181 // Check if VL contains select instructions that can be folded into a min/max
9182 // vector intrinsic and return the intrinsic if it is possible.
9183 // TODO: Support floating point min/max.
9184 bool AllCmpSingleUse = true;
9185 SelectPatternResult SelectPattern;
9186 SelectPattern.Flavor = SPF_UNKNOWN;
9187 if (all_of(Range&: VL, P: [&SelectPattern, &AllCmpSingleUse](Value *I) {
9188 Value *LHS, *RHS;
9189 auto CurrentPattern = matchSelectPattern(V: I, LHS, RHS);
9190 if (!SelectPatternResult::isMinOrMax(SPF: CurrentPattern.Flavor))
9191 return false;
9192 if (SelectPattern.Flavor != SPF_UNKNOWN &&
9193 SelectPattern.Flavor != CurrentPattern.Flavor)
9194 return false;
9195 SelectPattern = CurrentPattern;
9196 AllCmpSingleUse &=
9197 match(V: I, P: m_Select(C: m_OneUse(SubPattern: m_Value()), L: m_Value(), R: m_Value()));
9198 return true;
9199 })) {
9200 switch (SelectPattern.Flavor) {
9201 case SPF_SMIN:
9202 return {Intrinsic::smin, AllCmpSingleUse};
9203 case SPF_UMIN:
9204 return {Intrinsic::umin, AllCmpSingleUse};
9205 case SPF_SMAX:
9206 return {Intrinsic::smax, AllCmpSingleUse};
9207 case SPF_UMAX:
9208 return {Intrinsic::umax, AllCmpSingleUse};
9209 case SPF_FMAXNUM:
9210 return {Intrinsic::maxnum, AllCmpSingleUse};
9211 case SPF_FMINNUM:
9212 return {Intrinsic::minnum, AllCmpSingleUse};
9213 default:
9214 llvm_unreachable("unexpected select pattern flavor");
9215 }
9216 }
9217 return {Intrinsic::not_intrinsic, false};
9218}
9219
9220template <typename InstTy>
9221static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
9222 Value *&Init, Value *&OtherOp) {
9223 // Handle the case of a simple two-predecessor recurrence PHI.
9224 // There's a lot more that could theoretically be done here, but
9225 // this is sufficient to catch some interesting cases.
9226 // TODO: Expand list -- gep, uadd.sat etc.
9227 if (PN->getNumIncomingValues() != 2)
9228 return false;
9229
9230 for (unsigned I = 0; I != 2; ++I) {
9231 if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(i: I));
9232 Operation && Operation->getNumOperands() >= 2) {
9233 Value *LHS = Operation->getOperand(0);
9234 Value *RHS = Operation->getOperand(1);
9235 if (LHS != PN && RHS != PN)
9236 continue;
9237
9238 Inst = Operation;
9239 Init = PN->getIncomingValue(i: !I);
9240 OtherOp = (LHS == PN) ? RHS : LHS;
9241 return true;
9242 }
9243 }
9244 return false;
9245}
9246
9247bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
9248 Value *&Start, Value *&Step) {
9249 // We try to match a recurrence of the form:
9250 // %iv = [Start, %entry], [%iv.next, %backedge]
9251 // %iv.next = binop %iv, Step
9252 // Or:
9253 // %iv = [Start, %entry], [%iv.next, %backedge]
9254 // %iv.next = binop Step, %iv
9255 return matchTwoInputRecurrence(PN: P, Inst&: BO, Init&: Start, OtherOp&: Step);
9256}
9257
9258bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
9259 Value *&Start, Value *&Step) {
9260 BinaryOperator *BO = nullptr;
9261 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 0));
9262 if (!P)
9263 P = dyn_cast<PHINode>(Val: I->getOperand(i_nocapture: 1));
9264 return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I;
9265}
9266
9267bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I,
9268 PHINode *&P, Value *&Init,
9269 Value *&OtherOp) {
9270 // Binary intrinsics only supported for now.
9271 if (I->arg_size() != 2 || I->getType() != I->getArgOperand(i: 0)->getType() ||
9272 I->getType() != I->getArgOperand(i: 1)->getType())
9273 return false;
9274
9275 IntrinsicInst *II = nullptr;
9276 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 0));
9277 if (!P)
9278 P = dyn_cast<PHINode>(Val: I->getArgOperand(i: 1));
9279
9280 return P && matchTwoInputRecurrence(PN: P, Inst&: II, Init, OtherOp) && II == I;
9281}
9282
9283/// Return true if "icmp Pred LHS RHS" is always true.
9284static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS,
9285 const Value *RHS) {
9286 if (ICmpInst::isTrueWhenEqual(predicate: Pred) && LHS == RHS)
9287 return true;
9288
9289 switch (Pred) {
9290 default:
9291 return false;
9292
9293 case CmpInst::ICMP_SLE: {
9294 const APInt *C;
9295
9296 // LHS s<= LHS +_{nsw} C if C >= 0
9297 // LHS s<= LHS | C if C >= 0
9298 if (match(V: RHS, P: m_NSWAdd(L: m_Specific(V: LHS), R: m_APInt(Res&: C))) ||
9299 match(V: RHS, P: m_Or(L: m_Specific(V: LHS), R: m_APInt(Res&: C))))
9300 return !C->isNegative();
9301
9302 // LHS s<= smax(LHS, V) for any V
9303 if (match(V: RHS, P: m_c_SMax(L: m_Specific(V: LHS), R: m_Value())))
9304 return true;
9305
9306 // smin(RHS, V) s<= RHS for any V
9307 if (match(V: LHS, P: m_c_SMin(L: m_Specific(V: RHS), R: m_Value())))
9308 return true;
9309
9310 // Match A to (X +_{nsw} CA) and B to (X +_{nsw} CB)
9311 const Value *X;
9312 const APInt *CLHS, *CRHS;
9313 if (match(V: LHS, P: m_NSWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9314 match(V: RHS, P: m_NSWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9315 return CLHS->sle(RHS: *CRHS);
9316
9317 return false;
9318 }
9319
9320 case CmpInst::ICMP_ULE: {
9321 // LHS u<= LHS +_{nuw} V for any V
9322 if (match(V: RHS, P: m_c_Add(L: m_Specific(V: LHS), R: m_Value())) &&
9323 cast<OverflowingBinaryOperator>(Val: RHS)->hasNoUnsignedWrap())
9324 return true;
9325
9326 // LHS u<= LHS | V for any V
9327 if (match(V: RHS, P: m_c_Or(L: m_Specific(V: LHS), R: m_Value())))
9328 return true;
9329
9330 // LHS u<= umax(LHS, V) for any V
9331 if (match(V: RHS, P: m_c_UMax(L: m_Specific(V: LHS), R: m_Value())))
9332 return true;
9333
9334 // RHS >> V u<= RHS for any V
9335 if (match(V: LHS, P: m_LShr(L: m_Specific(V: RHS), R: m_Value())))
9336 return true;
9337
9338 // RHS u/ C_ugt_1 u<= RHS
9339 const APInt *C;
9340 if (match(V: LHS, P: m_UDiv(L: m_Specific(V: RHS), R: m_APInt(Res&: C))) && C->ugt(RHS: 1))
9341 return true;
9342
9343 // RHS & V u<= RHS for any V
9344 if (match(V: LHS, P: m_c_And(L: m_Specific(V: RHS), R: m_Value())))
9345 return true;
9346
9347 // umin(RHS, V) u<= RHS for any V
9348 if (match(V: LHS, P: m_c_UMin(L: m_Specific(V: RHS), R: m_Value())))
9349 return true;
9350
9351 // Match A to (X +_{nuw} CA) and B to (X +_{nuw} CB)
9352 const Value *X;
9353 const APInt *CLHS, *CRHS;
9354 if (match(V: LHS, P: m_NUWAddLike(L: m_Value(V&: X), R: m_APInt(Res&: CLHS))) &&
9355 match(V: RHS, P: m_NUWAddLike(L: m_Specific(V: X), R: m_APInt(Res&: CRHS))))
9356 return CLHS->ule(RHS: *CRHS);
9357
9358 return false;
9359 }
9360 }
9361}
9362
9363/// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred
9364/// ALHS ARHS" is true. Otherwise, return std::nullopt.
9365static std::optional<bool>
9366isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS,
9367 const Value *ARHS, const Value *BLHS, const Value *BRHS) {
9368 switch (Pred) {
9369 default:
9370 return std::nullopt;
9371
9372 case CmpInst::ICMP_SLT:
9373 case CmpInst::ICMP_SLE:
9374 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BLHS, RHS: ALHS) &&
9375 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ARHS, RHS: BRHS))
9376 return true;
9377 return std::nullopt;
9378
9379 case CmpInst::ICMP_SGT:
9380 case CmpInst::ICMP_SGE:
9381 if (isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: ALHS, RHS: BLHS) &&
9382 isTruePredicate(Pred: CmpInst::ICMP_SLE, LHS: BRHS, RHS: ARHS))
9383 return true;
9384 return std::nullopt;
9385
9386 case CmpInst::ICMP_ULT:
9387 case CmpInst::ICMP_ULE:
9388 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BLHS, RHS: ALHS) &&
9389 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ARHS, RHS: BRHS))
9390 return true;
9391 return std::nullopt;
9392
9393 case CmpInst::ICMP_UGT:
9394 case CmpInst::ICMP_UGE:
9395 if (isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: ALHS, RHS: BLHS) &&
9396 isTruePredicate(Pred: CmpInst::ICMP_ULE, LHS: BRHS, RHS: ARHS))
9397 return true;
9398 return std::nullopt;
9399 }
9400}
9401
9402/// Return true if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is true.
9403/// Return false if "icmp LPred X, LCR" implies "icmp RPred X, RCR" is false.
9404/// Otherwise, return std::nullopt if we can't infer anything.
9405static std::optional<bool>
9406isImpliedCondCommonOperandWithCR(CmpPredicate LPred, const ConstantRange &LCR,
9407 CmpPredicate RPred, const ConstantRange &RCR) {
9408 auto CRImpliesPred = [&](ConstantRange CR,
9409 CmpInst::Predicate Pred) -> std::optional<bool> {
9410 // If all true values for lhs and true for rhs, lhs implies rhs
9411 if (CR.icmp(Pred, Other: RCR))
9412 return true;
9413
9414 // If there is no overlap, lhs implies not rhs
9415 if (CR.icmp(Pred: CmpInst::getInversePredicate(pred: Pred), Other: RCR))
9416 return false;
9417
9418 return std::nullopt;
9419 };
9420 if (auto Res = CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9421 RPred))
9422 return Res;
9423 if (LPred.hasSameSign() ^ RPred.hasSameSign()) {
9424 LPred = LPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: LPred)
9425 : LPred.dropSameSign();
9426 RPred = RPred.hasSameSign() ? ICmpInst::getFlippedSignednessPredicate(Pred: RPred)
9427 : RPred.dropSameSign();
9428 return CRImpliesPred(ConstantRange::makeAllowedICmpRegion(Pred: LPred, Other: LCR),
9429 RPred);
9430 }
9431 return std::nullopt;
9432}
9433
9434/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9435/// is true. Return false if LHS implies RHS is false. Otherwise, return
9436/// std::nullopt if we can't infer anything.
9437static std::optional<bool>
9438isImpliedCondICmps(CmpPredicate LPred, const Value *L0, const Value *L1,
9439 CmpPredicate RPred, const Value *R0, const Value *R1,
9440 const DataLayout &DL, bool LHSIsTrue) {
9441 // The rest of the logic assumes the LHS condition is true. If that's not the
9442 // case, invert the predicate to make it so.
9443 if (!LHSIsTrue)
9444 LPred = ICmpInst::getInverseCmpPredicate(Pred: LPred);
9445
9446 // We can have non-canonical operands, so try to normalize any common operand
9447 // to L0/R0.
9448 if (L0 == R1) {
9449 std::swap(a&: R0, b&: R1);
9450 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9451 }
9452 if (R0 == L1) {
9453 std::swap(a&: L0, b&: L1);
9454 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9455 }
9456 if (L1 == R1) {
9457 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9458 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9459 std::swap(a&: L0, b&: L1);
9460 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9461 std::swap(a&: R0, b&: R1);
9462 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9463 }
9464 }
9465
9466 // See if we can infer anything if operand-0 matches and we have at least one
9467 // constant.
9468 const APInt *Unused;
9469 if (L0 == R0 && (match(V: L1, P: m_APInt(Res&: Unused)) || match(V: R1, P: m_APInt(Res&: Unused)))) {
9470 // Potential TODO: We could also further use the constant range of L0/R0 to
9471 // further constraint the constant ranges. At the moment this leads to
9472 // several regressions related to not transforming `multi_use(A + C0) eq/ne
9473 // C1` (see discussion: D58633).
9474 ConstantRange LCR = computeConstantRange(
9475 V: L1, ForSigned: ICmpInst::isSigned(predicate: LPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9476 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9477 ConstantRange RCR = computeConstantRange(
9478 V: R1, ForSigned: ICmpInst::isSigned(predicate: RPred), /* UseInstrInfo=*/true, /*AC=*/nullptr,
9479 /*CxtI=*/CtxI: nullptr, /*DT=*/nullptr, Depth: MaxAnalysisRecursionDepth - 1);
9480 // Even if L1/R1 are not both constant, we can still sometimes deduce
9481 // relationship from a single constant. For example X u> Y implies X != 0.
9482 if (auto R = isImpliedCondCommonOperandWithCR(LPred, LCR, RPred, RCR))
9483 return R;
9484 // If both L1/R1 were exact constant ranges and we didn't get anything
9485 // here, we won't be able to deduce this.
9486 if (match(V: L1, P: m_APInt(Res&: Unused)) && match(V: R1, P: m_APInt(Res&: Unused)))
9487 return std::nullopt;
9488 }
9489
9490 // Can we infer anything when the two compares have matching operands?
9491 if (L0 == R0 && L1 == R1)
9492 return ICmpInst::isImpliedByMatchingCmp(Pred1: LPred, Pred2: RPred);
9493
9494 // It only really makes sense in the context of signed comparison for "X - Y
9495 // must be positive if X >= Y and no overflow".
9496 // Take SGT as an example: L0:x > L1:y and C >= 0
9497 // ==> R0:(x -nsw y) < R1:(-C) is false
9498 CmpInst::Predicate SignedLPred = LPred.getPreferredSignedPredicate();
9499 if ((SignedLPred == ICmpInst::ICMP_SGT ||
9500 SignedLPred == ICmpInst::ICMP_SGE) &&
9501 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9502 if (match(V: R1, P: m_NonPositive()) &&
9503 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == false)
9504 return false;
9505 }
9506
9507 // Take SLT as an example: L0:x < L1:y and C <= 0
9508 // ==> R0:(x -nsw y) < R1:(-C) is true
9509 if ((SignedLPred == ICmpInst::ICMP_SLT ||
9510 SignedLPred == ICmpInst::ICMP_SLE) &&
9511 match(V: R0, P: m_NSWSub(L: m_Specific(V: L0), R: m_Specific(V: L1)))) {
9512 if (match(V: R1, P: m_NonNegative()) &&
9513 ICmpInst::isImpliedByMatchingCmp(Pred1: SignedLPred, Pred2: RPred) == true)
9514 return true;
9515 }
9516
9517 // a - b == NonZero -> a != b
9518 // ptrtoint(a) - ptrtoint(b) == NonZero -> a != b
9519 const APInt *L1C;
9520 Value *A, *B;
9521 if (LPred == ICmpInst::ICMP_EQ && ICmpInst::isEquality(P: RPred) &&
9522 match(V: L1, P: m_APInt(Res&: L1C)) && !L1C->isZero() &&
9523 match(V: L0, P: m_Sub(L: m_Value(V&: A), R: m_Value(V&: B))) &&
9524 ((A == R0 && B == R1) || (A == R1 && B == R0) ||
9525 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0))) &&
9526 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1)))) ||
9527 (match(V: A, P: m_PtrToIntOrAddr(Op: m_Specific(V: R1))) &&
9528 match(V: B, P: m_PtrToIntOrAddr(Op: m_Specific(V: R0)))))) {
9529 return RPred.dropSameSign() == ICmpInst::ICMP_NE;
9530 }
9531
9532 // L0 = R0 = L1 + R1, L0 >=u L1 implies R0 >=u R1, L0 <u L1 implies R0 <u R1
9533 if (L0 == R0 &&
9534 (LPred == ICmpInst::ICMP_ULT || LPred == ICmpInst::ICMP_UGE) &&
9535 (RPred == ICmpInst::ICMP_ULT || RPred == ICmpInst::ICMP_UGE) &&
9536 match(V: L0, P: m_c_Add(L: m_Specific(V: L1), R: m_Specific(V: R1))))
9537 return CmpPredicate::getMatching(A: LPred, B: RPred).has_value();
9538
9539 if (auto P = CmpPredicate::getMatching(A: LPred, B: RPred))
9540 return isImpliedCondOperands(Pred: *P, ALHS: L0, ARHS: L1, BLHS: R0, BRHS: R1);
9541
9542 return std::nullopt;
9543}
9544
9545/// Return true if LHS implies RHS (expanded to its components as "R0 RPred R1")
9546/// is true. Return false if LHS implies RHS is false. Otherwise, return
9547/// std::nullopt if we can't infer anything.
9548static std::optional<bool>
9549isImpliedCondFCmps(FCmpInst::Predicate LPred, const Value *L0, const Value *L1,
9550 FCmpInst::Predicate RPred, const Value *R0, const Value *R1,
9551 const DataLayout &DL, bool LHSIsTrue) {
9552 // The rest of the logic assumes the LHS condition is true. If that's not the
9553 // case, invert the predicate to make it so.
9554 if (!LHSIsTrue)
9555 LPred = FCmpInst::getInversePredicate(pred: LPred);
9556
9557 // We can have non-canonical operands, so try to normalize any common operand
9558 // to L0/R0.
9559 if (L0 == R1) {
9560 std::swap(a&: R0, b&: R1);
9561 RPred = FCmpInst::getSwappedPredicate(pred: RPred);
9562 }
9563 if (R0 == L1) {
9564 std::swap(a&: L0, b&: L1);
9565 LPred = FCmpInst::getSwappedPredicate(pred: LPred);
9566 }
9567 if (L1 == R1) {
9568 // If we have L0 == R0 and L1 == R1, then make L1/R1 the constants.
9569 if (L0 != R0 || match(V: L0, P: m_ImmConstant())) {
9570 std::swap(a&: L0, b&: L1);
9571 LPred = ICmpInst::getSwappedCmpPredicate(Pred: LPred);
9572 std::swap(a&: R0, b&: R1);
9573 RPred = ICmpInst::getSwappedCmpPredicate(Pred: RPred);
9574 }
9575 }
9576
9577 // Can we infer anything when the two compares have matching operands?
9578 if (L0 == R0 && L1 == R1) {
9579 if ((LPred & RPred) == LPred)
9580 return true;
9581 if ((LPred & ~RPred) == LPred)
9582 return false;
9583 }
9584
9585 // See if we can infer anything if operand-0 matches and we have at least one
9586 // constant.
9587 const APFloat *L1C, *R1C;
9588 if (L0 == R0 && match(V: L1, P: m_APFloat(Res&: L1C)) && match(V: R1, P: m_APFloat(Res&: R1C))) {
9589 if (std::optional<ConstantFPRange> DomCR =
9590 ConstantFPRange::makeExactFCmpRegion(Pred: LPred, Other: *L1C)) {
9591 if (std::optional<ConstantFPRange> ImpliedCR =
9592 ConstantFPRange::makeExactFCmpRegion(Pred: RPred, Other: *R1C)) {
9593 if (ImpliedCR->contains(CR: *DomCR))
9594 return true;
9595 }
9596 if (std::optional<ConstantFPRange> ImpliedCR =
9597 ConstantFPRange::makeExactFCmpRegion(
9598 Pred: FCmpInst::getInversePredicate(pred: RPred), Other: *R1C)) {
9599 if (ImpliedCR->contains(CR: *DomCR))
9600 return false;
9601 }
9602 }
9603 }
9604
9605 return std::nullopt;
9606}
9607
9608/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is
9609/// false. Otherwise, return std::nullopt if we can't infer anything. We
9610/// expect the RHS to be an icmp and the LHS to be an 'and', 'or', or a 'select'
9611/// instruction.
9612static std::optional<bool>
9613isImpliedCondAndOr(const Instruction *LHS, CmpPredicate RHSPred,
9614 const Value *RHSOp0, const Value *RHSOp1,
9615 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9616 // The LHS must be an 'or', 'and', or a 'select' instruction.
9617 assert((LHS->getOpcode() == Instruction::And ||
9618 LHS->getOpcode() == Instruction::Or ||
9619 LHS->getOpcode() == Instruction::Select) &&
9620 "Expected LHS to be 'and', 'or', or 'select'.");
9621
9622 assert(Depth <= MaxAnalysisRecursionDepth && "Hit recursion limit");
9623
9624 // If the result of an 'or' is false, then we know both legs of the 'or' are
9625 // false. Similarly, if the result of an 'and' is true, then we know both
9626 // legs of the 'and' are true.
9627 const Value *ALHS, *ARHS;
9628 if ((!LHSIsTrue && match(V: LHS, P: m_LogicalOr(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS)))) ||
9629 (LHSIsTrue && match(V: LHS, P: m_LogicalAnd(L: m_Value(V&: ALHS), R: m_Value(V&: ARHS))))) {
9630 // FIXME: Make this non-recursion.
9631 if (std::optional<bool> Implication = isImpliedCondition(
9632 LHS: ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9633 return Implication;
9634 if (std::optional<bool> Implication = isImpliedCondition(
9635 LHS: ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth: Depth + 1))
9636 return Implication;
9637 return std::nullopt;
9638 }
9639 return std::nullopt;
9640}
9641
9642std::optional<bool>
9643llvm::isImpliedCondition(const Value *LHS, CmpPredicate RHSPred,
9644 const Value *RHSOp0, const Value *RHSOp1,
9645 const DataLayout &DL, bool LHSIsTrue, unsigned Depth) {
9646 // Bail out when we hit the limit.
9647 if (Depth == MaxAnalysisRecursionDepth)
9648 return std::nullopt;
9649
9650 // A mismatch occurs when we compare a scalar cmp to a vector cmp, for
9651 // example.
9652 if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy())
9653 return std::nullopt;
9654
9655 assert(LHS->getType()->isIntOrIntVectorTy(1) &&
9656 "Expected integer type only!");
9657
9658 // Match not
9659 if (match(V: LHS, P: m_Not(V: m_Value(V&: LHS))))
9660 LHSIsTrue = !LHSIsTrue;
9661
9662 // Both LHS and RHS are icmps.
9663 if (RHSOp0->getType()->getScalarType()->isIntOrPtrTy()) {
9664 if (const auto *LHSCmp = dyn_cast<ICmpInst>(Val: LHS))
9665 return isImpliedCondICmps(LPred: LHSCmp->getCmpPredicate(),
9666 L0: LHSCmp->getOperand(i_nocapture: 0), L1: LHSCmp->getOperand(i_nocapture: 1),
9667 RPred: RHSPred, R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9668 const Value *V;
9669 if (match(V: LHS, P: m_NUWTrunc(Op: m_Value(V))))
9670 return isImpliedCondICmps(LPred: CmpInst::ICMP_NE, L0: V,
9671 L1: ConstantInt::get(Ty: V->getType(), V: 0), RPred: RHSPred,
9672 R0: RHSOp0, R1: RHSOp1, DL, LHSIsTrue);
9673 } else {
9674 assert(RHSOp0->getType()->isFPOrFPVectorTy() &&
9675 "Expected floating point type only!");
9676 if (const auto *LHSCmp = dyn_cast<FCmpInst>(Val: LHS))
9677 return isImpliedCondFCmps(LPred: LHSCmp->getPredicate(), L0: LHSCmp->getOperand(i_nocapture: 0),
9678 L1: LHSCmp->getOperand(i_nocapture: 1), RPred: RHSPred, R0: RHSOp0, R1: RHSOp1,
9679 DL, LHSIsTrue);
9680 }
9681
9682 /// The LHS should be an 'or', 'and', or a 'select' instruction. We expect
9683 /// the RHS to be an icmp.
9684 /// FIXME: Add support for and/or/select on the RHS.
9685 if (const Instruction *LHSI = dyn_cast<Instruction>(Val: LHS)) {
9686 if ((LHSI->getOpcode() == Instruction::And ||
9687 LHSI->getOpcode() == Instruction::Or ||
9688 LHSI->getOpcode() == Instruction::Select))
9689 return isImpliedCondAndOr(LHS: LHSI, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue,
9690 Depth);
9691 }
9692 return std::nullopt;
9693}
9694
9695std::optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS,
9696 const DataLayout &DL,
9697 bool LHSIsTrue, unsigned Depth) {
9698 // LHS ==> RHS by definition
9699 if (LHS == RHS)
9700 return LHSIsTrue;
9701
9702 // Match not
9703 bool InvertRHS = false;
9704 if (match(V: RHS, P: m_Not(V: m_Value(V&: RHS)))) {
9705 if (LHS == RHS)
9706 return !LHSIsTrue;
9707 InvertRHS = true;
9708 }
9709
9710 if (const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(Val: RHS)) {
9711 if (auto Implied = isImpliedCondition(
9712 LHS, RHSPred: RHSCmp->getCmpPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9713 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9714 return InvertRHS ? !*Implied : *Implied;
9715 return std::nullopt;
9716 }
9717 if (const FCmpInst *RHSCmp = dyn_cast<FCmpInst>(Val: RHS)) {
9718 if (auto Implied = isImpliedCondition(
9719 LHS, RHSPred: RHSCmp->getPredicate(), RHSOp0: RHSCmp->getOperand(i_nocapture: 0),
9720 RHSOp1: RHSCmp->getOperand(i_nocapture: 1), DL, LHSIsTrue, Depth))
9721 return InvertRHS ? !*Implied : *Implied;
9722 return std::nullopt;
9723 }
9724
9725 const Value *V;
9726 if (match(V: RHS, P: m_NUWTrunc(Op: m_Value(V)))) {
9727 if (auto Implied = isImpliedCondition(LHS, RHSPred: CmpInst::ICMP_NE, RHSOp0: V,
9728 RHSOp1: ConstantInt::get(Ty: V->getType(), V: 0), DL,
9729 LHSIsTrue, Depth))
9730 return InvertRHS ? !*Implied : *Implied;
9731 return std::nullopt;
9732 }
9733
9734 if (Depth == MaxAnalysisRecursionDepth)
9735 return std::nullopt;
9736
9737 // LHS ==> (RHS1 || RHS2) if LHS ==> RHS1 or LHS ==> RHS2
9738 // LHS ==> !(RHS1 && RHS2) if LHS ==> !RHS1 or LHS ==> !RHS2
9739 const Value *RHS1, *RHS2;
9740 if (match(V: RHS, P: m_LogicalOr(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9741 if (std::optional<bool> Imp =
9742 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9743 if (*Imp == true)
9744 return !InvertRHS;
9745 if (std::optional<bool> Imp =
9746 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9747 if (*Imp == true)
9748 return !InvertRHS;
9749 }
9750 if (match(V: RHS, P: m_LogicalAnd(L: m_Value(V&: RHS1), R: m_Value(V&: RHS2)))) {
9751 if (std::optional<bool> Imp =
9752 isImpliedCondition(LHS, RHS: RHS1, DL, LHSIsTrue, Depth: Depth + 1))
9753 if (*Imp == false)
9754 return InvertRHS;
9755 if (std::optional<bool> Imp =
9756 isImpliedCondition(LHS, RHS: RHS2, DL, LHSIsTrue, Depth: Depth + 1))
9757 if (*Imp == false)
9758 return InvertRHS;
9759 }
9760
9761 return std::nullopt;
9762}
9763
9764// Returns a pair (Condition, ConditionIsTrue), where Condition is a branch
9765// condition dominating ContextI or nullptr, if no condition is found.
9766static std::pair<Value *, bool>
9767getDomPredecessorCondition(const Instruction *ContextI) {
9768 if (!ContextI || !ContextI->getParent())
9769 return {nullptr, false};
9770
9771 // TODO: This is a poor/cheap way to determine dominance. Should we use a
9772 // dominator tree (eg, from a SimplifyQuery) instead?
9773 const BasicBlock *ContextBB = ContextI->getParent();
9774 const BasicBlock *PredBB = ContextBB->getSinglePredecessor();
9775 if (!PredBB)
9776 return {nullptr, false};
9777
9778 // We need a conditional branch in the predecessor.
9779 Value *PredCond;
9780 BasicBlock *TrueBB, *FalseBB;
9781 if (!match(V: PredBB->getTerminator(), P: m_Br(C: m_Value(V&: PredCond), T&: TrueBB, F&: FalseBB)))
9782 return {nullptr, false};
9783
9784 // The branch should get simplified. Don't bother simplifying this condition.
9785 if (TrueBB == FalseBB)
9786 return {nullptr, false};
9787
9788 assert((TrueBB == ContextBB || FalseBB == ContextBB) &&
9789 "Predecessor block does not point to successor?");
9790
9791 // Is this condition implied by the predecessor condition?
9792 return {PredCond, TrueBB == ContextBB};
9793}
9794
9795std::optional<bool> llvm::isImpliedByDomCondition(const Value *Cond,
9796 const Instruction *ContextI,
9797 const DataLayout &DL) {
9798 assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool");
9799 auto PredCond = getDomPredecessorCondition(ContextI);
9800 if (PredCond.first)
9801 return isImpliedCondition(LHS: PredCond.first, RHS: Cond, DL, LHSIsTrue: PredCond.second);
9802 return std::nullopt;
9803}
9804
9805std::optional<bool> llvm::isImpliedByDomCondition(CmpPredicate Pred,
9806 const Value *LHS,
9807 const Value *RHS,
9808 const Instruction *ContextI,
9809 const DataLayout &DL) {
9810 auto PredCond = getDomPredecessorCondition(ContextI);
9811 if (PredCond.first)
9812 return isImpliedCondition(LHS: PredCond.first, RHSPred: Pred, RHSOp0: LHS, RHSOp1: RHS, DL,
9813 LHSIsTrue: PredCond.second);
9814 return std::nullopt;
9815}
9816
9817static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower,
9818 APInt &Upper, const InstrInfoQuery &IIQ,
9819 bool PreferSignedRange) {
9820 unsigned Width = Lower.getBitWidth();
9821 const APInt *C;
9822 switch (BO.getOpcode()) {
9823 case Instruction::Sub:
9824 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9825 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9826 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9827
9828 // If the caller expects a signed compare, then try to use a signed range.
9829 // Otherwise if both no-wraps are set, use the unsigned range because it
9830 // is never larger than the signed range. Example:
9831 // "sub nuw nsw i8 -2, x" is unsigned [0, 254] vs. signed [-128, 126].
9832 // "sub nuw nsw i8 2, x" is unsigned [0, 2] vs. signed [-125, 127].
9833 if (PreferSignedRange && HasNSW && HasNUW)
9834 HasNUW = false;
9835
9836 if (HasNUW) {
9837 // 'sub nuw c, x' produces [0, C].
9838 Upper = *C + 1;
9839 } else if (HasNSW) {
9840 if (C->isNegative()) {
9841 // 'sub nsw -C, x' produces [SINT_MIN, -C - SINT_MIN].
9842 Lower = APInt::getSignedMinValue(numBits: Width);
9843 Upper = *C - APInt::getSignedMaxValue(numBits: Width);
9844 } else {
9845 // Note that sub 0, INT_MIN is not NSW. It techically is a signed wrap
9846 // 'sub nsw C, x' produces [C - SINT_MAX, SINT_MAX].
9847 Lower = *C - APInt::getSignedMaxValue(numBits: Width);
9848 Upper = APInt::getSignedMinValue(numBits: Width);
9849 }
9850 }
9851 }
9852 break;
9853 case Instruction::Add:
9854 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
9855 bool HasNSW = IIQ.hasNoSignedWrap(Op: &BO);
9856 bool HasNUW = IIQ.hasNoUnsignedWrap(Op: &BO);
9857
9858 // If the caller expects a signed compare, then try to use a signed
9859 // range. Otherwise if both no-wraps are set, use the unsigned range
9860 // because it is never larger than the signed range. Example: "add nuw
9861 // nsw i8 X, -2" is unsigned [254,255] vs. signed [-128, 125].
9862 if (PreferSignedRange && HasNSW && HasNUW)
9863 HasNUW = false;
9864
9865 if (HasNUW) {
9866 // 'add nuw x, C' produces [C, UINT_MAX].
9867 Lower = *C;
9868 } else if (HasNSW) {
9869 if (C->isNegative()) {
9870 // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C].
9871 Lower = APInt::getSignedMinValue(numBits: Width);
9872 Upper = APInt::getSignedMaxValue(numBits: Width) + *C + 1;
9873 } else {
9874 // 'add nsw x, +C' produces [SINT_MIN + C, SINT_MAX].
9875 Lower = APInt::getSignedMinValue(numBits: Width) + *C;
9876 Upper = APInt::getSignedMaxValue(numBits: Width) + 1;
9877 }
9878 }
9879 }
9880 break;
9881
9882 case Instruction::And:
9883 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
9884 // 'and x, C' produces [0, C].
9885 Upper = *C + 1;
9886 // X & -X is a power of two or zero. So we can cap the value at max power of
9887 // two.
9888 if (match(V: BO.getOperand(i_nocapture: 0), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 1)))) ||
9889 match(V: BO.getOperand(i_nocapture: 1), P: m_Neg(V: m_Specific(V: BO.getOperand(i_nocapture: 0)))))
9890 Upper = APInt::getSignedMinValue(numBits: Width) + 1;
9891 break;
9892
9893 case Instruction::Or:
9894 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
9895 // 'or x, C' produces [C, UINT_MAX].
9896 Lower = *C;
9897 break;
9898
9899 case Instruction::AShr:
9900 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9901 // 'ashr x, C' produces [INT_MIN >> C, INT_MAX >> C].
9902 Lower = APInt::getSignedMinValue(numBits: Width).ashr(ShiftAmt: *C);
9903 Upper = APInt::getSignedMaxValue(numBits: Width).ashr(ShiftAmt: *C) + 1;
9904 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9905 unsigned ShiftAmount = Width - 1;
9906 if (!C->isZero() && IIQ.isExact(Op: &BO))
9907 ShiftAmount = C->countr_zero();
9908 if (C->isNegative()) {
9909 // 'ashr C, x' produces [C, C >> (Width-1)]
9910 Lower = *C;
9911 Upper = C->ashr(ShiftAmt: ShiftAmount) + 1;
9912 } else {
9913 // 'ashr C, x' produces [C >> (Width-1), C]
9914 Lower = C->ashr(ShiftAmt: ShiftAmount);
9915 Upper = *C + 1;
9916 }
9917 }
9918 break;
9919
9920 case Instruction::LShr:
9921 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9922 // 'lshr x, C' produces [0, UINT_MAX >> C].
9923 Upper = APInt::getAllOnes(numBits: Width).lshr(ShiftAmt: *C) + 1;
9924 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9925 // 'lshr C, x' produces [C >> (Width-1), C].
9926 unsigned ShiftAmount = Width - 1;
9927 if (!C->isZero() && IIQ.isExact(Op: &BO))
9928 ShiftAmount = C->countr_zero();
9929 Lower = C->lshr(shiftAmt: ShiftAmount);
9930 Upper = *C + 1;
9931 }
9932 break;
9933
9934 case Instruction::Shl:
9935 if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9936 if (IIQ.hasNoUnsignedWrap(Op: &BO)) {
9937 // 'shl nuw C, x' produces [C, C << CLZ(C)]
9938 Lower = *C;
9939 Upper = Lower.shl(shiftAmt: Lower.countl_zero()) + 1;
9940 } else if (BO.hasNoSignedWrap()) { // TODO: What if both nuw+nsw?
9941 if (C->isNegative()) {
9942 // 'shl nsw C, x' produces [C << CLO(C)-1, C]
9943 unsigned ShiftAmount = C->countl_one() - 1;
9944 Lower = C->shl(shiftAmt: ShiftAmount);
9945 Upper = *C + 1;
9946 } else {
9947 // 'shl nsw C, x' produces [C, C << CLZ(C)-1]
9948 unsigned ShiftAmount = C->countl_zero() - 1;
9949 Lower = *C;
9950 Upper = C->shl(shiftAmt: ShiftAmount) + 1;
9951 }
9952 } else {
9953 // If lowbit is set, value can never be zero.
9954 if ((*C)[0])
9955 Lower = APInt::getOneBitSet(numBits: Width, BitNo: 0);
9956 // If we are shifting a constant the largest it can be is if the longest
9957 // sequence of consecutive ones is shifted to the highbits (breaking
9958 // ties for which sequence is higher). At the moment we take a liberal
9959 // upper bound on this by just popcounting the constant.
9960 // TODO: There may be a bitwise trick for it longest/highest
9961 // consecutative sequence of ones (naive method is O(Width) loop).
9962 Upper = APInt::getHighBitsSet(numBits: Width, hiBitsSet: C->popcount()) + 1;
9963 }
9964 } else if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && C->ult(RHS: Width)) {
9965 Upper = APInt::getBitsSetFrom(numBits: Width, loBit: C->getZExtValue()) + 1;
9966 }
9967 break;
9968
9969 case Instruction::SDiv:
9970 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
9971 APInt IntMin = APInt::getSignedMinValue(numBits: Width);
9972 APInt IntMax = APInt::getSignedMaxValue(numBits: Width);
9973 if (C->isAllOnes()) {
9974 // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX]
9975 // where C != -1 and C != 0 and C != 1
9976 Lower = IntMin + 1;
9977 Upper = IntMax + 1;
9978 } else if (C->countl_zero() < Width - 1) {
9979 // 'sdiv x, C' produces [INT_MIN / C, INT_MAX / C]
9980 // where C != -1 and C != 0 and C != 1
9981 Lower = IntMin.sdiv(RHS: *C);
9982 Upper = IntMax.sdiv(RHS: *C);
9983 if (Lower.sgt(RHS: Upper))
9984 std::swap(a&: Lower, b&: Upper);
9985 Upper = Upper + 1;
9986 assert(Upper != Lower && "Upper part of range has wrapped!");
9987 }
9988 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
9989 if (C->isMinSignedValue()) {
9990 // 'sdiv INT_MIN, x' produces [INT_MIN, INT_MIN / -2].
9991 Lower = *C;
9992 Upper = Lower.lshr(shiftAmt: 1) + 1;
9993 } else {
9994 // 'sdiv C, x' produces [-|C|, |C|].
9995 Upper = C->abs() + 1;
9996 Lower = (-Upper) + 1;
9997 }
9998 }
9999 break;
10000
10001 case Instruction::UDiv:
10002 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)) && !C->isZero()) {
10003 // 'udiv x, C' produces [0, UINT_MAX / C].
10004 Upper = APInt::getMaxValue(numBits: Width).udiv(RHS: *C) + 1;
10005 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10006 // 'udiv C, x' produces [0, C].
10007 Upper = *C + 1;
10008 }
10009 break;
10010
10011 case Instruction::SRem:
10012 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10013 // 'srem x, C' produces (-|C|, |C|).
10014 Upper = C->abs();
10015 Lower = (-Upper) + 1;
10016 } else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10017 if (C->isNegative()) {
10018 // 'srem -|C|, x' produces [-|C|, 0].
10019 Upper = 1;
10020 Lower = *C;
10021 } else {
10022 // 'srem |C|, x' produces [0, |C|].
10023 Upper = *C + 1;
10024 }
10025 }
10026 break;
10027
10028 case Instruction::URem:
10029 if (match(V: BO.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10030 // 'urem x, C' produces [0, C).
10031 Upper = *C;
10032 else if (match(V: BO.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10033 // 'urem C, x' produces [0, C].
10034 Upper = *C + 1;
10035 break;
10036
10037 default:
10038 break;
10039 }
10040}
10041
10042static ConstantRange getRangeForIntrinsic(const IntrinsicInst &II,
10043 bool UseInstrInfo) {
10044 unsigned Width = II.getType()->getScalarSizeInBits();
10045 const APInt *C;
10046 switch (II.getIntrinsicID()) {
10047 case Intrinsic::ctlz:
10048 case Intrinsic::cttz: {
10049 APInt Upper(Width, Width);
10050 if (!UseInstrInfo || !match(V: II.getArgOperand(i: 1), P: m_One()))
10051 Upper += 1;
10052 // Maximum of set/clear bits is the bit width.
10053 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper);
10054 }
10055 case Intrinsic::ctpop:
10056 // Maximum of set/clear bits is the bit width.
10057 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10058 Upper: APInt(Width, Width) + 1);
10059 case Intrinsic::uadd_sat:
10060 // uadd.sat(x, C) produces [C, UINT_MAX].
10061 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10062 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10063 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10064 break;
10065 case Intrinsic::sadd_sat:
10066 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) ||
10067 match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10068 if (C->isNegative())
10069 // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)].
10070 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10071 Upper: APInt::getSignedMaxValue(numBits: Width) + *C +
10072 1);
10073
10074 // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX].
10075 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) + *C,
10076 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10077 }
10078 break;
10079 case Intrinsic::usub_sat:
10080 // usub.sat(C, x) produces [0, C].
10081 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)))
10082 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10083
10084 // usub.sat(x, C) produces [0, UINT_MAX - C].
10085 if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10086 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10087 Upper: APInt::getMaxValue(numBits: Width) - *C + 1);
10088 break;
10089 case Intrinsic::ssub_sat:
10090 if (match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C))) {
10091 if (C->isNegative())
10092 // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)].
10093 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10094 Upper: *C - APInt::getSignedMinValue(numBits: Width) +
10095 1);
10096
10097 // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX].
10098 return ConstantRange::getNonEmpty(Lower: *C - APInt::getSignedMaxValue(numBits: Width),
10099 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10100 } else if (match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C))) {
10101 if (C->isNegative())
10102 // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]:
10103 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width) - *C,
10104 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10105
10106 // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C].
10107 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10108 Upper: APInt::getSignedMaxValue(numBits: Width) - *C +
10109 1);
10110 }
10111 break;
10112 case Intrinsic::umin:
10113 case Intrinsic::umax:
10114 case Intrinsic::smin:
10115 case Intrinsic::smax:
10116 if (!match(V: II.getOperand(i_nocapture: 0), P: m_APInt(Res&: C)) &&
10117 !match(V: II.getOperand(i_nocapture: 1), P: m_APInt(Res&: C)))
10118 break;
10119
10120 switch (II.getIntrinsicID()) {
10121 case Intrinsic::umin:
10122 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width), Upper: *C + 1);
10123 case Intrinsic::umax:
10124 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: Width));
10125 case Intrinsic::smin:
10126 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: Width),
10127 Upper: *C + 1);
10128 case Intrinsic::smax:
10129 return ConstantRange::getNonEmpty(Lower: *C,
10130 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10131 default:
10132 llvm_unreachable("Must be min/max intrinsic");
10133 }
10134 break;
10135 case Intrinsic::abs:
10136 // If abs of SIGNED_MIN is poison, then the result is [0..SIGNED_MAX],
10137 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10138 if (match(V: II.getOperand(i_nocapture: 1), P: m_One()))
10139 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10140 Upper: APInt::getSignedMaxValue(numBits: Width) + 1);
10141
10142 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: Width),
10143 Upper: APInt::getSignedMinValue(numBits: Width) + 1);
10144 case Intrinsic::vscale:
10145 if (!II.getParent() || !II.getFunction())
10146 break;
10147 return getVScaleRange(F: II.getFunction(), BitWidth: Width);
10148 default:
10149 break;
10150 }
10151
10152 return ConstantRange::getFull(BitWidth: Width);
10153}
10154
10155static ConstantRange getRangeForSelectPattern(const SelectInst &SI,
10156 const InstrInfoQuery &IIQ) {
10157 unsigned BitWidth = SI.getType()->getScalarSizeInBits();
10158 const Value *LHS = nullptr, *RHS = nullptr;
10159 SelectPatternResult R = matchSelectPattern(V: &SI, LHS, RHS);
10160 if (R.Flavor == SPF_UNKNOWN)
10161 return ConstantRange::getFull(BitWidth);
10162
10163 if (R.Flavor == SelectPatternFlavor::SPF_ABS) {
10164 // If the negation part of the abs (in RHS) has the NSW flag,
10165 // then the result of abs(X) is [0..SIGNED_MAX],
10166 // otherwise it is [0..SIGNED_MIN], as -SIGNED_MIN == SIGNED_MIN.
10167 if (match(V: RHS, P: m_Neg(V: m_Specific(V: LHS))) &&
10168 IIQ.hasNoSignedWrap(Op: cast<Instruction>(Val: RHS)))
10169 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10170 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10171
10172 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth),
10173 Upper: APInt::getSignedMinValue(numBits: BitWidth) + 1);
10174 }
10175
10176 if (R.Flavor == SelectPatternFlavor::SPF_NABS) {
10177 // The result of -abs(X) is <= 0.
10178 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10179 Upper: APInt(BitWidth, 1));
10180 }
10181
10182 const APInt *C;
10183 if (!match(V: LHS, P: m_APInt(Res&: C)) && !match(V: RHS, P: m_APInt(Res&: C)))
10184 return ConstantRange::getFull(BitWidth);
10185
10186 switch (R.Flavor) {
10187 case SPF_UMIN:
10188 return ConstantRange::getNonEmpty(Lower: APInt::getZero(numBits: BitWidth), Upper: *C + 1);
10189 case SPF_UMAX:
10190 return ConstantRange::getNonEmpty(Lower: *C, Upper: APInt::getZero(numBits: BitWidth));
10191 case SPF_SMIN:
10192 return ConstantRange::getNonEmpty(Lower: APInt::getSignedMinValue(numBits: BitWidth),
10193 Upper: *C + 1);
10194 case SPF_SMAX:
10195 return ConstantRange::getNonEmpty(Lower: *C,
10196 Upper: APInt::getSignedMaxValue(numBits: BitWidth) + 1);
10197 default:
10198 return ConstantRange::getFull(BitWidth);
10199 }
10200}
10201
10202static void setLimitForFPToI(const Instruction *I, APInt &Lower, APInt &Upper) {
10203 // The maximum representable value of a half is 65504. For floats the maximum
10204 // value is 3.4e38 which requires roughly 129 bits.
10205 unsigned BitWidth = I->getType()->getScalarSizeInBits();
10206 if (!I->getOperand(i: 0)->getType()->getScalarType()->isHalfTy())
10207 return;
10208 if (isa<FPToSIInst>(Val: I) && BitWidth >= 17) {
10209 Lower = APInt(BitWidth, -65504, true);
10210 Upper = APInt(BitWidth, 65505);
10211 }
10212
10213 if (isa<FPToUIInst>(Val: I) && BitWidth >= 16) {
10214 // For a fptoui the lower limit is left as 0.
10215 Upper = APInt(BitWidth, 65505);
10216 }
10217}
10218
10219ConstantRange llvm::computeConstantRange(const Value *V, bool ForSigned,
10220 bool UseInstrInfo, AssumptionCache *AC,
10221 const Instruction *CtxI,
10222 const DominatorTree *DT,
10223 unsigned Depth) {
10224 assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction");
10225
10226 if (Depth == MaxAnalysisRecursionDepth)
10227 return ConstantRange::getFull(BitWidth: V->getType()->getScalarSizeInBits());
10228
10229 if (auto *C = dyn_cast<Constant>(Val: V))
10230 return C->toConstantRange();
10231
10232 unsigned BitWidth = V->getType()->getScalarSizeInBits();
10233 InstrInfoQuery IIQ(UseInstrInfo);
10234 ConstantRange CR = ConstantRange::getFull(BitWidth);
10235 if (auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10236 APInt Lower = APInt(BitWidth, 0);
10237 APInt Upper = APInt(BitWidth, 0);
10238 // TODO: Return ConstantRange.
10239 setLimitsForBinOp(BO: *BO, Lower, Upper, IIQ, PreferSignedRange: ForSigned);
10240 CR = ConstantRange::getNonEmpty(Lower, Upper);
10241 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: V))
10242 CR = getRangeForIntrinsic(II: *II, UseInstrInfo);
10243 else if (auto *SI = dyn_cast<SelectInst>(Val: V)) {
10244 ConstantRange CRTrue = computeConstantRange(
10245 V: SI->getTrueValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10246 ConstantRange CRFalse = computeConstantRange(
10247 V: SI->getFalseValue(), ForSigned, UseInstrInfo, AC, CtxI, DT, Depth: Depth + 1);
10248 CR = CRTrue.unionWith(CR: CRFalse);
10249 CR = CR.intersectWith(CR: getRangeForSelectPattern(SI: *SI, IIQ));
10250 } else if (isa<FPToUIInst>(Val: V) || isa<FPToSIInst>(Val: V)) {
10251 APInt Lower = APInt(BitWidth, 0);
10252 APInt Upper = APInt(BitWidth, 0);
10253 // TODO: Return ConstantRange.
10254 setLimitForFPToI(I: cast<Instruction>(Val: V), Lower, Upper);
10255 CR = ConstantRange::getNonEmpty(Lower, Upper);
10256 } else if (const auto *A = dyn_cast<Argument>(Val: V))
10257 if (std::optional<ConstantRange> Range = A->getRange())
10258 CR = *Range;
10259
10260 if (auto *I = dyn_cast<Instruction>(Val: V)) {
10261 if (auto *Range = IIQ.getMetadata(I, KindID: LLVMContext::MD_range))
10262 CR = CR.intersectWith(CR: getConstantRangeFromMetadata(RangeMD: *Range));
10263
10264 if (const auto *CB = dyn_cast<CallBase>(Val: V))
10265 if (std::optional<ConstantRange> Range = CB->getRange())
10266 CR = CR.intersectWith(CR: *Range);
10267 }
10268
10269 if (CtxI && AC) {
10270 // Try to restrict the range based on information from assumptions.
10271 for (auto &AssumeVH : AC->assumptionsFor(V)) {
10272 if (!AssumeVH)
10273 continue;
10274 CallInst *I = cast<CallInst>(Val&: AssumeVH);
10275 assert(I->getParent()->getParent() == CtxI->getParent()->getParent() &&
10276 "Got assumption for the wrong function!");
10277 assert(I->getIntrinsicID() == Intrinsic::assume &&
10278 "must be an assume intrinsic");
10279
10280 if (!isValidAssumeForContext(Inv: I, CxtI: CtxI, DT))
10281 continue;
10282 Value *Arg = I->getArgOperand(i: 0);
10283 ICmpInst *Cmp = dyn_cast<ICmpInst>(Val: Arg);
10284 // Currently we just use information from comparisons.
10285 if (!Cmp || Cmp->getOperand(i_nocapture: 0) != V)
10286 continue;
10287 // TODO: Set "ForSigned" parameter via Cmp->isSigned()?
10288 ConstantRange RHS =
10289 computeConstantRange(V: Cmp->getOperand(i_nocapture: 1), /* ForSigned */ false,
10290 UseInstrInfo, AC, CtxI: I, DT, Depth: Depth + 1);
10291 CR = CR.intersectWith(
10292 CR: ConstantRange::makeAllowedICmpRegion(Pred: Cmp->getPredicate(), Other: RHS));
10293 }
10294 }
10295
10296 return CR;
10297}
10298
10299static void
10300addValueAffectedByCondition(Value *V,
10301 function_ref<void(Value *)> InsertAffected) {
10302 assert(V != nullptr);
10303 if (isa<Argument>(Val: V) || isa<GlobalValue>(Val: V)) {
10304 InsertAffected(V);
10305 } else if (auto *I = dyn_cast<Instruction>(Val: V)) {
10306 InsertAffected(V);
10307
10308 // Peek through unary operators to find the source of the condition.
10309 Value *Op;
10310 if (match(V: I, P: m_CombineOr(L: m_PtrToIntOrAddr(Op: m_Value(V&: Op)),
10311 R: m_Trunc(Op: m_Value(V&: Op))))) {
10312 if (isa<Instruction>(Val: Op) || isa<Argument>(Val: Op))
10313 InsertAffected(Op);
10314 }
10315 }
10316}
10317
10318void llvm::findValuesAffectedByCondition(
10319 Value *Cond, bool IsAssume, function_ref<void(Value *)> InsertAffected) {
10320 auto AddAffected = [&InsertAffected](Value *V) {
10321 addValueAffectedByCondition(V, InsertAffected);
10322 };
10323
10324 auto AddCmpOperands = [&AddAffected, IsAssume](Value *LHS, Value *RHS) {
10325 if (IsAssume) {
10326 AddAffected(LHS);
10327 AddAffected(RHS);
10328 } else if (match(V: RHS, P: m_Constant()))
10329 AddAffected(LHS);
10330 };
10331
10332 SmallVector<Value *, 8> Worklist;
10333 SmallPtrSet<Value *, 8> Visited;
10334 Worklist.push_back(Elt: Cond);
10335 while (!Worklist.empty()) {
10336 Value *V = Worklist.pop_back_val();
10337 if (!Visited.insert(Ptr: V).second)
10338 continue;
10339
10340 CmpPredicate Pred;
10341 Value *A, *B, *X;
10342
10343 if (IsAssume) {
10344 AddAffected(V);
10345 if (match(V, P: m_Not(V: m_Value(V&: X))))
10346 AddAffected(X);
10347 }
10348
10349 if (match(V, P: m_LogicalOp(L: m_Value(V&: A), R: m_Value(V&: B)))) {
10350 // assume(A && B) is split to -> assume(A); assume(B);
10351 // assume(!(A || B)) is split to -> assume(!A); assume(!B);
10352 // Finally, assume(A || B) / assume(!(A && B)) generally don't provide
10353 // enough information to be worth handling (intersection of information as
10354 // opposed to union).
10355 if (!IsAssume) {
10356 Worklist.push_back(Elt: A);
10357 Worklist.push_back(Elt: B);
10358 }
10359 } else if (match(V, P: m_ICmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10360 bool HasRHSC = match(V: B, P: m_ConstantInt());
10361 if (ICmpInst::isEquality(P: Pred)) {
10362 AddAffected(A);
10363 if (IsAssume)
10364 AddAffected(B);
10365 if (HasRHSC) {
10366 Value *Y;
10367 // (X << C) or (X >>_s C) or (X >>_u C).
10368 if (match(V: A, P: m_Shift(L: m_Value(V&: X), R: m_ConstantInt())))
10369 AddAffected(X);
10370 // (X & C) or (X | C).
10371 else if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10372 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10373 AddAffected(X);
10374 AddAffected(Y);
10375 }
10376 // X - Y
10377 else if (match(V: A, P: m_Sub(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10378 AddAffected(X);
10379 AddAffected(Y);
10380 }
10381 }
10382 } else {
10383 AddCmpOperands(A, B);
10384 if (HasRHSC) {
10385 // Handle (A + C1) u< C2, which is the canonical form of
10386 // A > C3 && A < C4.
10387 if (match(V: A, P: m_AddLike(L: m_Value(V&: X), R: m_ConstantInt())))
10388 AddAffected(X);
10389
10390 if (ICmpInst::isUnsigned(predicate: Pred)) {
10391 Value *Y;
10392 // X & Y u> C -> X >u C && Y >u C
10393 // X | Y u< C -> X u< C && Y u< C
10394 // X nuw+ Y u< C -> X u< C && Y u< C
10395 if (match(V: A, P: m_And(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10396 match(V: A, P: m_Or(L: m_Value(V&: X), R: m_Value(V&: Y))) ||
10397 match(V: A, P: m_NUWAdd(L: m_Value(V&: X), R: m_Value(V&: Y)))) {
10398 AddAffected(X);
10399 AddAffected(Y);
10400 }
10401 // X nuw- Y u> C -> X u> C
10402 if (match(V: A, P: m_NUWSub(L: m_Value(V&: X), R: m_Value())))
10403 AddAffected(X);
10404 }
10405 }
10406
10407 // Handle icmp slt/sgt (bitcast X to int), 0/-1, which is supported
10408 // by computeKnownFPClass().
10409 if (match(V: A, P: m_ElementWiseBitCast(Op: m_Value(V&: X)))) {
10410 if (Pred == ICmpInst::ICMP_SLT && match(V: B, P: m_Zero()))
10411 InsertAffected(X);
10412 else if (Pred == ICmpInst::ICMP_SGT && match(V: B, P: m_AllOnes()))
10413 InsertAffected(X);
10414 }
10415 }
10416
10417 if (HasRHSC && match(V: A, P: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Value(V&: X))))
10418 AddAffected(X);
10419 } else if (match(V, P: m_FCmp(Pred, L: m_Value(V&: A), R: m_Value(V&: B)))) {
10420 AddCmpOperands(A, B);
10421
10422 // fcmp fneg(x), y
10423 // fcmp fabs(x), y
10424 // fcmp fneg(fabs(x)), y
10425 if (match(V: A, P: m_FNeg(X: m_Value(V&: A))))
10426 AddAffected(A);
10427 if (match(V: A, P: m_FAbs(Op0: m_Value(V&: A))))
10428 AddAffected(A);
10429
10430 } else if (match(V, P: m_Intrinsic<Intrinsic::is_fpclass>(Op0: m_Value(V&: A),
10431 Op1: m_Value()))) {
10432 // Handle patterns that computeKnownFPClass() support.
10433 AddAffected(A);
10434 } else if (!IsAssume && match(V, P: m_Trunc(Op: m_Value(V&: X)))) {
10435 // Assume is checked here as X is already added above for assumes in
10436 // addValueAffectedByCondition
10437 AddAffected(X);
10438 } else if (!IsAssume && match(V, P: m_Not(V: m_Value(V&: X)))) {
10439 // Assume is checked here to avoid issues with ephemeral values
10440 Worklist.push_back(Elt: X);
10441 }
10442 }
10443}
10444
10445const Value *llvm::stripNullTest(const Value *V) {
10446 // (X >> C) or/add (X & mask(C) != 0)
10447 if (const auto *BO = dyn_cast<BinaryOperator>(Val: V)) {
10448 if (BO->getOpcode() == Instruction::Add ||
10449 BO->getOpcode() == Instruction::Or) {
10450 const Value *X;
10451 const APInt *C1, *C2;
10452 if (match(V: BO, P: m_c_BinOp(L: m_LShr(L: m_Value(V&: X), R: m_APInt(Res&: C1)),
10453 R: m_ZExt(Op: m_SpecificICmp(
10454 MatchPred: ICmpInst::ICMP_NE,
10455 L: m_And(L: m_Deferred(V: X), R: m_LowBitMask(V&: C2)),
10456 R: m_Zero())))) &&
10457 C2->popcount() == C1->getZExtValue())
10458 return X;
10459 }
10460 }
10461 return nullptr;
10462}
10463
10464Value *llvm::stripNullTest(Value *V) {
10465 return const_cast<Value *>(stripNullTest(V: const_cast<const Value *>(V)));
10466}
10467
10468bool llvm::collectPossibleValues(const Value *V,
10469 SmallPtrSetImpl<const Constant *> &Constants,
10470 unsigned MaxCount, bool AllowUndefOrPoison) {
10471 SmallPtrSet<const Instruction *, 8> Visited;
10472 SmallVector<const Instruction *, 8> Worklist;
10473 auto Push = [&](const Value *V) -> bool {
10474 Constant *C;
10475 if (match(V: const_cast<Value *>(V), P: m_ImmConstant(C))) {
10476 if (!AllowUndefOrPoison && !isGuaranteedNotToBeUndefOrPoison(V: C))
10477 return false;
10478 // Check existence first to avoid unnecessary allocations.
10479 if (Constants.contains(Ptr: C))
10480 return true;
10481 if (Constants.size() == MaxCount)
10482 return false;
10483 Constants.insert(Ptr: C);
10484 return true;
10485 }
10486
10487 if (auto *Inst = dyn_cast<Instruction>(Val: V)) {
10488 if (Visited.insert(Ptr: Inst).second)
10489 Worklist.push_back(Elt: Inst);
10490 return true;
10491 }
10492 return false;
10493 };
10494 if (!Push(V))
10495 return false;
10496 while (!Worklist.empty()) {
10497 const Instruction *CurInst = Worklist.pop_back_val();
10498 switch (CurInst->getOpcode()) {
10499 case Instruction::Select:
10500 if (!Push(CurInst->getOperand(i: 1)))
10501 return false;
10502 if (!Push(CurInst->getOperand(i: 2)))
10503 return false;
10504 break;
10505 case Instruction::PHI:
10506 for (Value *IncomingValue : cast<PHINode>(Val: CurInst)->incoming_values()) {
10507 // Fast path for recurrence PHI.
10508 if (IncomingValue == CurInst)
10509 continue;
10510 if (!Push(IncomingValue))
10511 return false;
10512 }
10513 break;
10514 default:
10515 return false;
10516 }
10517 }
10518 return true;
10519}
10520