1//===- InstCombineSimplifyDemanded.cpp ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains logic for simplifying instructions based on information
10// about how they are used.
11//
12//===----------------------------------------------------------------------===//
13
14#include "InstCombineInternal.h"
15#include "llvm/Analysis/ValueTracking.h"
16#include "llvm/IR/GetElementPtrTypeIterator.h"
17#include "llvm/IR/IntrinsicInst.h"
18#include "llvm/IR/PatternMatch.h"
19#include "llvm/Support/KnownBits.h"
20#include "llvm/Transforms/InstCombine/InstCombiner.h"
21
22using namespace llvm;
23using namespace llvm::PatternMatch;
24
25#define DEBUG_TYPE "instcombine"
26
27static cl::opt<bool>
28 VerifyKnownBits("instcombine-verify-known-bits",
29 cl::desc("Verify that computeKnownBits() and "
30 "SimplifyDemandedBits() are consistent"),
31 cl::Hidden, cl::init(Val: false));
32
33static cl::opt<unsigned> SimplifyDemandedVectorEltsDepthLimit(
34 "instcombine-simplify-vector-elts-depth",
35 cl::desc(
36 "Depth limit when simplifying vector instructions and their operands"),
37 cl::Hidden, cl::init(Val: 10));
38
39/// Check to see if the specified operand of the specified instruction is a
40/// constant integer. If so, check to see if there are any bits set in the
41/// constant that are not demanded. If so, shrink the constant and return true.
42static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
43 const APInt &Demanded) {
44 assert(I && "No instruction?");
45 assert(OpNo < I->getNumOperands() && "Operand index too large");
46
47 // The operand must be a constant integer or splat integer.
48 Value *Op = I->getOperand(i: OpNo);
49 const APInt *C;
50 if (!match(V: Op, P: m_APInt(Res&: C)))
51 return false;
52
53 // If there are no bits set that aren't demanded, nothing to do.
54 if (C->isSubsetOf(RHS: Demanded))
55 return false;
56
57 // This instruction is producing bits that are not demanded. Shrink the RHS.
58 I->setOperand(i: OpNo, Val: ConstantInt::get(Ty: Op->getType(), V: *C & Demanded));
59
60 return true;
61}
62
63/// Returns the bitwidth of the given scalar or pointer type. For vector types,
64/// returns the element type's bitwidth.
65static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
66 if (unsigned BitWidth = Ty->getScalarSizeInBits())
67 return BitWidth;
68
69 return DL.getPointerTypeSizeInBits(Ty);
70}
71
72/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
73/// the instruction has any properties that allow us to simplify its operands.
74bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst,
75 KnownBits &Known) {
76 APInt DemandedMask(APInt::getAllOnes(numBits: Known.getBitWidth()));
77 Value *V = SimplifyDemandedUseBits(I: &Inst, DemandedMask, Known,
78 Q: SQ.getWithInstruction(I: &Inst));
79 if (!V) return false;
80 if (V == &Inst) return true;
81 replaceInstUsesWith(I&: Inst, V);
82 return true;
83}
84
85/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if
86/// the instruction has any properties that allow us to simplify its operands.
87bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) {
88 KnownBits Known(getBitWidth(Ty: Inst.getType(), DL));
89 return SimplifyDemandedInstructionBits(Inst, Known);
90}
91
92/// This form of SimplifyDemandedBits simplifies the specified instruction
93/// operand if possible, updating it in place. It returns true if it made any
94/// change and false otherwise.
95bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
96 const APInt &DemandedMask,
97 KnownBits &Known,
98 const SimplifyQuery &Q,
99 unsigned Depth) {
100 Use &U = I->getOperandUse(i: OpNo);
101 Value *V = U.get();
102 if (isa<Constant>(Val: V)) {
103 llvm::computeKnownBits(V, Known, Q, Depth);
104 return false;
105 }
106
107 Known.resetAll();
108 if (DemandedMask.isZero()) {
109 // Not demanding any bits from V.
110 replaceUse(U, NewValue: UndefValue::get(T: V->getType()));
111 return true;
112 }
113
114 Instruction *VInst = dyn_cast<Instruction>(Val: V);
115 if (!VInst) {
116 llvm::computeKnownBits(V, Known, Q, Depth);
117 return false;
118 }
119
120 if (Depth == MaxAnalysisRecursionDepth)
121 return false;
122
123 Value *NewVal;
124 if (VInst->hasOneUse()) {
125 // If the instruction has one use, we can directly simplify it.
126 NewVal = SimplifyDemandedUseBits(I: VInst, DemandedMask, Known, Q, Depth);
127 } else {
128 // If there are multiple uses of this instruction, then we can simplify
129 // VInst to some other value, but not modify the instruction.
130 NewVal =
131 SimplifyMultipleUseDemandedBits(I: VInst, DemandedMask, Known, Q, Depth);
132 }
133 if (!NewVal) return false;
134 if (Instruction* OpInst = dyn_cast<Instruction>(Val&: U))
135 salvageDebugInfo(I&: *OpInst);
136
137 replaceUse(U, NewValue: NewVal);
138 return true;
139}
140
141/// This function attempts to replace V with a simpler value based on the
142/// demanded bits. When this function is called, it is known that only the bits
143/// set in DemandedMask of the result of V are ever used downstream.
144/// Consequently, depending on the mask and V, it may be possible to replace V
145/// with a constant or one of its operands. In such cases, this function does
146/// the replacement and returns true. In all other cases, it returns false after
147/// analyzing the expression and setting KnownOne and known to be one in the
148/// expression. Known.Zero contains all the bits that are known to be zero in
149/// the expression. These are provided to potentially allow the caller (which
150/// might recursively be SimplifyDemandedBits itself) to simplify the
151/// expression.
152/// Known.One and Known.Zero always follow the invariant that:
153/// Known.One & Known.Zero == 0.
154/// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero
155/// are accurate even for bits not in DemandedMask. Note
156/// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all
157/// be the same.
158///
159/// This returns null if it did not change anything and it permits no
160/// simplification. This returns V itself if it did some simplification of V's
161/// operands based on the information about what bits are demanded. This returns
162/// some other non-null value if it found out that V is equal to another value
163/// in the context where the specified bits are demanded, but not for all users.
164Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
165 const APInt &DemandedMask,
166 KnownBits &Known,
167 const SimplifyQuery &Q,
168 unsigned Depth) {
169 assert(I != nullptr && "Null pointer of Value???");
170 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
171 uint32_t BitWidth = DemandedMask.getBitWidth();
172 Type *VTy = I->getType();
173 assert(
174 (!VTy->isIntOrIntVectorTy() || VTy->getScalarSizeInBits() == BitWidth) &&
175 Known.getBitWidth() == BitWidth &&
176 "Value *V, DemandedMask and Known must have same BitWidth");
177
178 KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth);
179
180 // Update flags after simplifying an operand based on the fact that some high
181 // order bits are not demanded.
182 auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I,
183 unsigned NLZ) {
184 if (NLZ > 0) {
185 // Disable the nsw and nuw flags here: We can no longer guarantee that
186 // we won't wrap after simplification. Removing the nsw/nuw flags is
187 // legal here because the top bit is not demanded.
188 I->setHasNoSignedWrap(false);
189 I->setHasNoUnsignedWrap(false);
190 }
191 return I;
192 };
193
194 // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
195 // about the high bits of the operands.
196 auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
197 unsigned NLZ = DemandedMask.countl_zero();
198 // Right fill the mask of bits for the operands to demand the most
199 // significant bit and all those below it.
200 DemandedFromOps = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - NLZ);
201 if (ShrinkDemandedConstant(I, OpNo: 0, Demanded: DemandedFromOps) ||
202 SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedFromOps, Known&: LHSKnown, Q, Depth: Depth + 1) ||
203 ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedFromOps) ||
204 SimplifyDemandedBits(I, OpNo: 1, DemandedMask: DemandedFromOps, Known&: RHSKnown, Q, Depth: Depth + 1)) {
205 disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
206 return true;
207 }
208 return false;
209 };
210
211 switch (I->getOpcode()) {
212 default:
213 llvm::computeKnownBits(V: I, Known, Q, Depth);
214 break;
215 case Instruction::And: {
216 // If either the LHS or the RHS are Zero, the result is zero.
217 if (SimplifyDemandedBits(I, OpNo: 1, DemandedMask, Known&: RHSKnown, Q, Depth: Depth + 1) ||
218 SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMask & ~RHSKnown.Zero, Known&: LHSKnown, Q,
219 Depth: Depth + 1))
220 return I;
221
222 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
223 SQ: Q, Depth);
224
225 // If the client is only demanding bits that we know, return the known
226 // constant.
227 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
228 return Constant::getIntegerValue(Ty: VTy, V: Known.One);
229
230 // If all of the demanded bits are known 1 on one side, return the other.
231 // These bits cannot contribute to the result of the 'and'.
232 if (DemandedMask.isSubsetOf(RHS: LHSKnown.Zero | RHSKnown.One))
233 return I->getOperand(i: 0);
234 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero | LHSKnown.One))
235 return I->getOperand(i: 1);
236
237 // If the RHS is a constant, see if we can simplify it.
238 if (ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedMask & ~LHSKnown.Zero))
239 return I;
240
241 break;
242 }
243 case Instruction::Or: {
244 // If either the LHS or the RHS are One, the result is One.
245 if (SimplifyDemandedBits(I, OpNo: 1, DemandedMask, Known&: RHSKnown, Q, Depth: Depth + 1) ||
246 SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMask & ~RHSKnown.One, Known&: LHSKnown, Q,
247 Depth: Depth + 1)) {
248 // Disjoint flag may not longer hold.
249 I->dropPoisonGeneratingFlags();
250 return I;
251 }
252
253 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
254 SQ: Q, Depth);
255
256 // If the client is only demanding bits that we know, return the known
257 // constant.
258 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
259 return Constant::getIntegerValue(Ty: VTy, V: Known.One);
260
261 // If all of the demanded bits are known zero on one side, return the other.
262 // These bits cannot contribute to the result of the 'or'.
263 if (DemandedMask.isSubsetOf(RHS: LHSKnown.One | RHSKnown.Zero))
264 return I->getOperand(i: 0);
265 if (DemandedMask.isSubsetOf(RHS: RHSKnown.One | LHSKnown.Zero))
266 return I->getOperand(i: 1);
267
268 // If the RHS is a constant, see if we can simplify it.
269 if (ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedMask))
270 return I;
271
272 // Infer disjoint flag if no common bits are set.
273 if (!cast<PossiblyDisjointInst>(Val: I)->isDisjoint()) {
274 WithCache<const Value *> LHSCache(I->getOperand(i: 0), LHSKnown),
275 RHSCache(I->getOperand(i: 1), RHSKnown);
276 if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ: Q)) {
277 cast<PossiblyDisjointInst>(Val: I)->setIsDisjoint(true);
278 return I;
279 }
280 }
281
282 break;
283 }
284 case Instruction::Xor: {
285 if (SimplifyDemandedBits(I, OpNo: 1, DemandedMask, Known&: RHSKnown, Q, Depth: Depth + 1) ||
286 SimplifyDemandedBits(I, OpNo: 0, DemandedMask, Known&: LHSKnown, Q, Depth: Depth + 1))
287 return I;
288 Value *LHS, *RHS;
289 if (DemandedMask == 1 &&
290 match(V: I->getOperand(i: 0), P: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Value(V&: LHS))) &&
291 match(V: I->getOperand(i: 1), P: m_Intrinsic<Intrinsic::ctpop>(Op0: m_Value(V&: RHS)))) {
292 // (ctpop(X) ^ ctpop(Y)) & 1 --> ctpop(X^Y) & 1
293 IRBuilderBase::InsertPointGuard Guard(Builder);
294 Builder.SetInsertPoint(I);
295 auto *Xor = Builder.CreateXor(LHS, RHS);
296 return Builder.CreateUnaryIntrinsic(ID: Intrinsic::ctpop, V: Xor);
297 }
298
299 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
300 SQ: Q, Depth);
301
302 // If the client is only demanding bits that we know, return the known
303 // constant.
304 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
305 return Constant::getIntegerValue(Ty: VTy, V: Known.One);
306
307 // If all of the demanded bits are known zero on one side, return the other.
308 // These bits cannot contribute to the result of the 'xor'.
309 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero))
310 return I->getOperand(i: 0);
311 if (DemandedMask.isSubsetOf(RHS: LHSKnown.Zero))
312 return I->getOperand(i: 1);
313
314 // If all of the demanded bits are known to be zero on one side or the
315 // other, turn this into an *inclusive* or.
316 // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0
317 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero | LHSKnown.Zero)) {
318 Instruction *Or =
319 BinaryOperator::CreateOr(V1: I->getOperand(i: 0), V2: I->getOperand(i: 1));
320 if (DemandedMask.isAllOnes())
321 cast<PossiblyDisjointInst>(Val: Or)->setIsDisjoint(true);
322 Or->takeName(V: I);
323 return InsertNewInstWith(New: Or, Old: I->getIterator());
324 }
325
326 // If all of the demanded bits on one side are known, and all of the set
327 // bits on that side are also known to be set on the other side, turn this
328 // into an AND, as we know the bits will be cleared.
329 // e.g. (X | C1) ^ C2 --> (X | C1) & ~C2 iff (C1&C2) == C2
330 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero|RHSKnown.One) &&
331 RHSKnown.One.isSubsetOf(RHS: LHSKnown.One)) {
332 Constant *AndC = Constant::getIntegerValue(Ty: VTy,
333 V: ~RHSKnown.One & DemandedMask);
334 Instruction *And = BinaryOperator::CreateAnd(V1: I->getOperand(i: 0), V2: AndC);
335 return InsertNewInstWith(New: And, Old: I->getIterator());
336 }
337
338 // If the RHS is a constant, see if we can change it. Don't alter a -1
339 // constant because that's a canonical 'not' op, and that is better for
340 // combining, SCEV, and codegen.
341 const APInt *C;
342 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: C)) && !C->isAllOnes()) {
343 if ((*C | ~DemandedMask).isAllOnes()) {
344 // Force bits to 1 to create a 'not' op.
345 I->setOperand(i: 1, Val: ConstantInt::getAllOnesValue(Ty: VTy));
346 return I;
347 }
348 // If we can't turn this into a 'not', try to shrink the constant.
349 if (ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedMask))
350 return I;
351 }
352
353 // If our LHS is an 'and' and if it has one use, and if any of the bits we
354 // are flipping are known to be set, then the xor is just resetting those
355 // bits to zero. We can just knock out bits from the 'and' and the 'xor',
356 // simplifying both of them.
357 if (Instruction *LHSInst = dyn_cast<Instruction>(Val: I->getOperand(i: 0))) {
358 ConstantInt *AndRHS, *XorRHS;
359 if (LHSInst->getOpcode() == Instruction::And && LHSInst->hasOneUse() &&
360 match(V: I->getOperand(i: 1), P: m_ConstantInt(CI&: XorRHS)) &&
361 match(V: LHSInst->getOperand(i: 1), P: m_ConstantInt(CI&: AndRHS)) &&
362 (LHSKnown.One & RHSKnown.One & DemandedMask) != 0) {
363 APInt NewMask = ~(LHSKnown.One & RHSKnown.One & DemandedMask);
364
365 Constant *AndC = ConstantInt::get(Ty: VTy, V: NewMask & AndRHS->getValue());
366 Instruction *NewAnd = BinaryOperator::CreateAnd(V1: I->getOperand(i: 0), V2: AndC);
367 InsertNewInstWith(New: NewAnd, Old: I->getIterator());
368
369 Constant *XorC = ConstantInt::get(Ty: VTy, V: NewMask & XorRHS->getValue());
370 Instruction *NewXor = BinaryOperator::CreateXor(V1: NewAnd, V2: XorC);
371 return InsertNewInstWith(New: NewXor, Old: I->getIterator());
372 }
373 }
374 break;
375 }
376 case Instruction::Select: {
377 if (SimplifyDemandedBits(I, OpNo: 2, DemandedMask, Known&: RHSKnown, Q, Depth: Depth + 1) ||
378 SimplifyDemandedBits(I, OpNo: 1, DemandedMask, Known&: LHSKnown, Q, Depth: Depth + 1))
379 return I;
380
381 // If the operands are constants, see if we can simplify them.
382 // This is similar to ShrinkDemandedConstant, but for a select we want to
383 // try to keep the selected constants the same as icmp value constants, if
384 // we can. This helps not break apart (or helps put back together)
385 // canonical patterns like min and max.
386 auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo,
387 const APInt &DemandedMask) {
388 const APInt *SelC;
389 if (!match(V: I->getOperand(i: OpNo), P: m_APInt(Res&: SelC)))
390 return false;
391
392 // Get the constant out of the ICmp, if there is one.
393 // Only try this when exactly 1 operand is a constant (if both operands
394 // are constant, the icmp should eventually simplify). Otherwise, we may
395 // invert the transform that reduces set bits and infinite-loop.
396 Value *X;
397 const APInt *CmpC;
398 if (!match(V: I->getOperand(i: 0), P: m_ICmp(L: m_Value(V&: X), R: m_APInt(Res&: CmpC))) ||
399 isa<Constant>(Val: X) || CmpC->getBitWidth() != SelC->getBitWidth())
400 return ShrinkDemandedConstant(I, OpNo, Demanded: DemandedMask);
401
402 // If the constant is already the same as the ICmp, leave it as-is.
403 if (*CmpC == *SelC)
404 return false;
405 // If the constants are not already the same, but can be with the demand
406 // mask, use the constant value from the ICmp.
407 if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) {
408 I->setOperand(i: OpNo, Val: ConstantInt::get(Ty: I->getType(), V: *CmpC));
409 return true;
410 }
411 return ShrinkDemandedConstant(I, OpNo, Demanded: DemandedMask);
412 };
413 if (CanonicalizeSelectConstant(I, 1, DemandedMask) ||
414 CanonicalizeSelectConstant(I, 2, DemandedMask))
415 return I;
416
417 // Only known if known in both the LHS and RHS.
418 adjustKnownBitsForSelectArm(Known&: LHSKnown, Cond: I->getOperand(i: 0), Arm: I->getOperand(i: 1),
419 /*Invert=*/false, Q, Depth);
420 adjustKnownBitsForSelectArm(Known&: RHSKnown, Cond: I->getOperand(i: 0), Arm: I->getOperand(i: 2),
421 /*Invert=*/true, Q, Depth);
422 Known = LHSKnown.intersectWith(RHS: RHSKnown);
423 break;
424 }
425 case Instruction::Trunc: {
426 // If we do not demand the high bits of a right-shifted and truncated value,
427 // then we may be able to truncate it before the shift.
428 Value *X;
429 const APInt *C;
430 if (match(V: I->getOperand(i: 0), P: m_OneUse(SubPattern: m_LShr(L: m_Value(V&: X), R: m_APInt(Res&: C))))) {
431 // The shift amount must be valid (not poison) in the narrow type, and
432 // it must not be greater than the high bits demanded of the result.
433 if (C->ult(RHS: VTy->getScalarSizeInBits()) &&
434 C->ule(RHS: DemandedMask.countl_zero())) {
435 // trunc (lshr X, C) --> lshr (trunc X), C
436 IRBuilderBase::InsertPointGuard Guard(Builder);
437 Builder.SetInsertPoint(I);
438 Value *Trunc = Builder.CreateTrunc(V: X, DestTy: VTy);
439 return Builder.CreateLShr(LHS: Trunc, RHS: C->getZExtValue());
440 }
441 }
442 }
443 [[fallthrough]];
444 case Instruction::ZExt: {
445 unsigned SrcBitWidth = I->getOperand(i: 0)->getType()->getScalarSizeInBits();
446
447 APInt InputDemandedMask = DemandedMask.zextOrTrunc(width: SrcBitWidth);
448 KnownBits InputKnown(SrcBitWidth);
449 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: InputDemandedMask, Known&: InputKnown, Q,
450 Depth: Depth + 1)) {
451 // For zext nneg, we may have dropped the instruction which made the
452 // input non-negative.
453 I->dropPoisonGeneratingFlags();
454 return I;
455 }
456 assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?");
457 if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() &&
458 !InputKnown.isNegative())
459 InputKnown.makeNonNegative();
460 Known = InputKnown.zextOrTrunc(BitWidth);
461
462 break;
463 }
464 case Instruction::SExt: {
465 // Compute the bits in the result that are not present in the input.
466 unsigned SrcBitWidth = I->getOperand(i: 0)->getType()->getScalarSizeInBits();
467
468 APInt InputDemandedBits = DemandedMask.trunc(width: SrcBitWidth);
469
470 // If any of the sign extended bits are demanded, we know that the sign
471 // bit is demanded.
472 if (DemandedMask.getActiveBits() > SrcBitWidth)
473 InputDemandedBits.setBit(SrcBitWidth-1);
474
475 KnownBits InputKnown(SrcBitWidth);
476 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: InputDemandedBits, Known&: InputKnown, Q, Depth: Depth + 1))
477 return I;
478
479 // If the input sign bit is known zero, or if the NewBits are not demanded
480 // convert this into a zero extension.
481 if (InputKnown.isNonNegative() ||
482 DemandedMask.getActiveBits() <= SrcBitWidth) {
483 // Convert to ZExt cast.
484 CastInst *NewCast = new ZExtInst(I->getOperand(i: 0), VTy);
485 NewCast->takeName(V: I);
486 return InsertNewInstWith(New: NewCast, Old: I->getIterator());
487 }
488
489 // If the sign bit of the input is known set or clear, then we know the
490 // top bits of the result.
491 Known = InputKnown.sext(BitWidth);
492 break;
493 }
494 case Instruction::Add: {
495 if ((DemandedMask & 1) == 0) {
496 // If we do not need the low bit, try to convert bool math to logic:
497 // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
498 Value *X, *Y;
499 if (match(V: I, P: m_c_Add(L: m_OneUse(SubPattern: m_ZExt(Op: m_Value(V&: X))),
500 R: m_OneUse(SubPattern: m_SExt(Op: m_Value(V&: Y))))) &&
501 X->getType()->isIntOrIntVectorTy(BitWidth: 1) && X->getType() == Y->getType()) {
502 // Truth table for inputs and output signbits:
503 // X:0 | X:1
504 // ----------
505 // Y:0 | 0 | 0 |
506 // Y:1 | -1 | 0 |
507 // ----------
508 IRBuilderBase::InsertPointGuard Guard(Builder);
509 Builder.SetInsertPoint(I);
510 Value *AndNot = Builder.CreateAnd(LHS: Builder.CreateNot(V: X), RHS: Y);
511 return Builder.CreateSExt(V: AndNot, DestTy: VTy);
512 }
513
514 // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN
515 if (match(V: I, P: m_Add(L: m_SExt(Op: m_Value(V&: X)), R: m_SExt(Op: m_Value(V&: Y)))) &&
516 X->getType()->isIntOrIntVectorTy(BitWidth: 1) && X->getType() == Y->getType() &&
517 (I->getOperand(i: 0)->hasOneUse() || I->getOperand(i: 1)->hasOneUse())) {
518
519 // Truth table for inputs and output signbits:
520 // X:0 | X:1
521 // -----------
522 // Y:0 | -1 | -1 |
523 // Y:1 | -1 | 0 |
524 // -----------
525 IRBuilderBase::InsertPointGuard Guard(Builder);
526 Builder.SetInsertPoint(I);
527 Value *Or = Builder.CreateOr(LHS: X, RHS: Y);
528 return Builder.CreateSExt(V: Or, DestTy: VTy);
529 }
530 }
531
532 // Right fill the mask of bits for the operands to demand the most
533 // significant bit and all those below it.
534 unsigned NLZ = DemandedMask.countl_zero();
535 APInt DemandedFromOps = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - NLZ);
536 if (ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedFromOps) ||
537 SimplifyDemandedBits(I, OpNo: 1, DemandedMask: DemandedFromOps, Known&: RHSKnown, Q, Depth: Depth + 1))
538 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
539
540 // If low order bits are not demanded and known to be zero in one operand,
541 // then we don't need to demand them from the other operand, since they
542 // can't cause overflow into any bits that are demanded in the result.
543 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
544 APInt DemandedFromLHS = DemandedFromOps;
545 DemandedFromLHS.clearLowBits(loBits: NTZ);
546 if (ShrinkDemandedConstant(I, OpNo: 0, Demanded: DemandedFromLHS) ||
547 SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedFromLHS, Known&: LHSKnown, Q, Depth: Depth + 1))
548 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
549
550 // If we are known to be adding zeros to every bit below
551 // the highest demanded bit, we just return the other side.
552 if (DemandedFromOps.isSubsetOf(RHS: RHSKnown.Zero))
553 return I->getOperand(i: 0);
554 if (DemandedFromOps.isSubsetOf(RHS: LHSKnown.Zero))
555 return I->getOperand(i: 1);
556
557 // (add X, C) --> (xor X, C) IFF C is equal to the top bit of the DemandMask
558 {
559 const APInt *C;
560 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: C)) &&
561 C->isOneBitSet(BitNo: DemandedMask.getActiveBits() - 1)) {
562 IRBuilderBase::InsertPointGuard Guard(Builder);
563 Builder.SetInsertPoint(I);
564 return Builder.CreateXor(LHS: I->getOperand(i: 0), RHS: ConstantInt::get(Ty: VTy, V: *C));
565 }
566 }
567
568 // Otherwise just compute the known bits of the result.
569 bool NSW = cast<OverflowingBinaryOperator>(Val: I)->hasNoSignedWrap();
570 bool NUW = cast<OverflowingBinaryOperator>(Val: I)->hasNoUnsignedWrap();
571 Known = KnownBits::add(LHS: LHSKnown, RHS: RHSKnown, NSW, NUW);
572 break;
573 }
574 case Instruction::Sub: {
575 // Right fill the mask of bits for the operands to demand the most
576 // significant bit and all those below it.
577 unsigned NLZ = DemandedMask.countl_zero();
578 APInt DemandedFromOps = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - NLZ);
579 if (ShrinkDemandedConstant(I, OpNo: 1, Demanded: DemandedFromOps) ||
580 SimplifyDemandedBits(I, OpNo: 1, DemandedMask: DemandedFromOps, Known&: RHSKnown, Q, Depth: Depth + 1))
581 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
582
583 // If low order bits are not demanded and are known to be zero in RHS,
584 // then we don't need to demand them from LHS, since they can't cause a
585 // borrow from any bits that are demanded in the result.
586 unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countr_one();
587 APInt DemandedFromLHS = DemandedFromOps;
588 DemandedFromLHS.clearLowBits(loBits: NTZ);
589 if (ShrinkDemandedConstant(I, OpNo: 0, Demanded: DemandedFromLHS) ||
590 SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedFromLHS, Known&: LHSKnown, Q, Depth: Depth + 1))
591 return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
592
593 // If we are known to be subtracting zeros from every bit below
594 // the highest demanded bit, we just return the other side.
595 if (DemandedFromOps.isSubsetOf(RHS: RHSKnown.Zero))
596 return I->getOperand(i: 0);
597 // We can't do this with the LHS for subtraction, unless we are only
598 // demanding the LSB.
599 if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(RHS: LHSKnown.Zero))
600 return I->getOperand(i: 1);
601
602 // Canonicalize sub mask, X -> ~X
603 const APInt *LHSC;
604 if (match(V: I->getOperand(i: 0), P: m_LowBitMask(V&: LHSC)) &&
605 DemandedFromOps.isSubsetOf(RHS: *LHSC)) {
606 IRBuilderBase::InsertPointGuard Guard(Builder);
607 Builder.SetInsertPoint(I);
608 return Builder.CreateNot(V: I->getOperand(i: 1));
609 }
610
611 // Otherwise just compute the known bits of the result.
612 bool NSW = cast<OverflowingBinaryOperator>(Val: I)->hasNoSignedWrap();
613 bool NUW = cast<OverflowingBinaryOperator>(Val: I)->hasNoUnsignedWrap();
614 Known = KnownBits::sub(LHS: LHSKnown, RHS: RHSKnown, NSW, NUW);
615 break;
616 }
617 case Instruction::Mul: {
618 APInt DemandedFromOps;
619 if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
620 return I;
621
622 if (DemandedMask.isPowerOf2()) {
623 // The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
624 // If we demand exactly one bit N and we have "X * (C' << N)" where C' is
625 // odd (has LSB set), then the left-shifted low bit of X is the answer.
626 unsigned CTZ = DemandedMask.countr_zero();
627 const APInt *C;
628 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: C)) && C->countr_zero() == CTZ) {
629 Constant *ShiftC = ConstantInt::get(Ty: VTy, V: CTZ);
630 Instruction *Shl = BinaryOperator::CreateShl(V1: I->getOperand(i: 0), V2: ShiftC);
631 return InsertNewInstWith(New: Shl, Old: I->getIterator());
632 }
633 }
634 // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because:
635 // X * X is odd iff X is odd.
636 // 'Quadratic Reciprocity': X * X -> 0 for bit[1]
637 if (I->getOperand(i: 0) == I->getOperand(i: 1) && DemandedMask.ult(RHS: 4)) {
638 Constant *One = ConstantInt::get(Ty: VTy, V: 1);
639 Instruction *And1 = BinaryOperator::CreateAnd(V1: I->getOperand(i: 0), V2: One);
640 return InsertNewInstWith(New: And1, Old: I->getIterator());
641 }
642
643 llvm::computeKnownBits(V: I, Known, Q, Depth);
644 break;
645 }
646 case Instruction::Shl: {
647 const APInt *SA;
648 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: SA))) {
649 const APInt *ShrAmt;
650 if (match(V: I->getOperand(i: 0), P: m_Shr(L: m_Value(), R: m_APInt(Res&: ShrAmt))))
651 if (Instruction *Shr = dyn_cast<Instruction>(Val: I->getOperand(i: 0)))
652 if (Value *R = simplifyShrShlDemandedBits(Shr, ShrOp1: *ShrAmt, Shl: I, ShlOp1: *SA,
653 DemandedMask, Known))
654 return R;
655
656 // Do not simplify if shl is part of funnel-shift pattern
657 if (I->hasOneUse()) {
658 auto *Inst = dyn_cast<Instruction>(Val: I->user_back());
659 if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
660 if (auto Opt = convertOrOfShiftsToFunnelShift(Or&: *Inst)) {
661 auto [IID, FShiftArgs] = *Opt;
662 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
663 FShiftArgs[0] == FShiftArgs[1]) {
664 llvm::computeKnownBits(V: I, Known, Q, Depth);
665 break;
666 }
667 }
668 }
669 }
670
671 // We only want bits that already match the signbit then we don't
672 // need to shift.
673 uint64_t ShiftAmt = SA->getLimitedValue(Limit: BitWidth - 1);
674 if (DemandedMask.countr_zero() >= ShiftAmt) {
675 if (I->hasNoSignedWrap()) {
676 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
677 unsigned SignBits =
678 ComputeNumSignBits(Op: I->getOperand(i: 0), CxtI: Q.CxtI, Depth: Depth + 1);
679 if (SignBits > ShiftAmt && SignBits - ShiftAmt >= NumHiDemandedBits)
680 return I->getOperand(i: 0);
681 }
682
683 // If we can pre-shift a right-shifted constant to the left without
684 // losing any high bits and we don't demand the low bits, then eliminate
685 // the left-shift:
686 // (C >> X) << LeftShiftAmtC --> (C << LeftShiftAmtC) >> X
687 Value *X;
688 Constant *C;
689 if (match(V: I->getOperand(i: 0), P: m_LShr(L: m_ImmConstant(C), R: m_Value(V&: X)))) {
690 Constant *LeftShiftAmtC = ConstantInt::get(Ty: VTy, V: ShiftAmt);
691 Constant *NewC = ConstantFoldBinaryOpOperands(Opcode: Instruction::Shl, LHS: C,
692 RHS: LeftShiftAmtC, DL);
693 if (ConstantFoldBinaryOpOperands(Opcode: Instruction::LShr, LHS: NewC,
694 RHS: LeftShiftAmtC, DL) == C) {
695 Instruction *Lshr = BinaryOperator::CreateLShr(V1: NewC, V2: X);
696 return InsertNewInstWith(New: Lshr, Old: I->getIterator());
697 }
698 }
699 }
700
701 APInt DemandedMaskIn(DemandedMask.lshr(shiftAmt: ShiftAmt));
702
703 // If the shift is NUW/NSW, then it does demand the high bits.
704 ShlOperator *IOp = cast<ShlOperator>(Val: I);
705 if (IOp->hasNoSignedWrap())
706 DemandedMaskIn.setHighBits(ShiftAmt+1);
707 else if (IOp->hasNoUnsignedWrap())
708 DemandedMaskIn.setHighBits(ShiftAmt);
709
710 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMaskIn, Known, Q, Depth: Depth + 1))
711 return I;
712
713 Known = KnownBits::shl(LHS: Known,
714 RHS: KnownBits::makeConstant(C: APInt(BitWidth, ShiftAmt)),
715 /* NUW */ IOp->hasNoUnsignedWrap(),
716 /* NSW */ IOp->hasNoSignedWrap());
717 } else {
718 // This is a variable shift, so we can't shift the demand mask by a known
719 // amount. But if we are not demanding high bits, then we are not
720 // demanding those bits from the pre-shifted operand either.
721 if (unsigned CTLZ = DemandedMask.countl_zero()) {
722 APInt DemandedFromOp(APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - CTLZ));
723 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedFromOp, Known, Q, Depth: Depth + 1)) {
724 // We can't guarantee that nsw/nuw hold after simplifying the operand.
725 I->dropPoisonGeneratingFlags();
726 return I;
727 }
728 }
729 llvm::computeKnownBits(V: I, Known, Q, Depth);
730 }
731 break;
732 }
733 case Instruction::LShr: {
734 const APInt *SA;
735 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: SA))) {
736 uint64_t ShiftAmt = SA->getLimitedValue(Limit: BitWidth-1);
737
738 // Do not simplify if lshr is part of funnel-shift pattern
739 if (I->hasOneUse()) {
740 auto *Inst = dyn_cast<Instruction>(Val: I->user_back());
741 if (Inst && Inst->getOpcode() == BinaryOperator::Or) {
742 if (auto Opt = convertOrOfShiftsToFunnelShift(Or&: *Inst)) {
743 auto [IID, FShiftArgs] = *Opt;
744 if ((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
745 FShiftArgs[0] == FShiftArgs[1]) {
746 llvm::computeKnownBits(V: I, Known, Q, Depth);
747 break;
748 }
749 }
750 }
751 }
752
753 // If we are just demanding the shifted sign bit and below, then this can
754 // be treated as an ASHR in disguise.
755 if (DemandedMask.countl_zero() >= ShiftAmt) {
756 // If we only want bits that already match the signbit then we don't
757 // need to shift.
758 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
759 unsigned SignBits =
760 ComputeNumSignBits(Op: I->getOperand(i: 0), CxtI: Q.CxtI, Depth: Depth + 1);
761 if (SignBits >= NumHiDemandedBits)
762 return I->getOperand(i: 0);
763
764 // If we can pre-shift a left-shifted constant to the right without
765 // losing any low bits (we already know we don't demand the high bits),
766 // then eliminate the right-shift:
767 // (C << X) >> RightShiftAmtC --> (C >> RightShiftAmtC) << X
768 Value *X;
769 Constant *C;
770 if (match(V: I->getOperand(i: 0), P: m_Shl(L: m_ImmConstant(C), R: m_Value(V&: X)))) {
771 Constant *RightShiftAmtC = ConstantInt::get(Ty: VTy, V: ShiftAmt);
772 Constant *NewC = ConstantFoldBinaryOpOperands(Opcode: Instruction::LShr, LHS: C,
773 RHS: RightShiftAmtC, DL);
774 if (ConstantFoldBinaryOpOperands(Opcode: Instruction::Shl, LHS: NewC,
775 RHS: RightShiftAmtC, DL) == C) {
776 Instruction *Shl = BinaryOperator::CreateShl(V1: NewC, V2: X);
777 return InsertNewInstWith(New: Shl, Old: I->getIterator());
778 }
779 }
780
781 const APInt *Factor;
782 if (match(V: I->getOperand(i: 0),
783 P: m_OneUse(SubPattern: m_Mul(L: m_Value(V&: X), R: m_APInt(Res&: Factor)))) &&
784 Factor->countr_zero() >= ShiftAmt) {
785 BinaryOperator *Mul = BinaryOperator::CreateMul(
786 V1: X, V2: ConstantInt::get(Ty: X->getType(), V: Factor->lshr(shiftAmt: ShiftAmt)));
787 return InsertNewInstWith(New: Mul, Old: I->getIterator());
788 }
789 }
790
791 // Unsigned shift right.
792 APInt DemandedMaskIn(DemandedMask.shl(shiftAmt: ShiftAmt));
793 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMaskIn, Known, Q, Depth: Depth + 1)) {
794 // exact flag may not longer hold.
795 I->dropPoisonGeneratingFlags();
796 return I;
797 }
798 Known.Zero.lshrInPlace(ShiftAmt);
799 Known.One.lshrInPlace(ShiftAmt);
800 if (ShiftAmt)
801 Known.Zero.setHighBits(ShiftAmt); // high bits known zero.
802 } else {
803 llvm::computeKnownBits(V: I, Known, Q, Depth);
804 }
805 break;
806 }
807 case Instruction::AShr: {
808 unsigned SignBits = ComputeNumSignBits(Op: I->getOperand(i: 0), CxtI: Q.CxtI, Depth: Depth + 1);
809
810 // If we only want bits that already match the signbit then we don't need
811 // to shift.
812 unsigned NumHiDemandedBits = BitWidth - DemandedMask.countr_zero();
813 if (SignBits >= NumHiDemandedBits)
814 return I->getOperand(i: 0);
815
816 // If this is an arithmetic shift right and only the low-bit is set, we can
817 // always convert this into a logical shr, even if the shift amount is
818 // variable. The low bit of the shift cannot be an input sign bit unless
819 // the shift amount is >= the size of the datatype, which is undefined.
820 if (DemandedMask.isOne()) {
821 // Perform the logical shift right.
822 Instruction *NewVal = BinaryOperator::CreateLShr(
823 V1: I->getOperand(i: 0), V2: I->getOperand(i: 1), Name: I->getName());
824 return InsertNewInstWith(New: NewVal, Old: I->getIterator());
825 }
826
827 const APInt *SA;
828 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: SA))) {
829 uint32_t ShiftAmt = SA->getLimitedValue(Limit: BitWidth-1);
830
831 // Signed shift right.
832 APInt DemandedMaskIn(DemandedMask.shl(shiftAmt: ShiftAmt));
833 // If any of the bits being shifted in are demanded, then we should set
834 // the sign bit as demanded.
835 bool ShiftedInBitsDemanded = DemandedMask.countl_zero() < ShiftAmt;
836 if (ShiftedInBitsDemanded)
837 DemandedMaskIn.setSignBit();
838 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMaskIn, Known, Q, Depth: Depth + 1)) {
839 // exact flag may not longer hold.
840 I->dropPoisonGeneratingFlags();
841 return I;
842 }
843
844 // If the input sign bit is known to be zero, or if none of the shifted in
845 // bits are demanded, turn this into an unsigned shift right.
846 if (Known.Zero[BitWidth - 1] || !ShiftedInBitsDemanded) {
847 BinaryOperator *LShr = BinaryOperator::CreateLShr(V1: I->getOperand(i: 0),
848 V2: I->getOperand(i: 1));
849 LShr->setIsExact(cast<BinaryOperator>(Val: I)->isExact());
850 LShr->takeName(V: I);
851 return InsertNewInstWith(New: LShr, Old: I->getIterator());
852 }
853
854 Known = KnownBits::ashr(
855 LHS: Known, RHS: KnownBits::makeConstant(C: APInt(BitWidth, ShiftAmt)),
856 ShAmtNonZero: ShiftAmt != 0, Exact: I->isExact());
857 } else {
858 llvm::computeKnownBits(V: I, Known, Q, Depth);
859 }
860 break;
861 }
862 case Instruction::UDiv: {
863 // UDiv doesn't demand low bits that are zero in the divisor.
864 const APInt *SA;
865 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: SA))) {
866 // TODO: Take the demanded mask of the result into account.
867 unsigned RHSTrailingZeros = SA->countr_zero();
868 APInt DemandedMaskIn =
869 APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: BitWidth - RHSTrailingZeros);
870 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMaskIn, Known&: LHSKnown, Q, Depth: Depth + 1)) {
871 // We can't guarantee that "exact" is still true after changing the
872 // the dividend.
873 I->dropPoisonGeneratingFlags();
874 return I;
875 }
876
877 Known = KnownBits::udiv(LHS: LHSKnown, RHS: KnownBits::makeConstant(C: *SA),
878 Exact: cast<BinaryOperator>(Val: I)->isExact());
879 } else {
880 llvm::computeKnownBits(V: I, Known, Q, Depth);
881 }
882 break;
883 }
884 case Instruction::SRem: {
885 const APInt *Rem;
886 if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: Rem)) && Rem->isPowerOf2()) {
887 if (DemandedMask.ult(RHS: *Rem)) // srem won't affect demanded bits
888 return I->getOperand(i: 0);
889
890 APInt LowBits = *Rem - 1;
891 APInt Mask2 = LowBits | APInt::getSignMask(BitWidth);
892 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: Mask2, Known&: LHSKnown, Q, Depth: Depth + 1))
893 return I;
894 Known = KnownBits::srem(LHS: LHSKnown, RHS: KnownBits::makeConstant(C: *Rem));
895 break;
896 }
897
898 llvm::computeKnownBits(V: I, Known, Q, Depth);
899 break;
900 }
901 case Instruction::Call: {
902 bool KnownBitsComputed = false;
903 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) {
904 switch (II->getIntrinsicID()) {
905 case Intrinsic::abs: {
906 if (DemandedMask == 1)
907 return II->getArgOperand(i: 0);
908 break;
909 }
910 case Intrinsic::ctpop: {
911 // Checking if the number of clear bits is odd (parity)? If the type has
912 // an even number of bits, that's the same as checking if the number of
913 // set bits is odd, so we can eliminate the 'not' op.
914 Value *X;
915 if (DemandedMask == 1 && VTy->getScalarSizeInBits() % 2 == 0 &&
916 match(V: II->getArgOperand(i: 0), P: m_Not(V: m_Value(V&: X)))) {
917 Function *Ctpop = Intrinsic::getOrInsertDeclaration(
918 M: II->getModule(), id: Intrinsic::ctpop, Tys: VTy);
919 return InsertNewInstWith(New: CallInst::Create(Func: Ctpop, Args: {X}), Old: I->getIterator());
920 }
921 break;
922 }
923 case Intrinsic::bswap: {
924 // If the only bits demanded come from one byte of the bswap result,
925 // just shift the input byte into position to eliminate the bswap.
926 unsigned NLZ = DemandedMask.countl_zero();
927 unsigned NTZ = DemandedMask.countr_zero();
928
929 // Round NTZ down to the next byte. If we have 11 trailing zeros, then
930 // we need all the bits down to bit 8. Likewise, round NLZ. If we
931 // have 14 leading zeros, round to 8.
932 NLZ = alignDown(Value: NLZ, Align: 8);
933 NTZ = alignDown(Value: NTZ, Align: 8);
934 // If we need exactly one byte, we can do this transformation.
935 if (BitWidth - NLZ - NTZ == 8) {
936 // Replace this with either a left or right shift to get the byte into
937 // the right place.
938 Instruction *NewVal;
939 if (NLZ > NTZ)
940 NewVal = BinaryOperator::CreateLShr(
941 V1: II->getArgOperand(i: 0), V2: ConstantInt::get(Ty: VTy, V: NLZ - NTZ));
942 else
943 NewVal = BinaryOperator::CreateShl(
944 V1: II->getArgOperand(i: 0), V2: ConstantInt::get(Ty: VTy, V: NTZ - NLZ));
945 NewVal->takeName(V: I);
946 return InsertNewInstWith(New: NewVal, Old: I->getIterator());
947 }
948 break;
949 }
950 case Intrinsic::ptrmask: {
951 unsigned MaskWidth = I->getOperand(i: 1)->getType()->getScalarSizeInBits();
952 RHSKnown = KnownBits(MaskWidth);
953 // If either the LHS or the RHS are Zero, the result is zero.
954 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask, Known&: LHSKnown, Q, Depth: Depth + 1) ||
955 SimplifyDemandedBits(
956 I, OpNo: 1, DemandedMask: (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(width: MaskWidth),
957 Known&: RHSKnown, Q, Depth: Depth + 1))
958 return I;
959
960 // TODO: Should be 1-extend
961 RHSKnown = RHSKnown.anyextOrTrunc(BitWidth);
962
963 Known = LHSKnown & RHSKnown;
964 KnownBitsComputed = true;
965
966 // If the client is only demanding bits we know to be zero, return
967 // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer
968 // provenance, but making the mask zero will be easily optimizable in
969 // the backend.
970 if (DemandedMask.isSubsetOf(RHS: Known.Zero) &&
971 !match(V: I->getOperand(i: 1), P: m_Zero()))
972 return replaceOperand(
973 I&: *I, OpNum: 1, V: Constant::getNullValue(Ty: I->getOperand(i: 1)->getType()));
974
975 // Mask in demanded space does nothing.
976 // NOTE: We may have attributes associated with the return value of the
977 // llvm.ptrmask intrinsic that will be lost when we just return the
978 // operand. We should try to preserve them.
979 if (DemandedMask.isSubsetOf(RHS: RHSKnown.One | LHSKnown.Zero))
980 return I->getOperand(i: 0);
981
982 // If the RHS is a constant, see if we can simplify it.
983 if (ShrinkDemandedConstant(
984 I, OpNo: 1, Demanded: (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(width: MaskWidth)))
985 return I;
986
987 // Combine:
988 // (ptrmask (getelementptr i8, ptr p, imm i), imm mask)
989 // -> (ptrmask (getelementptr i8, ptr p, imm (i & mask)), imm mask)
990 // where only the low bits known to be zero in the pointer are changed
991 Value *InnerPtr;
992 uint64_t GEPIndex;
993 uint64_t PtrMaskImmediate;
994 if (match(V: I, P: m_Intrinsic<Intrinsic::ptrmask>(
995 Op0: m_PtrAdd(PointerOp: m_Value(V&: InnerPtr), OffsetOp: m_ConstantInt(V&: GEPIndex)),
996 Op1: m_ConstantInt(V&: PtrMaskImmediate)))) {
997
998 LHSKnown = computeKnownBits(V: InnerPtr, CxtI: I, Depth: Depth + 1);
999 if (!LHSKnown.isZero()) {
1000 const unsigned trailingZeros = LHSKnown.countMinTrailingZeros();
1001 uint64_t PointerAlignBits = (uint64_t(1) << trailingZeros) - 1;
1002
1003 uint64_t HighBitsGEPIndex = GEPIndex & ~PointerAlignBits;
1004 uint64_t MaskedLowBitsGEPIndex =
1005 GEPIndex & PointerAlignBits & PtrMaskImmediate;
1006
1007 uint64_t MaskedGEPIndex = HighBitsGEPIndex | MaskedLowBitsGEPIndex;
1008
1009 if (MaskedGEPIndex != GEPIndex) {
1010 auto *GEP = cast<GEPOperator>(Val: II->getArgOperand(i: 0));
1011 Builder.SetInsertPoint(I);
1012 Type *GEPIndexType =
1013 DL.getIndexType(PtrTy: GEP->getPointerOperand()->getType());
1014 Value *MaskedGEP = Builder.CreateGEP(
1015 Ty: GEP->getSourceElementType(), Ptr: InnerPtr,
1016 IdxList: ConstantInt::get(Ty: GEPIndexType, V: MaskedGEPIndex),
1017 Name: GEP->getName(), NW: GEP->isInBounds());
1018
1019 replaceOperand(I&: *I, OpNum: 0, V: MaskedGEP);
1020 return I;
1021 }
1022 }
1023 }
1024
1025 break;
1026 }
1027
1028 case Intrinsic::fshr:
1029 case Intrinsic::fshl: {
1030 const APInt *SA;
1031 if (!match(V: I->getOperand(i: 2), P: m_APInt(Res&: SA)))
1032 break;
1033
1034 // Normalize to funnel shift left. APInt shifts of BitWidth are well-
1035 // defined, so no need to special-case zero shifts here.
1036 uint64_t ShiftAmt = SA->urem(RHS: BitWidth);
1037 if (II->getIntrinsicID() == Intrinsic::fshr)
1038 ShiftAmt = BitWidth - ShiftAmt;
1039
1040 APInt DemandedMaskLHS(DemandedMask.lshr(shiftAmt: ShiftAmt));
1041 APInt DemandedMaskRHS(DemandedMask.shl(shiftAmt: BitWidth - ShiftAmt));
1042 if (I->getOperand(i: 0) != I->getOperand(i: 1)) {
1043 if (SimplifyDemandedBits(I, OpNo: 0, DemandedMask: DemandedMaskLHS, Known&: LHSKnown, Q,
1044 Depth: Depth + 1) ||
1045 SimplifyDemandedBits(I, OpNo: 1, DemandedMask: DemandedMaskRHS, Known&: RHSKnown, Q,
1046 Depth: Depth + 1)) {
1047 // Range attribute may no longer hold.
1048 I->dropPoisonGeneratingReturnAttributes();
1049 return I;
1050 }
1051 } else { // fshl is a rotate
1052 // Avoid converting rotate into funnel shift.
1053 // Only simplify if one operand is constant.
1054 LHSKnown = computeKnownBits(V: I->getOperand(i: 0), CxtI: I, Depth: Depth + 1);
1055 if (DemandedMaskLHS.isSubsetOf(RHS: LHSKnown.Zero | LHSKnown.One) &&
1056 !match(V: I->getOperand(i: 0), P: m_SpecificInt(V: LHSKnown.One))) {
1057 replaceOperand(I&: *I, OpNum: 0, V: Constant::getIntegerValue(Ty: VTy, V: LHSKnown.One));
1058 return I;
1059 }
1060
1061 RHSKnown = computeKnownBits(V: I->getOperand(i: 1), CxtI: I, Depth: Depth + 1);
1062 if (DemandedMaskRHS.isSubsetOf(RHS: RHSKnown.Zero | RHSKnown.One) &&
1063 !match(V: I->getOperand(i: 1), P: m_SpecificInt(V: RHSKnown.One))) {
1064 replaceOperand(I&: *I, OpNum: 1, V: Constant::getIntegerValue(Ty: VTy, V: RHSKnown.One));
1065 return I;
1066 }
1067 }
1068
1069 Known.Zero = LHSKnown.Zero.shl(shiftAmt: ShiftAmt) |
1070 RHSKnown.Zero.lshr(shiftAmt: BitWidth - ShiftAmt);
1071 Known.One = LHSKnown.One.shl(shiftAmt: ShiftAmt) |
1072 RHSKnown.One.lshr(shiftAmt: BitWidth - ShiftAmt);
1073 KnownBitsComputed = true;
1074 break;
1075 }
1076 case Intrinsic::umax: {
1077 // UMax(A, C) == A if ...
1078 // The lowest non-zero bit of DemandMask is higher than the highest
1079 // non-zero bit of C.
1080 const APInt *C;
1081 unsigned CTZ = DemandedMask.countr_zero();
1082 if (match(V: II->getArgOperand(i: 1), P: m_APInt(Res&: C)) &&
1083 CTZ >= C->getActiveBits())
1084 return II->getArgOperand(i: 0);
1085 break;
1086 }
1087 case Intrinsic::umin: {
1088 // UMin(A, C) == A if ...
1089 // The lowest non-zero bit of DemandMask is higher than the highest
1090 // non-one bit of C.
1091 // This comes from using DeMorgans on the above umax example.
1092 const APInt *C;
1093 unsigned CTZ = DemandedMask.countr_zero();
1094 if (match(V: II->getArgOperand(i: 1), P: m_APInt(Res&: C)) &&
1095 CTZ >= C->getBitWidth() - C->countl_one())
1096 return II->getArgOperand(i: 0);
1097 break;
1098 }
1099 default: {
1100 // Handle target specific intrinsics
1101 std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
1102 II&: *II, DemandedMask, Known, KnownBitsComputed);
1103 if (V)
1104 return *V;
1105 break;
1106 }
1107 }
1108 }
1109
1110 if (!KnownBitsComputed)
1111 llvm::computeKnownBits(V: I, Known, Q, Depth);
1112 break;
1113 }
1114 }
1115
1116 if (I->getType()->isPointerTy()) {
1117 Align Alignment = I->getPointerAlignment(DL);
1118 Known.Zero.setLowBits(Log2(A: Alignment));
1119 }
1120
1121 // If the client is only demanding bits that we know, return the known
1122 // constant. We can't directly simplify pointers as a constant because of
1123 // pointer provenance.
1124 // TODO: We could return `(inttoptr const)` for pointers.
1125 if (!I->getType()->isPointerTy() &&
1126 DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
1127 return Constant::getIntegerValue(Ty: VTy, V: Known.One);
1128
1129 if (VerifyKnownBits) {
1130 KnownBits ReferenceKnown = llvm::computeKnownBits(V: I, Q, Depth);
1131 if (Known != ReferenceKnown) {
1132 errs() << "Mismatched known bits for " << *I << " in "
1133 << I->getFunction()->getName() << "\n";
1134 errs() << "computeKnownBits(): " << ReferenceKnown << "\n";
1135 errs() << "SimplifyDemandedBits(): " << Known << "\n";
1136 std::abort();
1137 }
1138 }
1139
1140 return nullptr;
1141}
1142
1143/// Helper routine of SimplifyDemandedUseBits. It computes Known
1144/// bits. It also tries to handle simplifications that can be done based on
1145/// DemandedMask, but without modifying the Instruction.
1146Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
1147 Instruction *I, const APInt &DemandedMask, KnownBits &Known,
1148 const SimplifyQuery &Q, unsigned Depth) {
1149 unsigned BitWidth = DemandedMask.getBitWidth();
1150 Type *ITy = I->getType();
1151
1152 KnownBits LHSKnown(BitWidth);
1153 KnownBits RHSKnown(BitWidth);
1154
1155 // Despite the fact that we can't simplify this instruction in all User's
1156 // context, we can at least compute the known bits, and we can
1157 // do simplifications that apply to *just* the one user if we know that
1158 // this instruction has a simpler value in that context.
1159 switch (I->getOpcode()) {
1160 case Instruction::And: {
1161 llvm::computeKnownBits(V: I->getOperand(i: 1), Known&: RHSKnown, Q, Depth: Depth + 1);
1162 llvm::computeKnownBits(V: I->getOperand(i: 0), Known&: LHSKnown, Q, Depth: Depth + 1);
1163 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
1164 SQ: Q, Depth);
1165 computeKnownBitsFromContext(V: I, Known, Q, Depth);
1166
1167 // If the client is only demanding bits that we know, return the known
1168 // constant.
1169 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
1170 return Constant::getIntegerValue(Ty: ITy, V: Known.One);
1171
1172 // If all of the demanded bits are known 1 on one side, return the other.
1173 // These bits cannot contribute to the result of the 'and' in this context.
1174 if (DemandedMask.isSubsetOf(RHS: LHSKnown.Zero | RHSKnown.One))
1175 return I->getOperand(i: 0);
1176 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero | LHSKnown.One))
1177 return I->getOperand(i: 1);
1178
1179 break;
1180 }
1181 case Instruction::Or: {
1182 llvm::computeKnownBits(V: I->getOperand(i: 1), Known&: RHSKnown, Q, Depth: Depth + 1);
1183 llvm::computeKnownBits(V: I->getOperand(i: 0), Known&: LHSKnown, Q, Depth: Depth + 1);
1184 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
1185 SQ: Q, Depth);
1186 computeKnownBitsFromContext(V: I, Known, Q, Depth);
1187
1188 // If the client is only demanding bits that we know, return the known
1189 // constant.
1190 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
1191 return Constant::getIntegerValue(Ty: ITy, V: Known.One);
1192
1193 // We can simplify (X|Y) -> X or Y in the user's context if we know that
1194 // only bits from X or Y are demanded.
1195 // If all of the demanded bits are known zero on one side, return the other.
1196 // These bits cannot contribute to the result of the 'or' in this context.
1197 if (DemandedMask.isSubsetOf(RHS: LHSKnown.One | RHSKnown.Zero))
1198 return I->getOperand(i: 0);
1199 if (DemandedMask.isSubsetOf(RHS: RHSKnown.One | LHSKnown.Zero))
1200 return I->getOperand(i: 1);
1201
1202 break;
1203 }
1204 case Instruction::Xor: {
1205 llvm::computeKnownBits(V: I->getOperand(i: 1), Known&: RHSKnown, Q, Depth: Depth + 1);
1206 llvm::computeKnownBits(V: I->getOperand(i: 0), Known&: LHSKnown, Q, Depth: Depth + 1);
1207 Known = analyzeKnownBitsFromAndXorOr(I: cast<Operator>(Val: I), KnownLHS: LHSKnown, KnownRHS: RHSKnown,
1208 SQ: Q, Depth);
1209 computeKnownBitsFromContext(V: I, Known, Q, Depth);
1210
1211 // If the client is only demanding bits that we know, return the known
1212 // constant.
1213 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
1214 return Constant::getIntegerValue(Ty: ITy, V: Known.One);
1215
1216 // We can simplify (X^Y) -> X or Y in the user's context if we know that
1217 // only bits from X or Y are demanded.
1218 // If all of the demanded bits are known zero on one side, return the other.
1219 if (DemandedMask.isSubsetOf(RHS: RHSKnown.Zero))
1220 return I->getOperand(i: 0);
1221 if (DemandedMask.isSubsetOf(RHS: LHSKnown.Zero))
1222 return I->getOperand(i: 1);
1223
1224 break;
1225 }
1226 case Instruction::Add: {
1227 unsigned NLZ = DemandedMask.countl_zero();
1228 APInt DemandedFromOps = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - NLZ);
1229
1230 // If an operand adds zeros to every bit below the highest demanded bit,
1231 // that operand doesn't change the result. Return the other side.
1232 llvm::computeKnownBits(V: I->getOperand(i: 1), Known&: RHSKnown, Q, Depth: Depth + 1);
1233 if (DemandedFromOps.isSubsetOf(RHS: RHSKnown.Zero))
1234 return I->getOperand(i: 0);
1235
1236 llvm::computeKnownBits(V: I->getOperand(i: 0), Known&: LHSKnown, Q, Depth: Depth + 1);
1237 if (DemandedFromOps.isSubsetOf(RHS: LHSKnown.Zero))
1238 return I->getOperand(i: 1);
1239
1240 bool NSW = cast<OverflowingBinaryOperator>(Val: I)->hasNoSignedWrap();
1241 bool NUW = cast<OverflowingBinaryOperator>(Val: I)->hasNoUnsignedWrap();
1242 Known = KnownBits::add(LHS: LHSKnown, RHS: RHSKnown, NSW, NUW);
1243 computeKnownBitsFromContext(V: I, Known, Q, Depth);
1244 break;
1245 }
1246 case Instruction::Sub: {
1247 unsigned NLZ = DemandedMask.countl_zero();
1248 APInt DemandedFromOps = APInt::getLowBitsSet(numBits: BitWidth, loBitsSet: BitWidth - NLZ);
1249
1250 // If an operand subtracts zeros from every bit below the highest demanded
1251 // bit, that operand doesn't change the result. Return the other side.
1252 llvm::computeKnownBits(V: I->getOperand(i: 1), Known&: RHSKnown, Q, Depth: Depth + 1);
1253 if (DemandedFromOps.isSubsetOf(RHS: RHSKnown.Zero))
1254 return I->getOperand(i: 0);
1255
1256 bool NSW = cast<OverflowingBinaryOperator>(Val: I)->hasNoSignedWrap();
1257 bool NUW = cast<OverflowingBinaryOperator>(Val: I)->hasNoUnsignedWrap();
1258 llvm::computeKnownBits(V: I->getOperand(i: 0), Known&: LHSKnown, Q, Depth: Depth + 1);
1259 Known = KnownBits::sub(LHS: LHSKnown, RHS: RHSKnown, NSW, NUW);
1260 computeKnownBitsFromContext(V: I, Known, Q, Depth);
1261 break;
1262 }
1263 case Instruction::AShr: {
1264 // Compute the Known bits to simplify things downstream.
1265 llvm::computeKnownBits(V: I, Known, Q, Depth);
1266
1267 // If this user is only demanding bits that we know, return the known
1268 // constant.
1269 if (DemandedMask.isSubsetOf(RHS: Known.Zero | Known.One))
1270 return Constant::getIntegerValue(Ty: ITy, V: Known.One);
1271
1272 // If the right shift operand 0 is a result of a left shift by the same
1273 // amount, this is probably a zero/sign extension, which may be unnecessary,
1274 // if we do not demand any of the new sign bits. So, return the original
1275 // operand instead.
1276 const APInt *ShiftRC;
1277 const APInt *ShiftLC;
1278 Value *X;
1279 unsigned BitWidth = DemandedMask.getBitWidth();
1280 if (match(V: I,
1281 P: m_AShr(L: m_Shl(L: m_Value(V&: X), R: m_APInt(Res&: ShiftLC)), R: m_APInt(Res&: ShiftRC))) &&
1282 ShiftLC == ShiftRC && ShiftLC->ult(RHS: BitWidth) &&
1283 DemandedMask.isSubsetOf(RHS: APInt::getLowBitsSet(
1284 numBits: BitWidth, loBitsSet: BitWidth - ShiftRC->getZExtValue()))) {
1285 return X;
1286 }
1287
1288 break;
1289 }
1290 default:
1291 // Compute the Known bits to simplify things downstream.
1292 llvm::computeKnownBits(V: I, Known, Q, Depth);
1293
1294 // If this user is only demanding bits that we know, return the known
1295 // constant.
1296 if (DemandedMask.isSubsetOf(RHS: Known.Zero|Known.One))
1297 return Constant::getIntegerValue(Ty: ITy, V: Known.One);
1298
1299 break;
1300 }
1301
1302 return nullptr;
1303}
1304
1305/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
1306/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
1307/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
1308/// of "C2-C1".
1309///
1310/// Suppose E1 and E2 are generally different in bits S={bm, bm+1,
1311/// ..., bn}, without considering the specific value X is holding.
1312/// This transformation is legal iff one of following conditions is hold:
1313/// 1) All the bit in S are 0, in this case E1 == E2.
1314/// 2) We don't care those bits in S, per the input DemandedMask.
1315/// 3) Combination of 1) and 2). Some bits in S are 0, and we don't care the
1316/// rest bits.
1317///
1318/// Currently we only test condition 2).
1319///
1320/// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
1321/// not successful.
1322Value *InstCombinerImpl::simplifyShrShlDemandedBits(
1323 Instruction *Shr, const APInt &ShrOp1, Instruction *Shl,
1324 const APInt &ShlOp1, const APInt &DemandedMask, KnownBits &Known) {
1325 if (!ShlOp1 || !ShrOp1)
1326 return nullptr; // No-op.
1327
1328 Value *VarX = Shr->getOperand(i: 0);
1329 Type *Ty = VarX->getType();
1330 unsigned BitWidth = Ty->getScalarSizeInBits();
1331 if (ShlOp1.uge(RHS: BitWidth) || ShrOp1.uge(RHS: BitWidth))
1332 return nullptr; // Undef.
1333
1334 unsigned ShlAmt = ShlOp1.getZExtValue();
1335 unsigned ShrAmt = ShrOp1.getZExtValue();
1336
1337 Known.One.clearAllBits();
1338 Known.Zero.setLowBits(ShlAmt - 1);
1339 Known.Zero &= DemandedMask;
1340
1341 APInt BitMask1(APInt::getAllOnes(numBits: BitWidth));
1342 APInt BitMask2(APInt::getAllOnes(numBits: BitWidth));
1343
1344 bool isLshr = (Shr->getOpcode() == Instruction::LShr);
1345 BitMask1 = isLshr ? (BitMask1.lshr(shiftAmt: ShrAmt) << ShlAmt) :
1346 (BitMask1.ashr(ShiftAmt: ShrAmt) << ShlAmt);
1347
1348 if (ShrAmt <= ShlAmt) {
1349 BitMask2 <<= (ShlAmt - ShrAmt);
1350 } else {
1351 BitMask2 = isLshr ? BitMask2.lshr(shiftAmt: ShrAmt - ShlAmt):
1352 BitMask2.ashr(ShiftAmt: ShrAmt - ShlAmt);
1353 }
1354
1355 // Check if condition-2 (see the comment to this function) is satified.
1356 if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
1357 if (ShrAmt == ShlAmt)
1358 return VarX;
1359
1360 if (!Shr->hasOneUse())
1361 return nullptr;
1362
1363 BinaryOperator *New;
1364 if (ShrAmt < ShlAmt) {
1365 Constant *Amt = ConstantInt::get(Ty: VarX->getType(), V: ShlAmt - ShrAmt);
1366 New = BinaryOperator::CreateShl(V1: VarX, V2: Amt);
1367 BinaryOperator *Orig = cast<BinaryOperator>(Val: Shl);
1368 New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
1369 New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
1370 } else {
1371 Constant *Amt = ConstantInt::get(Ty: VarX->getType(), V: ShrAmt - ShlAmt);
1372 New = isLshr ? BinaryOperator::CreateLShr(V1: VarX, V2: Amt) :
1373 BinaryOperator::CreateAShr(V1: VarX, V2: Amt);
1374 if (cast<BinaryOperator>(Val: Shr)->isExact())
1375 New->setIsExact(true);
1376 }
1377
1378 return InsertNewInstWith(New, Old: Shl->getIterator());
1379 }
1380
1381 return nullptr;
1382}
1383
1384/// The specified value produces a vector with any number of elements.
1385/// This method analyzes which elements of the operand are poison and
1386/// returns that information in PoisonElts.
1387///
1388/// DemandedElts contains the set of elements that are actually used by the
1389/// caller, and by default (AllowMultipleUsers equals false) the value is
1390/// simplified only if it has a single caller. If AllowMultipleUsers is set
1391/// to true, DemandedElts refers to the union of sets of elements that are
1392/// used by all callers.
1393///
1394/// If the information about demanded elements can be used to simplify the
1395/// operation, the operation is simplified, then the resultant value is
1396/// returned. This returns null if no change was made.
1397Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
1398 APInt DemandedElts,
1399 APInt &PoisonElts,
1400 unsigned Depth,
1401 bool AllowMultipleUsers) {
1402 // Cannot analyze scalable type. The number of vector elements is not a
1403 // compile-time constant.
1404 if (isa<ScalableVectorType>(Val: V->getType()))
1405 return nullptr;
1406
1407 unsigned VWidth = cast<FixedVectorType>(Val: V->getType())->getNumElements();
1408 APInt EltMask(APInt::getAllOnes(numBits: VWidth));
1409 assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
1410
1411 if (match(V, P: m_Poison())) {
1412 // If the entire vector is poison, just return this info.
1413 PoisonElts = EltMask;
1414 return nullptr;
1415 }
1416
1417 if (DemandedElts.isZero()) { // If nothing is demanded, provide poison.
1418 PoisonElts = EltMask;
1419 return PoisonValue::get(T: V->getType());
1420 }
1421
1422 PoisonElts = 0;
1423
1424 if (auto *C = dyn_cast<Constant>(Val: V)) {
1425 // Check if this is identity. If so, return 0 since we are not simplifying
1426 // anything.
1427 if (DemandedElts.isAllOnes())
1428 return nullptr;
1429
1430 Type *EltTy = cast<VectorType>(Val: V->getType())->getElementType();
1431 Constant *Poison = PoisonValue::get(T: EltTy);
1432 SmallVector<Constant*, 16> Elts;
1433 for (unsigned i = 0; i != VWidth; ++i) {
1434 if (!DemandedElts[i]) { // If not demanded, set to poison.
1435 Elts.push_back(Elt: Poison);
1436 PoisonElts.setBit(i);
1437 continue;
1438 }
1439
1440 Constant *Elt = C->getAggregateElement(Elt: i);
1441 if (!Elt) return nullptr;
1442
1443 Elts.push_back(Elt);
1444 if (isa<PoisonValue>(Val: Elt)) // Already poison.
1445 PoisonElts.setBit(i);
1446 }
1447
1448 // If we changed the constant, return it.
1449 Constant *NewCV = ConstantVector::get(V: Elts);
1450 return NewCV != C ? NewCV : nullptr;
1451 }
1452
1453 // Limit search depth.
1454 if (Depth == SimplifyDemandedVectorEltsDepthLimit)
1455 return nullptr;
1456
1457 if (!AllowMultipleUsers) {
1458 // If multiple users are using the root value, proceed with
1459 // simplification conservatively assuming that all elements
1460 // are needed.
1461 if (!V->hasOneUse()) {
1462 // Quit if we find multiple users of a non-root value though.
1463 // They'll be handled when it's their turn to be visited by
1464 // the main instcombine process.
1465 if (Depth != 0)
1466 // TODO: Just compute the PoisonElts information recursively.
1467 return nullptr;
1468
1469 // Conservatively assume that all elements are needed.
1470 DemandedElts = EltMask;
1471 }
1472 }
1473
1474 Instruction *I = dyn_cast<Instruction>(Val: V);
1475 if (!I) return nullptr; // Only analyze instructions.
1476
1477 bool MadeChange = false;
1478 auto simplifyAndSetOp = [&](Instruction *Inst, unsigned OpNum,
1479 APInt Demanded, APInt &Undef) {
1480 auto *II = dyn_cast<IntrinsicInst>(Val: Inst);
1481 Value *Op = II ? II->getArgOperand(i: OpNum) : Inst->getOperand(i: OpNum);
1482 if (Value *V = SimplifyDemandedVectorElts(V: Op, DemandedElts: Demanded, PoisonElts&: Undef, Depth: Depth + 1)) {
1483 replaceOperand(I&: *Inst, OpNum, V);
1484 MadeChange = true;
1485 }
1486 };
1487
1488 APInt PoisonElts2(VWidth, 0);
1489 APInt PoisonElts3(VWidth, 0);
1490 switch (I->getOpcode()) {
1491 default: break;
1492
1493 case Instruction::GetElementPtr: {
1494 // The LangRef requires that struct geps have all constant indices. As
1495 // such, we can't convert any operand to partial undef.
1496 auto mayIndexStructType = [](GetElementPtrInst &GEP) {
1497 for (auto I = gep_type_begin(GEP), E = gep_type_end(GEP);
1498 I != E; I++)
1499 if (I.isStruct())
1500 return true;
1501 return false;
1502 };
1503 if (mayIndexStructType(cast<GetElementPtrInst>(Val&: *I)))
1504 break;
1505
1506 // Conservatively track the demanded elements back through any vector
1507 // operands we may have. We know there must be at least one, or we
1508 // wouldn't have a vector result to get here. Note that we intentionally
1509 // merge the undef bits here since gepping with either an poison base or
1510 // index results in poison.
1511 for (unsigned i = 0; i < I->getNumOperands(); i++) {
1512 if (i == 0 ? match(V: I->getOperand(i), P: m_Undef())
1513 : match(V: I->getOperand(i), P: m_Poison())) {
1514 // If the entire vector is undefined, just return this info.
1515 PoisonElts = EltMask;
1516 return nullptr;
1517 }
1518 if (I->getOperand(i)->getType()->isVectorTy()) {
1519 APInt PoisonEltsOp(VWidth, 0);
1520 simplifyAndSetOp(I, i, DemandedElts, PoisonEltsOp);
1521 // gep(x, undef) is not undef, so skip considering idx ops here
1522 // Note that we could propagate poison, but we can't distinguish between
1523 // undef & poison bits ATM
1524 if (i == 0)
1525 PoisonElts |= PoisonEltsOp;
1526 }
1527 }
1528
1529 break;
1530 }
1531 case Instruction::InsertElement: {
1532 // If this is a variable index, we don't know which element it overwrites.
1533 // demand exactly the same input as we produce.
1534 ConstantInt *Idx = dyn_cast<ConstantInt>(Val: I->getOperand(i: 2));
1535 if (!Idx) {
1536 // Note that we can't propagate undef elt info, because we don't know
1537 // which elt is getting updated.
1538 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts2);
1539 break;
1540 }
1541
1542 // The element inserted overwrites whatever was there, so the input demanded
1543 // set is simpler than the output set.
1544 unsigned IdxNo = Idx->getZExtValue();
1545 APInt PreInsertDemandedElts = DemandedElts;
1546 if (IdxNo < VWidth)
1547 PreInsertDemandedElts.clearBit(BitPosition: IdxNo);
1548
1549 // If we only demand the element that is being inserted and that element
1550 // was extracted from the same index in another vector with the same type,
1551 // replace this insert with that other vector.
1552 // Note: This is attempted before the call to simplifyAndSetOp because that
1553 // may change PoisonElts to a value that does not match with Vec.
1554 Value *Vec;
1555 if (PreInsertDemandedElts == 0 &&
1556 match(V: I->getOperand(i: 1),
1557 P: m_ExtractElt(Val: m_Value(V&: Vec), Idx: m_SpecificInt(V: IdxNo))) &&
1558 Vec->getType() == I->getType()) {
1559 return Vec;
1560 }
1561
1562 simplifyAndSetOp(I, 0, PreInsertDemandedElts, PoisonElts);
1563
1564 // If this is inserting an element that isn't demanded, remove this
1565 // insertelement.
1566 if (IdxNo >= VWidth || !DemandedElts[IdxNo]) {
1567 Worklist.push(I);
1568 return I->getOperand(i: 0);
1569 }
1570
1571 // The inserted element is defined.
1572 PoisonElts.clearBit(BitPosition: IdxNo);
1573 break;
1574 }
1575 case Instruction::ShuffleVector: {
1576 auto *Shuffle = cast<ShuffleVectorInst>(Val: I);
1577 assert(Shuffle->getOperand(0)->getType() ==
1578 Shuffle->getOperand(1)->getType() &&
1579 "Expected shuffle operands to have same type");
1580 unsigned OpWidth = cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType())
1581 ->getNumElements();
1582 // Handle trivial case of a splat. Only check the first element of LHS
1583 // operand.
1584 if (all_of(Range: Shuffle->getShuffleMask(), P: [](int Elt) { return Elt == 0; }) &&
1585 DemandedElts.isAllOnes()) {
1586 if (!isa<PoisonValue>(Val: I->getOperand(i: 1))) {
1587 I->setOperand(i: 1, Val: PoisonValue::get(T: I->getOperand(i: 1)->getType()));
1588 MadeChange = true;
1589 }
1590 APInt LeftDemanded(OpWidth, 1);
1591 APInt LHSPoisonElts(OpWidth, 0);
1592 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1593 if (LHSPoisonElts[0])
1594 PoisonElts = EltMask;
1595 else
1596 PoisonElts.clearAllBits();
1597 break;
1598 }
1599
1600 APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0);
1601 for (unsigned i = 0; i < VWidth; i++) {
1602 if (DemandedElts[i]) {
1603 unsigned MaskVal = Shuffle->getMaskValue(Elt: i);
1604 if (MaskVal != -1u) {
1605 assert(MaskVal < OpWidth * 2 &&
1606 "shufflevector mask index out of range!");
1607 if (MaskVal < OpWidth)
1608 LeftDemanded.setBit(MaskVal);
1609 else
1610 RightDemanded.setBit(MaskVal - OpWidth);
1611 }
1612 }
1613 }
1614
1615 APInt LHSPoisonElts(OpWidth, 0);
1616 simplifyAndSetOp(I, 0, LeftDemanded, LHSPoisonElts);
1617
1618 APInt RHSPoisonElts(OpWidth, 0);
1619 simplifyAndSetOp(I, 1, RightDemanded, RHSPoisonElts);
1620
1621 // If this shuffle does not change the vector length and the elements
1622 // demanded by this shuffle are an identity mask, then this shuffle is
1623 // unnecessary.
1624 //
1625 // We are assuming canonical form for the mask, so the source vector is
1626 // operand 0 and operand 1 is not used.
1627 //
1628 // Note that if an element is demanded and this shuffle mask is undefined
1629 // for that element, then the shuffle is not considered an identity
1630 // operation. The shuffle prevents poison from the operand vector from
1631 // leaking to the result by replacing poison with an undefined value.
1632 if (VWidth == OpWidth) {
1633 bool IsIdentityShuffle = true;
1634 for (unsigned i = 0; i < VWidth; i++) {
1635 unsigned MaskVal = Shuffle->getMaskValue(Elt: i);
1636 if (DemandedElts[i] && i != MaskVal) {
1637 IsIdentityShuffle = false;
1638 break;
1639 }
1640 }
1641 if (IsIdentityShuffle)
1642 return Shuffle->getOperand(i_nocapture: 0);
1643 }
1644
1645 bool NewPoisonElts = false;
1646 unsigned LHSIdx = -1u, LHSValIdx = -1u;
1647 unsigned RHSIdx = -1u, RHSValIdx = -1u;
1648 bool LHSUniform = true;
1649 bool RHSUniform = true;
1650 for (unsigned i = 0; i < VWidth; i++) {
1651 unsigned MaskVal = Shuffle->getMaskValue(Elt: i);
1652 if (MaskVal == -1u) {
1653 PoisonElts.setBit(i);
1654 } else if (!DemandedElts[i]) {
1655 NewPoisonElts = true;
1656 PoisonElts.setBit(i);
1657 } else if (MaskVal < OpWidth) {
1658 if (LHSPoisonElts[MaskVal]) {
1659 NewPoisonElts = true;
1660 PoisonElts.setBit(i);
1661 } else {
1662 LHSIdx = LHSIdx == -1u ? i : OpWidth;
1663 LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth;
1664 LHSUniform = LHSUniform && (MaskVal == i);
1665 }
1666 } else {
1667 if (RHSPoisonElts[MaskVal - OpWidth]) {
1668 NewPoisonElts = true;
1669 PoisonElts.setBit(i);
1670 } else {
1671 RHSIdx = RHSIdx == -1u ? i : OpWidth;
1672 RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth;
1673 RHSUniform = RHSUniform && (MaskVal - OpWidth == i);
1674 }
1675 }
1676 }
1677
1678 // Try to transform shuffle with constant vector and single element from
1679 // this constant vector to single insertelement instruction.
1680 // shufflevector V, C, <v1, v2, .., ci, .., vm> ->
1681 // insertelement V, C[ci], ci-n
1682 if (OpWidth ==
1683 cast<FixedVectorType>(Val: Shuffle->getType())->getNumElements()) {
1684 Value *Op = nullptr;
1685 Constant *Value = nullptr;
1686 unsigned Idx = -1u;
1687
1688 // Find constant vector with the single element in shuffle (LHS or RHS).
1689 if (LHSIdx < OpWidth && RHSUniform) {
1690 if (auto *CV = dyn_cast<ConstantVector>(Val: Shuffle->getOperand(i_nocapture: 0))) {
1691 Op = Shuffle->getOperand(i_nocapture: 1);
1692 Value = CV->getOperand(i_nocapture: LHSValIdx);
1693 Idx = LHSIdx;
1694 }
1695 }
1696 if (RHSIdx < OpWidth && LHSUniform) {
1697 if (auto *CV = dyn_cast<ConstantVector>(Val: Shuffle->getOperand(i_nocapture: 1))) {
1698 Op = Shuffle->getOperand(i_nocapture: 0);
1699 Value = CV->getOperand(i_nocapture: RHSValIdx);
1700 Idx = RHSIdx;
1701 }
1702 }
1703 // Found constant vector with single element - convert to insertelement.
1704 if (Op && Value) {
1705 Instruction *New = InsertElementInst::Create(
1706 Vec: Op, NewElt: Value, Idx: ConstantInt::get(Ty: Type::getInt64Ty(C&: I->getContext()), V: Idx),
1707 NameStr: Shuffle->getName());
1708 InsertNewInstWith(New, Old: Shuffle->getIterator());
1709 return New;
1710 }
1711 }
1712 if (NewPoisonElts) {
1713 // Add additional discovered undefs.
1714 SmallVector<int, 16> Elts;
1715 for (unsigned i = 0; i < VWidth; ++i) {
1716 if (PoisonElts[i])
1717 Elts.push_back(Elt: PoisonMaskElem);
1718 else
1719 Elts.push_back(Elt: Shuffle->getMaskValue(Elt: i));
1720 }
1721 Shuffle->setShuffleMask(Elts);
1722 MadeChange = true;
1723 }
1724 break;
1725 }
1726 case Instruction::Select: {
1727 // If this is a vector select, try to transform the select condition based
1728 // on the current demanded elements.
1729 SelectInst *Sel = cast<SelectInst>(Val: I);
1730 if (Sel->getCondition()->getType()->isVectorTy()) {
1731 // TODO: We are not doing anything with PoisonElts based on this call.
1732 // It is overwritten below based on the other select operands. If an
1733 // element of the select condition is known undef, then we are free to
1734 // choose the output value from either arm of the select. If we know that
1735 // one of those values is undef, then the output can be undef.
1736 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1737 }
1738
1739 // Next, see if we can transform the arms of the select.
1740 APInt DemandedLHS(DemandedElts), DemandedRHS(DemandedElts);
1741 if (auto *CV = dyn_cast<ConstantVector>(Val: Sel->getCondition())) {
1742 for (unsigned i = 0; i < VWidth; i++) {
1743 Constant *CElt = CV->getAggregateElement(Elt: i);
1744
1745 // isNullValue() always returns false when called on a ConstantExpr.
1746 if (CElt->isNullValue())
1747 DemandedLHS.clearBit(BitPosition: i);
1748 else if (CElt->isOneValue())
1749 DemandedRHS.clearBit(BitPosition: i);
1750 }
1751 }
1752
1753 simplifyAndSetOp(I, 1, DemandedLHS, PoisonElts2);
1754 simplifyAndSetOp(I, 2, DemandedRHS, PoisonElts3);
1755
1756 // Output elements are undefined if the element from each arm is undefined.
1757 // TODO: This can be improved. See comment in select condition handling.
1758 PoisonElts = PoisonElts2 & PoisonElts3;
1759 break;
1760 }
1761 case Instruction::BitCast: {
1762 // Vector->vector casts only.
1763 VectorType *VTy = dyn_cast<VectorType>(Val: I->getOperand(i: 0)->getType());
1764 if (!VTy) break;
1765 unsigned InVWidth = cast<FixedVectorType>(Val: VTy)->getNumElements();
1766 APInt InputDemandedElts(InVWidth, 0);
1767 PoisonElts2 = APInt(InVWidth, 0);
1768 unsigned Ratio;
1769
1770 if (VWidth == InVWidth) {
1771 // If we are converting from <4 x i32> -> <4 x f32>, we demand the same
1772 // elements as are demanded of us.
1773 Ratio = 1;
1774 InputDemandedElts = DemandedElts;
1775 } else if ((VWidth % InVWidth) == 0) {
1776 // If the number of elements in the output is a multiple of the number of
1777 // elements in the input then an input element is live if any of the
1778 // corresponding output elements are live.
1779 Ratio = VWidth / InVWidth;
1780 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1781 if (DemandedElts[OutIdx])
1782 InputDemandedElts.setBit(OutIdx / Ratio);
1783 } else if ((InVWidth % VWidth) == 0) {
1784 // If the number of elements in the input is a multiple of the number of
1785 // elements in the output then an input element is live if the
1786 // corresponding output element is live.
1787 Ratio = InVWidth / VWidth;
1788 for (unsigned InIdx = 0; InIdx != InVWidth; ++InIdx)
1789 if (DemandedElts[InIdx / Ratio])
1790 InputDemandedElts.setBit(InIdx);
1791 } else {
1792 // Unsupported so far.
1793 break;
1794 }
1795
1796 simplifyAndSetOp(I, 0, InputDemandedElts, PoisonElts2);
1797
1798 if (VWidth == InVWidth) {
1799 PoisonElts = PoisonElts2;
1800 } else if ((VWidth % InVWidth) == 0) {
1801 // If the number of elements in the output is a multiple of the number of
1802 // elements in the input then an output element is undef if the
1803 // corresponding input element is undef.
1804 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx)
1805 if (PoisonElts2[OutIdx / Ratio])
1806 PoisonElts.setBit(OutIdx);
1807 } else if ((InVWidth % VWidth) == 0) {
1808 // If the number of elements in the input is a multiple of the number of
1809 // elements in the output then an output element is undef if all of the
1810 // corresponding input elements are undef.
1811 for (unsigned OutIdx = 0; OutIdx != VWidth; ++OutIdx) {
1812 APInt SubUndef = PoisonElts2.lshr(shiftAmt: OutIdx * Ratio).zextOrTrunc(width: Ratio);
1813 if (SubUndef.popcount() == Ratio)
1814 PoisonElts.setBit(OutIdx);
1815 }
1816 } else {
1817 llvm_unreachable("Unimp");
1818 }
1819 break;
1820 }
1821 case Instruction::FPTrunc:
1822 case Instruction::FPExt:
1823 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1824 break;
1825
1826 case Instruction::Call: {
1827 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I);
1828 if (!II) break;
1829 switch (II->getIntrinsicID()) {
1830 case Intrinsic::masked_gather: // fallthrough
1831 case Intrinsic::masked_load: {
1832 // Subtlety: If we load from a pointer, the pointer must be valid
1833 // regardless of whether the element is demanded. Doing otherwise risks
1834 // segfaults which didn't exist in the original program.
1835 APInt DemandedPtrs(APInt::getAllOnes(numBits: VWidth)),
1836 DemandedPassThrough(DemandedElts);
1837 if (auto *CV = dyn_cast<ConstantVector>(Val: II->getOperand(i_nocapture: 2)))
1838 for (unsigned i = 0; i < VWidth; i++) {
1839 Constant *CElt = CV->getAggregateElement(Elt: i);
1840 if (CElt->isNullValue())
1841 DemandedPtrs.clearBit(BitPosition: i);
1842 else if (CElt->isAllOnesValue())
1843 DemandedPassThrough.clearBit(BitPosition: i);
1844 }
1845 if (II->getIntrinsicID() == Intrinsic::masked_gather)
1846 simplifyAndSetOp(II, 0, DemandedPtrs, PoisonElts2);
1847 simplifyAndSetOp(II, 3, DemandedPassThrough, PoisonElts3);
1848
1849 // Output elements are undefined if the element from both sources are.
1850 // TODO: can strengthen via mask as well.
1851 PoisonElts = PoisonElts2 & PoisonElts3;
1852 break;
1853 }
1854 default: {
1855 // Handle target specific intrinsics
1856 std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
1857 II&: *II, DemandedElts, UndefElts&: PoisonElts, UndefElts2&: PoisonElts2, UndefElts3&: PoisonElts3,
1858 SimplifyAndSetOp: simplifyAndSetOp);
1859 if (V)
1860 return *V;
1861 break;
1862 }
1863 } // switch on IntrinsicID
1864 break;
1865 } // case Call
1866 } // switch on Opcode
1867
1868 // TODO: We bail completely on integer div/rem and shifts because they have
1869 // UB/poison potential, but that should be refined.
1870 BinaryOperator *BO;
1871 if (match(V: I, P: m_BinOp(I&: BO)) && !BO->isIntDivRem() && !BO->isShift()) {
1872 Value *X = BO->getOperand(i_nocapture: 0);
1873 Value *Y = BO->getOperand(i_nocapture: 1);
1874
1875 // Look for an equivalent binop except that one operand has been shuffled.
1876 // If the demand for this binop only includes elements that are the same as
1877 // the other binop, then we may be able to replace this binop with a use of
1878 // the earlier one.
1879 //
1880 // Example:
1881 // %other_bo = bo (shuf X, {0}), Y
1882 // %this_extracted_bo = extelt (bo X, Y), 0
1883 // -->
1884 // %other_bo = bo (shuf X, {0}), Y
1885 // %this_extracted_bo = extelt %other_bo, 0
1886 //
1887 // TODO: Handle demand of an arbitrary single element or more than one
1888 // element instead of just element 0.
1889 // TODO: Unlike general demanded elements transforms, this should be safe
1890 // for any (div/rem/shift) opcode too.
1891 if (DemandedElts == 1 && !X->hasOneUse() && !Y->hasOneUse() &&
1892 BO->hasOneUse() ) {
1893
1894 auto findShufBO = [&](bool MatchShufAsOp0) -> User * {
1895 // Try to use shuffle-of-operand in place of an operand:
1896 // bo X, Y --> bo (shuf X), Y
1897 // bo X, Y --> bo X, (shuf Y)
1898
1899 Value *OtherOp = MatchShufAsOp0 ? Y : X;
1900 if (!OtherOp->hasUseList())
1901 return nullptr;
1902
1903 BinaryOperator::BinaryOps Opcode = BO->getOpcode();
1904 Value *ShufOp = MatchShufAsOp0 ? X : Y;
1905
1906 for (User *U : OtherOp->users()) {
1907 ArrayRef<int> Mask;
1908 auto Shuf = m_Shuffle(v1: m_Specific(V: ShufOp), v2: m_Value(), mask: m_Mask(Mask));
1909 if (BO->isCommutative()
1910 ? match(V: U, P: m_c_BinOp(Opcode, L: Shuf, R: m_Specific(V: OtherOp)))
1911 : MatchShufAsOp0
1912 ? match(V: U, P: m_BinOp(Opcode, L: Shuf, R: m_Specific(V: OtherOp)))
1913 : match(V: U, P: m_BinOp(Opcode, L: m_Specific(V: OtherOp), R: Shuf)))
1914 if (match(Mask, P: m_ZeroMask()) && Mask[0] != PoisonMaskElem)
1915 if (DT.dominates(Def: U, User: I))
1916 return U;
1917 }
1918 return nullptr;
1919 };
1920
1921 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ true))
1922 return ShufBO;
1923 if (User *ShufBO = findShufBO(/* MatchShufAsOp0 */ false))
1924 return ShufBO;
1925 }
1926
1927 simplifyAndSetOp(I, 0, DemandedElts, PoisonElts);
1928 simplifyAndSetOp(I, 1, DemandedElts, PoisonElts2);
1929
1930 // Output elements are undefined if both are undefined. Consider things
1931 // like undef & 0. The result is known zero, not undef.
1932 PoisonElts &= PoisonElts2;
1933 }
1934
1935 // If we've proven all of the lanes poison, return a poison value.
1936 // TODO: Intersect w/demanded lanes
1937 if (PoisonElts.isAllOnes())
1938 return PoisonValue::get(T: I->getType());
1939
1940 return MadeChange ? I : nullptr;
1941}
1942
1943/// For floating-point classes that resolve to a single bit pattern, return that
1944/// value.
1945static Constant *getFPClassConstant(Type *Ty, FPClassTest Mask) {
1946 if (Mask == fcNone)
1947 return PoisonValue::get(T: Ty);
1948
1949 if (Mask == fcPosZero)
1950 return Constant::getNullValue(Ty);
1951
1952 // TODO: Support aggregate types that are allowed by FPMathOperator.
1953 if (Ty->isAggregateType())
1954 return nullptr;
1955
1956 switch (Mask) {
1957 case fcNegZero:
1958 return ConstantFP::getZero(Ty, Negative: true);
1959 case fcPosInf:
1960 return ConstantFP::getInfinity(Ty);
1961 case fcNegInf:
1962 return ConstantFP::getInfinity(Ty, Negative: true);
1963 default:
1964 return nullptr;
1965 }
1966}
1967
1968Value *InstCombinerImpl::SimplifyDemandedUseFPClass(Value *V,
1969 FPClassTest DemandedMask,
1970 KnownFPClass &Known,
1971 Instruction *CxtI,
1972 unsigned Depth) {
1973 assert(Depth <= MaxAnalysisRecursionDepth && "Limit Search Depth");
1974 Type *VTy = V->getType();
1975
1976 assert(Known == KnownFPClass() && "expected uninitialized state");
1977
1978 if (DemandedMask == fcNone)
1979 return isa<UndefValue>(Val: V) ? nullptr : PoisonValue::get(T: VTy);
1980
1981 if (Depth == MaxAnalysisRecursionDepth)
1982 return nullptr;
1983
1984 Instruction *I = dyn_cast<Instruction>(Val: V);
1985 if (!I) {
1986 // Handle constants and arguments
1987 Known = computeKnownFPClass(Val: V, Interested: fcAllFlags, CtxI: CxtI, Depth: Depth + 1);
1988 Value *FoldedToConst =
1989 getFPClassConstant(Ty: VTy, Mask: DemandedMask & Known.KnownFPClasses);
1990 return FoldedToConst == V ? nullptr : FoldedToConst;
1991 }
1992
1993 if (!I->hasOneUse())
1994 return nullptr;
1995
1996 if (auto *FPOp = dyn_cast<FPMathOperator>(Val: I)) {
1997 if (FPOp->hasNoNaNs())
1998 DemandedMask &= ~fcNan;
1999 if (FPOp->hasNoInfs())
2000 DemandedMask &= ~fcInf;
2001 }
2002 switch (I->getOpcode()) {
2003 case Instruction::FNeg: {
2004 if (SimplifyDemandedFPClass(I, Op: 0, DemandedMask: llvm::fneg(Mask: DemandedMask), Known,
2005 Depth: Depth + 1))
2006 return I;
2007 Known.fneg();
2008 break;
2009 }
2010 case Instruction::Call: {
2011 CallInst *CI = cast<CallInst>(Val: I);
2012 switch (CI->getIntrinsicID()) {
2013 case Intrinsic::fabs:
2014 if (SimplifyDemandedFPClass(I, Op: 0, DemandedMask: llvm::inverse_fabs(Mask: DemandedMask), Known,
2015 Depth: Depth + 1))
2016 return I;
2017 Known.fabs();
2018 break;
2019 case Intrinsic::arithmetic_fence:
2020 if (SimplifyDemandedFPClass(I, Op: 0, DemandedMask, Known, Depth: Depth + 1))
2021 return I;
2022 break;
2023 case Intrinsic::copysign: {
2024 // Flip on more potentially demanded classes
2025 const FPClassTest DemandedMaskAnySign = llvm::unknown_sign(Mask: DemandedMask);
2026 if (SimplifyDemandedFPClass(I, Op: 0, DemandedMask: DemandedMaskAnySign, Known, Depth: Depth + 1))
2027 return I;
2028
2029 if ((DemandedMask & fcNegative) == DemandedMask) {
2030 // Roundabout way of replacing with fneg(fabs)
2031 I->setOperand(i: 1, Val: ConstantFP::get(Ty: VTy, V: -1.0));
2032 return I;
2033 }
2034
2035 if ((DemandedMask & fcPositive) == DemandedMask) {
2036 // Roundabout way of replacing with fabs
2037 I->setOperand(i: 1, Val: ConstantFP::getZero(Ty: VTy));
2038 return I;
2039 }
2040
2041 KnownFPClass KnownSign =
2042 computeKnownFPClass(Val: I->getOperand(i: 1), Interested: fcAllFlags, CtxI: CxtI, Depth: Depth + 1);
2043 Known.copysign(Sign: KnownSign);
2044 break;
2045 }
2046 default:
2047 Known = computeKnownFPClass(Val: I, Interested: ~DemandedMask, CtxI: CxtI, Depth: Depth + 1);
2048 break;
2049 }
2050
2051 break;
2052 }
2053 case Instruction::Select: {
2054 KnownFPClass KnownLHS, KnownRHS;
2055 if (SimplifyDemandedFPClass(I, Op: 2, DemandedMask, Known&: KnownRHS, Depth: Depth + 1) ||
2056 SimplifyDemandedFPClass(I, Op: 1, DemandedMask, Known&: KnownLHS, Depth: Depth + 1))
2057 return I;
2058
2059 if (KnownLHS.isKnownNever(Mask: DemandedMask))
2060 return I->getOperand(i: 2);
2061 if (KnownRHS.isKnownNever(Mask: DemandedMask))
2062 return I->getOperand(i: 1);
2063
2064 // TODO: Recognize clamping patterns
2065 Known = KnownLHS | KnownRHS;
2066 break;
2067 }
2068 default:
2069 Known = computeKnownFPClass(Val: I, Interested: ~DemandedMask, CtxI: CxtI, Depth: Depth + 1);
2070 break;
2071 }
2072
2073 return getFPClassConstant(Ty: VTy, Mask: DemandedMask & Known.KnownFPClasses);
2074}
2075
2076bool InstCombinerImpl::SimplifyDemandedFPClass(Instruction *I, unsigned OpNo,
2077 FPClassTest DemandedMask,
2078 KnownFPClass &Known,
2079 unsigned Depth) {
2080 Use &U = I->getOperandUse(i: OpNo);
2081 Value *NewVal =
2082 SimplifyDemandedUseFPClass(V: U.get(), DemandedMask, Known, CxtI: I, Depth);
2083 if (!NewVal)
2084 return false;
2085 if (Instruction *OpInst = dyn_cast<Instruction>(Val&: U))
2086 salvageDebugInfo(I&: *OpInst);
2087
2088 replaceUse(U, NewValue: NewVal);
2089 return true;
2090}
2091