1 | //===- InstCombineCasts.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements the visit functions for cast operations. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "InstCombineInternal.h" |
14 | #include "llvm/ADT/SetVector.h" |
15 | #include "llvm/Analysis/ConstantFolding.h" |
16 | #include "llvm/IR/DataLayout.h" |
17 | #include "llvm/IR/DebugInfo.h" |
18 | #include "llvm/IR/PatternMatch.h" |
19 | #include "llvm/Support/KnownBits.h" |
20 | #include "llvm/Transforms/InstCombine/InstCombiner.h" |
21 | #include <optional> |
22 | |
23 | using namespace llvm; |
24 | using namespace PatternMatch; |
25 | |
26 | #define DEBUG_TYPE "instcombine" |
27 | |
28 | /// Given an expression that CanEvaluateTruncated or CanEvaluateSExtd returns |
29 | /// true for, actually insert the code to evaluate the expression. |
30 | Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, |
31 | bool isSigned) { |
32 | if (Constant *C = dyn_cast<Constant>(Val: V)) |
33 | return ConstantFoldIntegerCast(C, DestTy: Ty, IsSigned: isSigned, DL); |
34 | |
35 | // Otherwise, it must be an instruction. |
36 | Instruction *I = cast<Instruction>(Val: V); |
37 | Instruction *Res = nullptr; |
38 | unsigned Opc = I->getOpcode(); |
39 | switch (Opc) { |
40 | case Instruction::Add: |
41 | case Instruction::Sub: |
42 | case Instruction::Mul: |
43 | case Instruction::And: |
44 | case Instruction::Or: |
45 | case Instruction::Xor: |
46 | case Instruction::AShr: |
47 | case Instruction::LShr: |
48 | case Instruction::Shl: |
49 | case Instruction::UDiv: |
50 | case Instruction::URem: { |
51 | Value *LHS = EvaluateInDifferentType(V: I->getOperand(i: 0), Ty, isSigned); |
52 | Value *RHS = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty, isSigned); |
53 | Res = BinaryOperator::Create(Op: (Instruction::BinaryOps)Opc, S1: LHS, S2: RHS); |
54 | break; |
55 | } |
56 | case Instruction::Trunc: |
57 | case Instruction::ZExt: |
58 | case Instruction::SExt: |
59 | // If the source type of the cast is the type we're trying for then we can |
60 | // just return the source. There's no need to insert it because it is not |
61 | // new. |
62 | if (I->getOperand(i: 0)->getType() == Ty) |
63 | return I->getOperand(i: 0); |
64 | |
65 | // Otherwise, must be the same type of cast, so just reinsert a new one. |
66 | // This also handles the case of zext(trunc(x)) -> zext(x). |
67 | Res = CastInst::CreateIntegerCast(S: I->getOperand(i: 0), Ty, |
68 | isSigned: Opc == Instruction::SExt); |
69 | break; |
70 | case Instruction::Select: { |
71 | Value *True = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty, isSigned); |
72 | Value *False = EvaluateInDifferentType(V: I->getOperand(i: 2), Ty, isSigned); |
73 | Res = SelectInst::Create(C: I->getOperand(i: 0), S1: True, S2: False); |
74 | break; |
75 | } |
76 | case Instruction::PHI: { |
77 | PHINode *OPN = cast<PHINode>(Val: I); |
78 | PHINode *NPN = PHINode::Create(Ty, NumReservedValues: OPN->getNumIncomingValues()); |
79 | for (unsigned i = 0, e = OPN->getNumIncomingValues(); i != e; ++i) { |
80 | Value *V = |
81 | EvaluateInDifferentType(V: OPN->getIncomingValue(i), Ty, isSigned); |
82 | NPN->addIncoming(V, BB: OPN->getIncomingBlock(i)); |
83 | } |
84 | Res = NPN; |
85 | break; |
86 | } |
87 | case Instruction::FPToUI: |
88 | case Instruction::FPToSI: |
89 | Res = CastInst::Create( |
90 | static_cast<Instruction::CastOps>(Opc), S: I->getOperand(i: 0), Ty); |
91 | break; |
92 | case Instruction::Call: |
93 | if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) { |
94 | switch (II->getIntrinsicID()) { |
95 | default: |
96 | llvm_unreachable("Unsupported call!" ); |
97 | case Intrinsic::vscale: { |
98 | Function *Fn = |
99 | Intrinsic::getDeclaration(M: I->getModule(), id: Intrinsic::vscale, Tys: {Ty}); |
100 | Res = CallInst::Create(Ty: Fn->getFunctionType(), F: Fn); |
101 | break; |
102 | } |
103 | } |
104 | } |
105 | break; |
106 | case Instruction::ShuffleVector: { |
107 | auto *ScalarTy = cast<VectorType>(Val: Ty)->getElementType(); |
108 | auto *VTy = cast<VectorType>(Val: I->getOperand(i: 0)->getType()); |
109 | auto *FixedTy = VectorType::get(ElementType: ScalarTy, EC: VTy->getElementCount()); |
110 | Value *Op0 = EvaluateInDifferentType(V: I->getOperand(i: 0), Ty: FixedTy, isSigned); |
111 | Value *Op1 = EvaluateInDifferentType(V: I->getOperand(i: 1), Ty: FixedTy, isSigned); |
112 | Res = new ShuffleVectorInst(Op0, Op1, |
113 | cast<ShuffleVectorInst>(Val: I)->getShuffleMask()); |
114 | break; |
115 | } |
116 | default: |
117 | // TODO: Can handle more cases here. |
118 | llvm_unreachable("Unreachable!" ); |
119 | } |
120 | |
121 | Res->takeName(V: I); |
122 | return InsertNewInstWith(New: Res, Old: I->getIterator()); |
123 | } |
124 | |
125 | Instruction::CastOps |
126 | InstCombinerImpl::isEliminableCastPair(const CastInst *CI1, |
127 | const CastInst *CI2) { |
128 | Type *SrcTy = CI1->getSrcTy(); |
129 | Type *MidTy = CI1->getDestTy(); |
130 | Type *DstTy = CI2->getDestTy(); |
131 | |
132 | Instruction::CastOps firstOp = CI1->getOpcode(); |
133 | Instruction::CastOps secondOp = CI2->getOpcode(); |
134 | Type *SrcIntPtrTy = |
135 | SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; |
136 | Type *MidIntPtrTy = |
137 | MidTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(MidTy) : nullptr; |
138 | Type *DstIntPtrTy = |
139 | DstTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(DstTy) : nullptr; |
140 | unsigned Res = CastInst::isEliminableCastPair(firstOpcode: firstOp, secondOpcode: secondOp, SrcTy, MidTy, |
141 | DstTy, SrcIntPtrTy, MidIntPtrTy, |
142 | DstIntPtrTy); |
143 | |
144 | // We don't want to form an inttoptr or ptrtoint that converts to an integer |
145 | // type that differs from the pointer size. |
146 | if ((Res == Instruction::IntToPtr && SrcTy != DstIntPtrTy) || |
147 | (Res == Instruction::PtrToInt && DstTy != SrcIntPtrTy)) |
148 | Res = 0; |
149 | |
150 | return Instruction::CastOps(Res); |
151 | } |
152 | |
153 | /// Implement the transforms common to all CastInst visitors. |
154 | Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { |
155 | Value *Src = CI.getOperand(i_nocapture: 0); |
156 | Type *Ty = CI.getType(); |
157 | |
158 | if (auto *SrcC = dyn_cast<Constant>(Val: Src)) |
159 | if (Constant *Res = ConstantFoldCastOperand(Opcode: CI.getOpcode(), C: SrcC, DestTy: Ty, DL)) |
160 | return replaceInstUsesWith(I&: CI, V: Res); |
161 | |
162 | // Try to eliminate a cast of a cast. |
163 | if (auto *CSrc = dyn_cast<CastInst>(Val: Src)) { // A->B->C cast |
164 | if (Instruction::CastOps NewOpc = isEliminableCastPair(CI1: CSrc, CI2: &CI)) { |
165 | // The first cast (CSrc) is eliminable so we need to fix up or replace |
166 | // the second cast (CI). CSrc will then have a good chance of being dead. |
167 | auto *Res = CastInst::Create(NewOpc, S: CSrc->getOperand(i_nocapture: 0), Ty); |
168 | // Point debug users of the dying cast to the new one. |
169 | if (CSrc->hasOneUse()) |
170 | replaceAllDbgUsesWith(From&: *CSrc, To&: *Res, DomPoint&: CI, DT); |
171 | return Res; |
172 | } |
173 | } |
174 | |
175 | if (auto *Sel = dyn_cast<SelectInst>(Val: Src)) { |
176 | // We are casting a select. Try to fold the cast into the select if the |
177 | // select does not have a compare instruction with matching operand types |
178 | // or the select is likely better done in a narrow type. |
179 | // Creating a select with operands that are different sizes than its |
180 | // condition may inhibit other folds and lead to worse codegen. |
181 | auto *Cmp = dyn_cast<CmpInst>(Val: Sel->getCondition()); |
182 | if (!Cmp || Cmp->getOperand(i_nocapture: 0)->getType() != Sel->getType() || |
183 | (CI.getOpcode() == Instruction::Trunc && |
184 | shouldChangeType(From: CI.getSrcTy(), To: CI.getType()))) { |
185 | |
186 | // If it's a bitcast involving vectors, make sure it has the same number |
187 | // of elements on both sides. |
188 | if (CI.getOpcode() != Instruction::BitCast || |
189 | match(V: &CI, P: m_ElementWiseBitCast(Op: m_Value()))) { |
190 | if (Instruction *NV = FoldOpIntoSelect(Op&: CI, SI: Sel)) { |
191 | replaceAllDbgUsesWith(From&: *Sel, To&: *NV, DomPoint&: CI, DT); |
192 | return NV; |
193 | } |
194 | } |
195 | } |
196 | } |
197 | |
198 | // If we are casting a PHI, then fold the cast into the PHI. |
199 | if (auto *PN = dyn_cast<PHINode>(Val: Src)) { |
200 | // Don't do this if it would create a PHI node with an illegal type from a |
201 | // legal type. |
202 | if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || |
203 | shouldChangeType(From: CI.getSrcTy(), To: CI.getType())) |
204 | if (Instruction *NV = foldOpIntoPhi(I&: CI, PN)) |
205 | return NV; |
206 | } |
207 | |
208 | // Canonicalize a unary shuffle after the cast if neither operation changes |
209 | // the size or element size of the input vector. |
210 | // TODO: We could allow size-changing ops if that doesn't harm codegen. |
211 | // cast (shuffle X, Mask) --> shuffle (cast X), Mask |
212 | Value *X; |
213 | ArrayRef<int> Mask; |
214 | if (match(V: Src, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: X), v2: m_Undef(), mask: m_Mask(Mask))))) { |
215 | // TODO: Allow scalable vectors? |
216 | auto *SrcTy = dyn_cast<FixedVectorType>(Val: X->getType()); |
217 | auto *DestTy = dyn_cast<FixedVectorType>(Val: Ty); |
218 | if (SrcTy && DestTy && |
219 | SrcTy->getNumElements() == DestTy->getNumElements() && |
220 | SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) { |
221 | Value *CastX = Builder.CreateCast(Op: CI.getOpcode(), V: X, DestTy); |
222 | return new ShuffleVectorInst(CastX, Mask); |
223 | } |
224 | } |
225 | |
226 | return nullptr; |
227 | } |
228 | |
229 | /// Constants and extensions/truncates from the destination type are always |
230 | /// free to be evaluated in that type. This is a helper for canEvaluate*. |
231 | static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { |
232 | if (isa<Constant>(Val: V)) |
233 | return match(V, P: m_ImmConstant()); |
234 | |
235 | Value *X; |
236 | if ((match(V, P: m_ZExtOrSExt(Op: m_Value(V&: X))) || match(V, P: m_Trunc(Op: m_Value(V&: X)))) && |
237 | X->getType() == Ty) |
238 | return true; |
239 | |
240 | return false; |
241 | } |
242 | |
243 | /// Filter out values that we can not evaluate in the destination type for free. |
244 | /// This is a helper for canEvaluate*. |
245 | static bool canNotEvaluateInType(Value *V, Type *Ty) { |
246 | if (!isa<Instruction>(Val: V)) |
247 | return true; |
248 | // We don't extend or shrink something that has multiple uses -- doing so |
249 | // would require duplicating the instruction which isn't profitable. |
250 | if (!V->hasOneUse()) |
251 | return true; |
252 | |
253 | return false; |
254 | } |
255 | |
256 | /// Return true if we can evaluate the specified expression tree as type Ty |
257 | /// instead of its larger type, and arrive with the same value. |
258 | /// This is used by code that tries to eliminate truncates. |
259 | /// |
260 | /// Ty will always be a type smaller than V. We should return true if trunc(V) |
261 | /// can be computed by computing V in the smaller type. If V is an instruction, |
262 | /// then trunc(inst(x,y)) can be computed as inst(trunc(x),trunc(y)), which only |
263 | /// makes sense if x and y can be efficiently truncated. |
264 | /// |
265 | /// This function works on both vectors and scalars. |
266 | /// |
267 | static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, |
268 | Instruction *CxtI) { |
269 | if (canAlwaysEvaluateInType(V, Ty)) |
270 | return true; |
271 | if (canNotEvaluateInType(V, Ty)) |
272 | return false; |
273 | |
274 | auto *I = cast<Instruction>(Val: V); |
275 | Type *OrigTy = V->getType(); |
276 | switch (I->getOpcode()) { |
277 | case Instruction::Add: |
278 | case Instruction::Sub: |
279 | case Instruction::Mul: |
280 | case Instruction::And: |
281 | case Instruction::Or: |
282 | case Instruction::Xor: |
283 | // These operators can all arbitrarily be extended or truncated. |
284 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
285 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
286 | |
287 | case Instruction::UDiv: |
288 | case Instruction::URem: { |
289 | // UDiv and URem can be truncated if all the truncated bits are zero. |
290 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
291 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
292 | assert(BitWidth < OrigBitWidth && "Unexpected bitwidths!" ); |
293 | APInt Mask = APInt::getBitsSetFrom(numBits: OrigBitWidth, loBit: BitWidth); |
294 | // Do not preserve the original context instruction. Simplifying div/rem |
295 | // based on later context may introduce a trap. |
296 | if (IC.MaskedValueIsZero(V: I->getOperand(i: 0), Mask, Depth: 0, CxtI: I) && |
297 | IC.MaskedValueIsZero(V: I->getOperand(i: 1), Mask, Depth: 0, CxtI: I)) { |
298 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI: I) && |
299 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI: I); |
300 | } |
301 | break; |
302 | } |
303 | case Instruction::Shl: { |
304 | // If we are truncating the result of this SHL, and if it's a shift of an |
305 | // inrange amount, we can always perform a SHL in a smaller type. |
306 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
307 | KnownBits AmtKnownBits = |
308 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
309 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth)) |
310 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
311 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
312 | break; |
313 | } |
314 | case Instruction::LShr: { |
315 | // If this is a truncate of a logical shr, we can truncate it to a smaller |
316 | // lshr iff we know that the bits we would otherwise be shifting in are |
317 | // already zeros. |
318 | // TODO: It is enough to check that the bits we would be shifting in are |
319 | // zero - use AmtKnownBits.getMaxValue(). |
320 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
321 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
322 | KnownBits AmtKnownBits = |
323 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
324 | APInt ShiftedBits = APInt::getBitsSetFrom(numBits: OrigBitWidth, loBit: BitWidth); |
325 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth) && |
326 | IC.MaskedValueIsZero(V: I->getOperand(i: 0), Mask: ShiftedBits, Depth: 0, CxtI)) { |
327 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
328 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
329 | } |
330 | break; |
331 | } |
332 | case Instruction::AShr: { |
333 | // If this is a truncate of an arithmetic shr, we can truncate it to a |
334 | // smaller ashr iff we know that all the bits from the sign bit of the |
335 | // original type and the sign bit of the truncate type are similar. |
336 | // TODO: It is enough to check that the bits we would be shifting in are |
337 | // similar to sign bit of the truncate type. |
338 | uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); |
339 | uint32_t BitWidth = Ty->getScalarSizeInBits(); |
340 | KnownBits AmtKnownBits = |
341 | llvm::computeKnownBits(V: I->getOperand(i: 1), DL: IC.getDataLayout()); |
342 | unsigned ShiftedBits = OrigBitWidth - BitWidth; |
343 | if (AmtKnownBits.getMaxValue().ult(RHS: BitWidth) && |
344 | ShiftedBits < IC.ComputeNumSignBits(Op: I->getOperand(i: 0), Depth: 0, CxtI)) |
345 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
346 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
347 | break; |
348 | } |
349 | case Instruction::Trunc: |
350 | // trunc(trunc(x)) -> trunc(x) |
351 | return true; |
352 | case Instruction::ZExt: |
353 | case Instruction::SExt: |
354 | // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest |
355 | // trunc(ext(x)) -> trunc(x) if the source type is larger than the new dest |
356 | return true; |
357 | case Instruction::Select: { |
358 | SelectInst *SI = cast<SelectInst>(Val: I); |
359 | return canEvaluateTruncated(V: SI->getTrueValue(), Ty, IC, CxtI) && |
360 | canEvaluateTruncated(V: SI->getFalseValue(), Ty, IC, CxtI); |
361 | } |
362 | case Instruction::PHI: { |
363 | // We can change a phi if we can change all operands. Note that we never |
364 | // get into trouble with cyclic PHIs here because we only consider |
365 | // instructions with a single use. |
366 | PHINode *PN = cast<PHINode>(Val: I); |
367 | for (Value *IncValue : PN->incoming_values()) |
368 | if (!canEvaluateTruncated(V: IncValue, Ty, IC, CxtI)) |
369 | return false; |
370 | return true; |
371 | } |
372 | case Instruction::FPToUI: |
373 | case Instruction::FPToSI: { |
374 | // If the integer type can hold the max FP value, it is safe to cast |
375 | // directly to that type. Otherwise, we may create poison via overflow |
376 | // that did not exist in the original code. |
377 | Type *InputTy = I->getOperand(i: 0)->getType()->getScalarType(); |
378 | const fltSemantics &Semantics = InputTy->getFltSemantics(); |
379 | uint32_t MinBitWidth = |
380 | APFloatBase::semanticsIntSizeInBits(Semantics, |
381 | I->getOpcode() == Instruction::FPToSI); |
382 | return Ty->getScalarSizeInBits() >= MinBitWidth; |
383 | } |
384 | case Instruction::ShuffleVector: |
385 | return canEvaluateTruncated(V: I->getOperand(i: 0), Ty, IC, CxtI) && |
386 | canEvaluateTruncated(V: I->getOperand(i: 1), Ty, IC, CxtI); |
387 | default: |
388 | // TODO: Can handle more cases here. |
389 | break; |
390 | } |
391 | |
392 | return false; |
393 | } |
394 | |
395 | /// Given a vector that is bitcast to an integer, optionally logically |
396 | /// right-shifted, and truncated, convert it to an extractelement. |
397 | /// Example (big endian): |
398 | /// trunc (lshr (bitcast <4 x i32> %X to i128), 32) to i32 |
399 | /// ---> |
400 | /// extractelement <4 x i32> %X, 1 |
401 | static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, |
402 | InstCombinerImpl &IC) { |
403 | Value *TruncOp = Trunc.getOperand(i_nocapture: 0); |
404 | Type *DestType = Trunc.getType(); |
405 | if (!TruncOp->hasOneUse() || !isa<IntegerType>(Val: DestType)) |
406 | return nullptr; |
407 | |
408 | Value *VecInput = nullptr; |
409 | ConstantInt *ShiftVal = nullptr; |
410 | if (!match(V: TruncOp, P: m_CombineOr(L: m_BitCast(Op: m_Value(V&: VecInput)), |
411 | R: m_LShr(L: m_BitCast(Op: m_Value(V&: VecInput)), |
412 | R: m_ConstantInt(CI&: ShiftVal)))) || |
413 | !isa<VectorType>(Val: VecInput->getType())) |
414 | return nullptr; |
415 | |
416 | VectorType *VecType = cast<VectorType>(Val: VecInput->getType()); |
417 | unsigned VecWidth = VecType->getPrimitiveSizeInBits(); |
418 | unsigned DestWidth = DestType->getPrimitiveSizeInBits(); |
419 | unsigned ShiftAmount = ShiftVal ? ShiftVal->getZExtValue() : 0; |
420 | |
421 | if ((VecWidth % DestWidth != 0) || (ShiftAmount % DestWidth != 0)) |
422 | return nullptr; |
423 | |
424 | // If the element type of the vector doesn't match the result type, |
425 | // bitcast it to a vector type that we can extract from. |
426 | unsigned NumVecElts = VecWidth / DestWidth; |
427 | if (VecType->getElementType() != DestType) { |
428 | VecType = FixedVectorType::get(ElementType: DestType, NumElts: NumVecElts); |
429 | VecInput = IC.Builder.CreateBitCast(V: VecInput, DestTy: VecType, Name: "bc" ); |
430 | } |
431 | |
432 | unsigned Elt = ShiftAmount / DestWidth; |
433 | if (IC.getDataLayout().isBigEndian()) |
434 | Elt = NumVecElts - 1 - Elt; |
435 | |
436 | return ExtractElementInst::Create(Vec: VecInput, Idx: IC.Builder.getInt32(C: Elt)); |
437 | } |
438 | |
439 | /// Funnel/Rotate left/right may occur in a wider type than necessary because of |
440 | /// type promotion rules. Try to narrow the inputs and convert to funnel shift. |
441 | Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { |
442 | assert((isa<VectorType>(Trunc.getSrcTy()) || |
443 | shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && |
444 | "Don't narrow to an illegal scalar type" ); |
445 | |
446 | // Bail out on strange types. It is possible to handle some of these patterns |
447 | // even with non-power-of-2 sizes, but it is not a likely scenario. |
448 | Type *DestTy = Trunc.getType(); |
449 | unsigned NarrowWidth = DestTy->getScalarSizeInBits(); |
450 | unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); |
451 | if (!isPowerOf2_32(Value: NarrowWidth)) |
452 | return nullptr; |
453 | |
454 | // First, find an or'd pair of opposite shifts: |
455 | // trunc (or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1)) |
456 | BinaryOperator *Or0, *Or1; |
457 | if (!match(V: Trunc.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_Or(L: m_BinOp(I&: Or0), R: m_BinOp(I&: Or1))))) |
458 | return nullptr; |
459 | |
460 | Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; |
461 | if (!match(V: Or0, P: m_OneUse(SubPattern: m_LogicalShift(L: m_Value(V&: ShVal0), R: m_Value(V&: ShAmt0)))) || |
462 | !match(V: Or1, P: m_OneUse(SubPattern: m_LogicalShift(L: m_Value(V&: ShVal1), R: m_Value(V&: ShAmt1)))) || |
463 | Or0->getOpcode() == Or1->getOpcode()) |
464 | return nullptr; |
465 | |
466 | // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). |
467 | if (Or0->getOpcode() == BinaryOperator::LShr) { |
468 | std::swap(a&: Or0, b&: Or1); |
469 | std::swap(a&: ShVal0, b&: ShVal1); |
470 | std::swap(a&: ShAmt0, b&: ShAmt1); |
471 | } |
472 | assert(Or0->getOpcode() == BinaryOperator::Shl && |
473 | Or1->getOpcode() == BinaryOperator::LShr && |
474 | "Illegal or(shift,shift) pair" ); |
475 | |
476 | // Match the shift amount operands for a funnel/rotate pattern. This always |
477 | // matches a subtraction on the R operand. |
478 | auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { |
479 | // The shift amounts may add up to the narrow bit width: |
480 | // (shl ShVal0, L) | (lshr ShVal1, Width - L) |
481 | // If this is a funnel shift (different operands are shifted), then the |
482 | // shift amount can not over-shift (create poison) in the narrow type. |
483 | unsigned MaxShiftAmountWidth = Log2_32(Value: NarrowWidth); |
484 | APInt HiBitMask = ~APInt::getLowBitsSet(numBits: WideWidth, loBitsSet: MaxShiftAmountWidth); |
485 | if (ShVal0 == ShVal1 || MaskedValueIsZero(V: L, Mask: HiBitMask)) |
486 | if (match(V: R, P: m_OneUse(SubPattern: m_Sub(L: m_SpecificInt(V: Width), R: m_Specific(V: L))))) |
487 | return L; |
488 | |
489 | // The following patterns currently only work for rotation patterns. |
490 | // TODO: Add more general funnel-shift compatible patterns. |
491 | if (ShVal0 != ShVal1) |
492 | return nullptr; |
493 | |
494 | // The shift amount may be masked with negation: |
495 | // (shl ShVal0, (X & (Width - 1))) | (lshr ShVal1, ((-X) & (Width - 1))) |
496 | Value *X; |
497 | unsigned Mask = Width - 1; |
498 | if (match(V: L, P: m_And(L: m_Value(V&: X), R: m_SpecificInt(V: Mask))) && |
499 | match(V: R, P: m_And(L: m_Neg(V: m_Specific(V: X)), R: m_SpecificInt(V: Mask)))) |
500 | return X; |
501 | |
502 | // Same as above, but the shift amount may be extended after masking: |
503 | if (match(V: L, P: m_ZExt(Op: m_And(L: m_Value(V&: X), R: m_SpecificInt(V: Mask)))) && |
504 | match(V: R, P: m_ZExt(Op: m_And(L: m_Neg(V: m_Specific(V: X)), R: m_SpecificInt(V: Mask))))) |
505 | return X; |
506 | |
507 | return nullptr; |
508 | }; |
509 | |
510 | Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, NarrowWidth); |
511 | bool IsFshl = true; // Sub on LSHR. |
512 | if (!ShAmt) { |
513 | ShAmt = matchShiftAmount(ShAmt1, ShAmt0, NarrowWidth); |
514 | IsFshl = false; // Sub on SHL. |
515 | } |
516 | if (!ShAmt) |
517 | return nullptr; |
518 | |
519 | // The right-shifted value must have high zeros in the wide type (for example |
520 | // from 'zext', 'and' or 'shift'). High bits of the left-shifted value are |
521 | // truncated, so those do not matter. |
522 | APInt HiBitMask = APInt::getHighBitsSet(numBits: WideWidth, hiBitsSet: WideWidth - NarrowWidth); |
523 | if (!MaskedValueIsZero(V: ShVal1, Mask: HiBitMask, Depth: 0, CxtI: &Trunc)) |
524 | return nullptr; |
525 | |
526 | // Adjust the width of ShAmt for narrowed funnel shift operation: |
527 | // - Zero-extend if ShAmt is narrower than the destination type. |
528 | // - Truncate if ShAmt is wider, discarding non-significant high-order bits. |
529 | // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal), |
530 | // zext/trunc(ShAmt)). |
531 | Value *NarrowShAmt = Builder.CreateZExtOrTrunc(V: ShAmt, DestTy); |
532 | |
533 | Value *X, *Y; |
534 | X = Y = Builder.CreateTrunc(V: ShVal0, DestTy); |
535 | if (ShVal0 != ShVal1) |
536 | Y = Builder.CreateTrunc(V: ShVal1, DestTy); |
537 | Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; |
538 | Function *F = Intrinsic::getDeclaration(M: Trunc.getModule(), id: IID, Tys: DestTy); |
539 | return CallInst::Create(Func: F, Args: {X, Y, NarrowShAmt}); |
540 | } |
541 | |
542 | /// Try to narrow the width of math or bitwise logic instructions by pulling a |
543 | /// truncate ahead of binary operators. |
544 | Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { |
545 | Type *SrcTy = Trunc.getSrcTy(); |
546 | Type *DestTy = Trunc.getType(); |
547 | unsigned SrcWidth = SrcTy->getScalarSizeInBits(); |
548 | unsigned DestWidth = DestTy->getScalarSizeInBits(); |
549 | |
550 | if (!isa<VectorType>(Val: SrcTy) && !shouldChangeType(From: SrcTy, To: DestTy)) |
551 | return nullptr; |
552 | |
553 | BinaryOperator *BinOp; |
554 | if (!match(V: Trunc.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BinOp(I&: BinOp)))) |
555 | return nullptr; |
556 | |
557 | Value *BinOp0 = BinOp->getOperand(i_nocapture: 0); |
558 | Value *BinOp1 = BinOp->getOperand(i_nocapture: 1); |
559 | switch (BinOp->getOpcode()) { |
560 | case Instruction::And: |
561 | case Instruction::Or: |
562 | case Instruction::Xor: |
563 | case Instruction::Add: |
564 | case Instruction::Sub: |
565 | case Instruction::Mul: { |
566 | Constant *C; |
567 | if (match(V: BinOp0, P: m_Constant(C))) { |
568 | // trunc (binop C, X) --> binop (trunc C', X) |
569 | Constant *NarrowC = ConstantExpr::getTrunc(C, Ty: DestTy); |
570 | Value *TruncX = Builder.CreateTrunc(V: BinOp1, DestTy); |
571 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: NarrowC, S2: TruncX); |
572 | } |
573 | if (match(V: BinOp1, P: m_Constant(C))) { |
574 | // trunc (binop X, C) --> binop (trunc X, C') |
575 | Constant *NarrowC = ConstantExpr::getTrunc(C, Ty: DestTy); |
576 | Value *TruncX = Builder.CreateTrunc(V: BinOp0, DestTy); |
577 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: TruncX, S2: NarrowC); |
578 | } |
579 | Value *X; |
580 | if (match(V: BinOp0, P: m_ZExtOrSExt(Op: m_Value(V&: X))) && X->getType() == DestTy) { |
581 | // trunc (binop (ext X), Y) --> binop X, (trunc Y) |
582 | Value *NarrowOp1 = Builder.CreateTrunc(V: BinOp1, DestTy); |
583 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: X, S2: NarrowOp1); |
584 | } |
585 | if (match(V: BinOp1, P: m_ZExtOrSExt(Op: m_Value(V&: X))) && X->getType() == DestTy) { |
586 | // trunc (binop Y, (ext X)) --> binop (trunc Y), X |
587 | Value *NarrowOp0 = Builder.CreateTrunc(V: BinOp0, DestTy); |
588 | return BinaryOperator::Create(Op: BinOp->getOpcode(), S1: NarrowOp0, S2: X); |
589 | } |
590 | break; |
591 | } |
592 | case Instruction::LShr: |
593 | case Instruction::AShr: { |
594 | // trunc (*shr (trunc A), C) --> trunc(*shr A, C) |
595 | Value *A; |
596 | Constant *C; |
597 | if (match(V: BinOp0, P: m_Trunc(Op: m_Value(V&: A))) && match(V: BinOp1, P: m_Constant(C))) { |
598 | unsigned MaxShiftAmt = SrcWidth - DestWidth; |
599 | // If the shift is small enough, all zero/sign bits created by the shift |
600 | // are removed by the trunc. |
601 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULE, |
602 | Threshold: APInt(SrcWidth, MaxShiftAmt)))) { |
603 | auto *OldShift = cast<Instruction>(Val: Trunc.getOperand(i_nocapture: 0)); |
604 | bool IsExact = OldShift->isExact(); |
605 | if (Constant *ShAmt = ConstantFoldIntegerCast(C, DestTy: A->getType(), |
606 | /*IsSigned*/ true, DL)) { |
607 | ShAmt = Constant::mergeUndefsWith(C: ShAmt, Other: C); |
608 | Value *Shift = |
609 | OldShift->getOpcode() == Instruction::AShr |
610 | ? Builder.CreateAShr(LHS: A, RHS: ShAmt, Name: OldShift->getName(), isExact: IsExact) |
611 | : Builder.CreateLShr(LHS: A, RHS: ShAmt, Name: OldShift->getName(), isExact: IsExact); |
612 | return CastInst::CreateTruncOrBitCast(S: Shift, Ty: DestTy); |
613 | } |
614 | } |
615 | } |
616 | break; |
617 | } |
618 | default: break; |
619 | } |
620 | |
621 | if (Instruction *NarrowOr = narrowFunnelShift(Trunc)) |
622 | return NarrowOr; |
623 | |
624 | return nullptr; |
625 | } |
626 | |
627 | /// Try to narrow the width of a splat shuffle. This could be generalized to any |
628 | /// shuffle with a constant operand, but we limit the transform to avoid |
629 | /// creating a shuffle type that targets may not be able to lower effectively. |
630 | static Instruction *shrinkSplatShuffle(TruncInst &Trunc, |
631 | InstCombiner::BuilderTy &Builder) { |
632 | auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Trunc.getOperand(i_nocapture: 0)); |
633 | if (Shuf && Shuf->hasOneUse() && match(V: Shuf->getOperand(i_nocapture: 1), P: m_Undef()) && |
634 | all_equal(Range: Shuf->getShuffleMask()) && |
635 | Shuf->getType() == Shuf->getOperand(i_nocapture: 0)->getType()) { |
636 | // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask |
637 | // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask |
638 | Value *NarrowOp = Builder.CreateTrunc(V: Shuf->getOperand(i_nocapture: 0), DestTy: Trunc.getType()); |
639 | return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask()); |
640 | } |
641 | |
642 | return nullptr; |
643 | } |
644 | |
645 | /// Try to narrow the width of an insert element. This could be generalized for |
646 | /// any vector constant, but we limit the transform to insertion into undef to |
647 | /// avoid potential backend problems from unsupported insertion widths. This |
648 | /// could also be extended to handle the case of inserting a scalar constant |
649 | /// into a vector variable. |
650 | static Instruction *shrinkInsertElt(CastInst &Trunc, |
651 | InstCombiner::BuilderTy &Builder) { |
652 | Instruction::CastOps Opcode = Trunc.getOpcode(); |
653 | assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && |
654 | "Unexpected instruction for shrinking" ); |
655 | |
656 | auto *InsElt = dyn_cast<InsertElementInst>(Val: Trunc.getOperand(i_nocapture: 0)); |
657 | if (!InsElt || !InsElt->hasOneUse()) |
658 | return nullptr; |
659 | |
660 | Type *DestTy = Trunc.getType(); |
661 | Type *DestScalarTy = DestTy->getScalarType(); |
662 | Value *VecOp = InsElt->getOperand(i_nocapture: 0); |
663 | Value *ScalarOp = InsElt->getOperand(i_nocapture: 1); |
664 | Value *Index = InsElt->getOperand(i_nocapture: 2); |
665 | |
666 | if (match(V: VecOp, P: m_Undef())) { |
667 | // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index |
668 | // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index |
669 | UndefValue *NarrowUndef = UndefValue::get(T: DestTy); |
670 | Value *NarrowOp = Builder.CreateCast(Op: Opcode, V: ScalarOp, DestTy: DestScalarTy); |
671 | return InsertElementInst::Create(Vec: NarrowUndef, NewElt: NarrowOp, Idx: Index); |
672 | } |
673 | |
674 | return nullptr; |
675 | } |
676 | |
677 | Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { |
678 | if (Instruction *Result = commonCastTransforms(CI&: Trunc)) |
679 | return Result; |
680 | |
681 | Value *Src = Trunc.getOperand(i_nocapture: 0); |
682 | Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); |
683 | unsigned DestWidth = DestTy->getScalarSizeInBits(); |
684 | unsigned SrcWidth = SrcTy->getScalarSizeInBits(); |
685 | |
686 | // Attempt to truncate the entire input expression tree to the destination |
687 | // type. Only do this if the dest type is a simple type, don't convert the |
688 | // expression tree to something weird like i93 unless the source is also |
689 | // strange. |
690 | if ((DestTy->isVectorTy() || shouldChangeType(From: SrcTy, To: DestTy)) && |
691 | canEvaluateTruncated(V: Src, Ty: DestTy, IC&: *this, CxtI: &Trunc)) { |
692 | |
693 | // If this cast is a truncate, evaluting in a different type always |
694 | // eliminates the cast, so it is always a win. |
695 | LLVM_DEBUG( |
696 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
697 | " to avoid cast: " |
698 | << Trunc << '\n'); |
699 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: false); |
700 | assert(Res->getType() == DestTy); |
701 | return replaceInstUsesWith(I&: Trunc, V: Res); |
702 | } |
703 | |
704 | // For integer types, check if we can shorten the entire input expression to |
705 | // DestWidth * 2, which won't allow removing the truncate, but reducing the |
706 | // width may enable further optimizations, e.g. allowing for larger |
707 | // vectorization factors. |
708 | if (auto *DestITy = dyn_cast<IntegerType>(Val: DestTy)) { |
709 | if (DestWidth * 2 < SrcWidth) { |
710 | auto *NewDestTy = DestITy->getExtendedType(); |
711 | if (shouldChangeType(From: SrcTy, To: NewDestTy) && |
712 | canEvaluateTruncated(V: Src, Ty: NewDestTy, IC&: *this, CxtI: &Trunc)) { |
713 | LLVM_DEBUG( |
714 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
715 | " to reduce the width of operand of" |
716 | << Trunc << '\n'); |
717 | Value *Res = EvaluateInDifferentType(V: Src, Ty: NewDestTy, isSigned: false); |
718 | return new TruncInst(Res, DestTy); |
719 | } |
720 | } |
721 | } |
722 | |
723 | // Test if the trunc is the user of a select which is part of a |
724 | // minimum or maximum operation. If so, don't do any more simplification. |
725 | // Even simplifying demanded bits can break the canonical form of a |
726 | // min/max. |
727 | Value *LHS, *RHS; |
728 | if (SelectInst *Sel = dyn_cast<SelectInst>(Val: Src)) |
729 | if (matchSelectPattern(V: Sel, LHS, RHS).Flavor != SPF_UNKNOWN) |
730 | return nullptr; |
731 | |
732 | // See if we can simplify any instructions used by the input whose sole |
733 | // purpose is to compute bits we don't care about. |
734 | if (SimplifyDemandedInstructionBits(Inst&: Trunc)) |
735 | return &Trunc; |
736 | |
737 | if (DestWidth == 1) { |
738 | Value *Zero = Constant::getNullValue(Ty: SrcTy); |
739 | |
740 | Value *X; |
741 | const APInt *C1; |
742 | Constant *C2; |
743 | if (match(V: Src, P: m_OneUse(SubPattern: m_Shr(L: m_Shl(L: m_Power2(V&: C1), R: m_Value(V&: X)), |
744 | R: m_ImmConstant(C&: C2))))) { |
745 | // trunc ((C1 << X) >> C2) to i1 --> X == (C2-cttz(C1)), where C1 is pow2 |
746 | Constant *Log2C1 = ConstantInt::get(Ty: SrcTy, V: C1->exactLogBase2()); |
747 | Constant *CmpC = ConstantExpr::getSub(C1: C2, C2: Log2C1); |
748 | return new ICmpInst(ICmpInst::ICMP_EQ, X, CmpC); |
749 | } |
750 | |
751 | Constant *C; |
752 | if (match(V: Src, P: m_OneUse(SubPattern: m_LShr(L: m_Value(V&: X), R: m_ImmConstant(C))))) { |
753 | // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 |
754 | Constant *One = ConstantInt::get(Ty: SrcTy, V: APInt(SrcWidth, 1)); |
755 | Value *MaskC = Builder.CreateShl(LHS: One, RHS: C); |
756 | Value *And = Builder.CreateAnd(LHS: X, RHS: MaskC); |
757 | return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); |
758 | } |
759 | if (match(V: Src, P: m_OneUse(SubPattern: m_c_Or(L: m_LShr(L: m_Value(V&: X), R: m_ImmConstant(C)), |
760 | R: m_Deferred(V: X))))) { |
761 | // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 |
762 | Constant *One = ConstantInt::get(Ty: SrcTy, V: APInt(SrcWidth, 1)); |
763 | Value *MaskC = Builder.CreateShl(LHS: One, RHS: C); |
764 | Value *And = Builder.CreateAnd(LHS: X, RHS: Builder.CreateOr(LHS: MaskC, RHS: One)); |
765 | return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); |
766 | } |
767 | |
768 | { |
769 | const APInt *C; |
770 | if (match(V: Src, P: m_Shl(L: m_APInt(Res&: C), R: m_Value(V&: X))) && (*C)[0] == 1) { |
771 | // trunc (C << X) to i1 --> X == 0, where C is odd |
772 | return new ICmpInst(ICmpInst::Predicate::ICMP_EQ, X, Zero); |
773 | } |
774 | } |
775 | |
776 | if (Trunc.hasNoUnsignedWrap() || Trunc.hasNoSignedWrap()) { |
777 | Value *X, *Y; |
778 | if (match(V: Src, P: m_Xor(L: m_Value(V&: X), R: m_Value(V&: Y)))) |
779 | return new ICmpInst(ICmpInst::ICMP_NE, X, Y); |
780 | } |
781 | } |
782 | |
783 | Value *A, *B; |
784 | Constant *C; |
785 | if (match(V: Src, P: m_LShr(L: m_SExt(Op: m_Value(V&: A)), R: m_Constant(C)))) { |
786 | unsigned AWidth = A->getType()->getScalarSizeInBits(); |
787 | unsigned MaxShiftAmt = SrcWidth - std::max(a: DestWidth, b: AWidth); |
788 | auto *OldSh = cast<Instruction>(Val: Src); |
789 | bool IsExact = OldSh->isExact(); |
790 | |
791 | // If the shift is small enough, all zero bits created by the shift are |
792 | // removed by the trunc. |
793 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULE, |
794 | Threshold: APInt(SrcWidth, MaxShiftAmt)))) { |
795 | auto GetNewShAmt = [&](unsigned Width) { |
796 | Constant *MaxAmt = ConstantInt::get(Ty: SrcTy, V: Width - 1, IsSigned: false); |
797 | Constant *Cmp = |
798 | ConstantFoldCompareInstOperands(Predicate: ICmpInst::ICMP_ULT, LHS: C, RHS: MaxAmt, DL); |
799 | Constant *ShAmt = ConstantFoldSelectInstruction(Cond: Cmp, V1: C, V2: MaxAmt); |
800 | return ConstantFoldCastOperand(Opcode: Instruction::Trunc, C: ShAmt, DestTy: A->getType(), |
801 | DL); |
802 | }; |
803 | |
804 | // trunc (lshr (sext A), C) --> ashr A, C |
805 | if (A->getType() == DestTy) { |
806 | Constant *ShAmt = GetNewShAmt(DestWidth); |
807 | ShAmt = Constant::mergeUndefsWith(C: ShAmt, Other: C); |
808 | return IsExact ? BinaryOperator::CreateExactAShr(V1: A, V2: ShAmt) |
809 | : BinaryOperator::CreateAShr(V1: A, V2: ShAmt); |
810 | } |
811 | // The types are mismatched, so create a cast after shifting: |
812 | // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) |
813 | if (Src->hasOneUse()) { |
814 | Constant *ShAmt = GetNewShAmt(AWidth); |
815 | Value *Shift = Builder.CreateAShr(LHS: A, RHS: ShAmt, Name: "" , isExact: IsExact); |
816 | return CastInst::CreateIntegerCast(S: Shift, Ty: DestTy, isSigned: true); |
817 | } |
818 | } |
819 | // TODO: Mask high bits with 'and'. |
820 | } |
821 | |
822 | if (Instruction *I = narrowBinOp(Trunc)) |
823 | return I; |
824 | |
825 | if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) |
826 | return I; |
827 | |
828 | if (Instruction *I = shrinkInsertElt(Trunc, Builder)) |
829 | return I; |
830 | |
831 | if (Src->hasOneUse() && |
832 | (isa<VectorType>(Val: SrcTy) || shouldChangeType(From: SrcTy, To: DestTy))) { |
833 | // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the |
834 | // dest type is native and cst < dest size. |
835 | if (match(V: Src, P: m_Shl(L: m_Value(V&: A), R: m_Constant(C))) && |
836 | !match(V: A, P: m_Shr(L: m_Value(), R: m_Constant()))) { |
837 | // Skip shifts of shift by constants. It undoes a combine in |
838 | // FoldShiftByConstant and is the extend in reg pattern. |
839 | APInt Threshold = APInt(C->getType()->getScalarSizeInBits(), DestWidth); |
840 | if (match(V: C, P: m_SpecificInt_ICMP(Predicate: ICmpInst::ICMP_ULT, Threshold))) { |
841 | Value *NewTrunc = Builder.CreateTrunc(V: A, DestTy, Name: A->getName() + ".tr" ); |
842 | return BinaryOperator::Create(Op: Instruction::Shl, S1: NewTrunc, |
843 | S2: ConstantExpr::getTrunc(C, Ty: DestTy)); |
844 | } |
845 | } |
846 | } |
847 | |
848 | if (Instruction *I = foldVecTruncToExtElt(Trunc, IC&: *this)) |
849 | return I; |
850 | |
851 | // Whenever an element is extracted from a vector, and then truncated, |
852 | // canonicalize by converting it to a bitcast followed by an |
853 | // extractelement. |
854 | // |
855 | // Example (little endian): |
856 | // trunc (extractelement <4 x i64> %X, 0) to i32 |
857 | // ---> |
858 | // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 |
859 | Value *VecOp; |
860 | ConstantInt *Cst; |
861 | if (match(V: Src, P: m_OneUse(SubPattern: m_ExtractElt(Val: m_Value(V&: VecOp), Idx: m_ConstantInt(CI&: Cst))))) { |
862 | auto *VecOpTy = cast<VectorType>(Val: VecOp->getType()); |
863 | auto VecElts = VecOpTy->getElementCount(); |
864 | |
865 | // A badly fit destination size would result in an invalid cast. |
866 | if (SrcWidth % DestWidth == 0) { |
867 | uint64_t TruncRatio = SrcWidth / DestWidth; |
868 | uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; |
869 | uint64_t VecOpIdx = Cst->getZExtValue(); |
870 | uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 |
871 | : VecOpIdx * TruncRatio; |
872 | assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && |
873 | "overflow 32-bits" ); |
874 | |
875 | auto *BitCastTo = |
876 | VectorType::get(ElementType: DestTy, NumElements: BitCastNumElts, Scalable: VecElts.isScalable()); |
877 | Value *BitCast = Builder.CreateBitCast(V: VecOp, DestTy: BitCastTo); |
878 | return ExtractElementInst::Create(Vec: BitCast, Idx: Builder.getInt32(C: NewIdx)); |
879 | } |
880 | } |
881 | |
882 | // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C) |
883 | if (match(V: Src, P: m_OneUse(SubPattern: m_Intrinsic<Intrinsic::ctlz>(Op0: m_ZExt(Op: m_Value(V&: A)), |
884 | Op1: m_Value(V&: B))))) { |
885 | unsigned AWidth = A->getType()->getScalarSizeInBits(); |
886 | if (AWidth == DestWidth && AWidth > Log2_32(Value: SrcWidth)) { |
887 | Value *WidthDiff = ConstantInt::get(Ty: A->getType(), V: SrcWidth - AWidth); |
888 | Value *NarrowCtlz = |
889 | Builder.CreateIntrinsic(ID: Intrinsic::ctlz, Types: {Trunc.getType()}, Args: {A, B}); |
890 | return BinaryOperator::CreateAdd(V1: NarrowCtlz, V2: WidthDiff); |
891 | } |
892 | } |
893 | |
894 | if (match(V: Src, P: m_VScale())) { |
895 | if (Trunc.getFunction() && |
896 | Trunc.getFunction()->hasFnAttribute(Kind: Attribute::VScaleRange)) { |
897 | Attribute Attr = |
898 | Trunc.getFunction()->getFnAttribute(Kind: Attribute::VScaleRange); |
899 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
900 | if (Log2_32(Value: *MaxVScale) < DestWidth) { |
901 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
902 | return replaceInstUsesWith(I&: Trunc, V: VScale); |
903 | } |
904 | } |
905 | } |
906 | } |
907 | |
908 | bool Changed = false; |
909 | if (!Trunc.hasNoSignedWrap() && |
910 | ComputeMaxSignificantBits(Op: Src, /*Depth=*/0, CxtI: &Trunc) <= DestWidth) { |
911 | Trunc.setHasNoSignedWrap(true); |
912 | Changed = true; |
913 | } |
914 | if (!Trunc.hasNoUnsignedWrap() && |
915 | MaskedValueIsZero(V: Src, Mask: APInt::getBitsSetFrom(numBits: SrcWidth, loBit: DestWidth), |
916 | /*Depth=*/0, CxtI: &Trunc)) { |
917 | Trunc.setHasNoUnsignedWrap(true); |
918 | Changed = true; |
919 | } |
920 | |
921 | return Changed ? &Trunc : nullptr; |
922 | } |
923 | |
924 | Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, |
925 | ZExtInst &Zext) { |
926 | // If we are just checking for a icmp eq of a single bit and zext'ing it |
927 | // to an integer, then shift the bit to the appropriate place and then |
928 | // cast to integer to avoid the comparison. |
929 | |
930 | // FIXME: This set of transforms does not check for extra uses and/or creates |
931 | // an extra instruction (an optional final cast is not included |
932 | // in the transform comments). We may also want to favor icmp over |
933 | // shifts in cases of equal instructions because icmp has better |
934 | // analysis in general (invert the transform). |
935 | |
936 | const APInt *Op1CV; |
937 | if (match(V: Cmp->getOperand(i_nocapture: 1), P: m_APInt(Res&: Op1CV))) { |
938 | |
939 | // zext (x <s 0) to i32 --> x>>u31 true if signbit set. |
940 | if (Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) { |
941 | Value *In = Cmp->getOperand(i_nocapture: 0); |
942 | Value *Sh = ConstantInt::get(Ty: In->getType(), |
943 | V: In->getType()->getScalarSizeInBits() - 1); |
944 | In = Builder.CreateLShr(LHS: In, RHS: Sh, Name: In->getName() + ".lobit" ); |
945 | if (In->getType() != Zext.getType()) |
946 | In = Builder.CreateIntCast(V: In, DestTy: Zext.getType(), isSigned: false /*ZExt*/); |
947 | |
948 | return replaceInstUsesWith(I&: Zext, V: In); |
949 | } |
950 | |
951 | // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. |
952 | // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. |
953 | // zext (X != 0) to i32 --> X iff X has only the low bit set. |
954 | // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. |
955 | |
956 | if (Op1CV->isZero() && Cmp->isEquality()) { |
957 | // Exactly 1 possible 1? But not the high-bit because that is |
958 | // canonicalized to this form. |
959 | KnownBits Known = computeKnownBits(V: Cmp->getOperand(i_nocapture: 0), Depth: 0, CxtI: &Zext); |
960 | APInt KnownZeroMask(~Known.Zero); |
961 | uint32_t ShAmt = KnownZeroMask.logBase2(); |
962 | bool IsExpectShAmt = KnownZeroMask.isPowerOf2() && |
963 | (Zext.getType()->getScalarSizeInBits() != ShAmt + 1); |
964 | if (IsExpectShAmt && |
965 | (Cmp->getOperand(i_nocapture: 0)->getType() == Zext.getType() || |
966 | Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) { |
967 | Value *In = Cmp->getOperand(i_nocapture: 0); |
968 | if (ShAmt) { |
969 | // Perform a logical shr by shiftamt. |
970 | // Insert the shift to put the result in the low bit. |
971 | In = Builder.CreateLShr(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), V: ShAmt), |
972 | Name: In->getName() + ".lobit" ); |
973 | } |
974 | |
975 | // Toggle the low bit for "X == 0". |
976 | if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) |
977 | In = Builder.CreateXor(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), V: 1)); |
978 | |
979 | if (Zext.getType() == In->getType()) |
980 | return replaceInstUsesWith(I&: Zext, V: In); |
981 | |
982 | Value *IntCast = Builder.CreateIntCast(V: In, DestTy: Zext.getType(), isSigned: false); |
983 | return replaceInstUsesWith(I&: Zext, V: IntCast); |
984 | } |
985 | } |
986 | } |
987 | |
988 | if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(i_nocapture: 0)->getType()) { |
989 | // Test if a bit is clear/set using a shifted-one mask: |
990 | // zext (icmp eq (and X, (1 << ShAmt)), 0) --> and (lshr (not X), ShAmt), 1 |
991 | // zext (icmp ne (and X, (1 << ShAmt)), 0) --> and (lshr X, ShAmt), 1 |
992 | Value *X, *ShAmt; |
993 | if (Cmp->hasOneUse() && match(V: Cmp->getOperand(i_nocapture: 1), P: m_ZeroInt()) && |
994 | match(V: Cmp->getOperand(i_nocapture: 0), |
995 | P: m_OneUse(SubPattern: m_c_And(L: m_Shl(L: m_One(), R: m_Value(V&: ShAmt)), R: m_Value(V&: X))))) { |
996 | if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) |
997 | X = Builder.CreateNot(V: X); |
998 | Value *Lshr = Builder.CreateLShr(LHS: X, RHS: ShAmt); |
999 | Value *And1 = Builder.CreateAnd(LHS: Lshr, RHS: ConstantInt::get(Ty: X->getType(), V: 1)); |
1000 | return replaceInstUsesWith(I&: Zext, V: And1); |
1001 | } |
1002 | } |
1003 | |
1004 | return nullptr; |
1005 | } |
1006 | |
1007 | /// Determine if the specified value can be computed in the specified wider type |
1008 | /// and produce the same low bits. If not, return false. |
1009 | /// |
1010 | /// If this function returns true, it can also return a non-zero number of bits |
1011 | /// (in BitsToClear) which indicates that the value it computes is correct for |
1012 | /// the zero extend, but that the additional BitsToClear bits need to be zero'd |
1013 | /// out. For example, to promote something like: |
1014 | /// |
1015 | /// %B = trunc i64 %A to i32 |
1016 | /// %C = lshr i32 %B, 8 |
1017 | /// %E = zext i32 %C to i64 |
1018 | /// |
1019 | /// CanEvaluateZExtd for the 'lshr' will return true, and BitsToClear will be |
1020 | /// set to 8 to indicate that the promoted value needs to have bits 24-31 |
1021 | /// cleared in addition to bits 32-63. Since an 'and' will be generated to |
1022 | /// clear the top bits anyway, doing this has no extra cost. |
1023 | /// |
1024 | /// This function works on both vectors and scalars. |
1025 | static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, |
1026 | InstCombinerImpl &IC, Instruction *CxtI) { |
1027 | BitsToClear = 0; |
1028 | if (canAlwaysEvaluateInType(V, Ty)) |
1029 | return true; |
1030 | if (canNotEvaluateInType(V, Ty)) |
1031 | return false; |
1032 | |
1033 | auto *I = cast<Instruction>(Val: V); |
1034 | unsigned Tmp; |
1035 | switch (I->getOpcode()) { |
1036 | case Instruction::ZExt: // zext(zext(x)) -> zext(x). |
1037 | case Instruction::SExt: // zext(sext(x)) -> sext(x). |
1038 | case Instruction::Trunc: // zext(trunc(x)) -> trunc(x) or zext(x) |
1039 | return true; |
1040 | case Instruction::And: |
1041 | case Instruction::Or: |
1042 | case Instruction::Xor: |
1043 | case Instruction::Add: |
1044 | case Instruction::Sub: |
1045 | case Instruction::Mul: |
1046 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI) || |
1047 | !canEvaluateZExtd(V: I->getOperand(i: 1), Ty, BitsToClear&: Tmp, IC, CxtI)) |
1048 | return false; |
1049 | // These can all be promoted if neither operand has 'bits to clear'. |
1050 | if (BitsToClear == 0 && Tmp == 0) |
1051 | return true; |
1052 | |
1053 | // If the operation is an AND/OR/XOR and the bits to clear are zero in the |
1054 | // other side, BitsToClear is ok. |
1055 | if (Tmp == 0 && I->isBitwiseLogicOp()) { |
1056 | // We use MaskedValueIsZero here for generality, but the case we care |
1057 | // about the most is constant RHS. |
1058 | unsigned VSize = V->getType()->getScalarSizeInBits(); |
1059 | if (IC.MaskedValueIsZero(V: I->getOperand(i: 1), |
1060 | Mask: APInt::getHighBitsSet(numBits: VSize, hiBitsSet: BitsToClear), |
1061 | Depth: 0, CxtI)) { |
1062 | // If this is an And instruction and all of the BitsToClear are |
1063 | // known to be zero we can reset BitsToClear. |
1064 | if (I->getOpcode() == Instruction::And) |
1065 | BitsToClear = 0; |
1066 | return true; |
1067 | } |
1068 | } |
1069 | |
1070 | // Otherwise, we don't know how to analyze this BitsToClear case yet. |
1071 | return false; |
1072 | |
1073 | case Instruction::Shl: { |
1074 | // We can promote shl(x, cst) if we can promote x. Since shl overwrites the |
1075 | // upper bits we can reduce BitsToClear by the shift amount. |
1076 | const APInt *Amt; |
1077 | if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: Amt))) { |
1078 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI)) |
1079 | return false; |
1080 | uint64_t ShiftAmt = Amt->getZExtValue(); |
1081 | BitsToClear = ShiftAmt < BitsToClear ? BitsToClear - ShiftAmt : 0; |
1082 | return true; |
1083 | } |
1084 | return false; |
1085 | } |
1086 | case Instruction::LShr: { |
1087 | // We can promote lshr(x, cst) if we can promote x. This requires the |
1088 | // ultimate 'and' to clear out the high zero bits we're clearing out though. |
1089 | const APInt *Amt; |
1090 | if (match(V: I->getOperand(i: 1), P: m_APInt(Res&: Amt))) { |
1091 | if (!canEvaluateZExtd(V: I->getOperand(i: 0), Ty, BitsToClear, IC, CxtI)) |
1092 | return false; |
1093 | BitsToClear += Amt->getZExtValue(); |
1094 | if (BitsToClear > V->getType()->getScalarSizeInBits()) |
1095 | BitsToClear = V->getType()->getScalarSizeInBits(); |
1096 | return true; |
1097 | } |
1098 | // Cannot promote variable LSHR. |
1099 | return false; |
1100 | } |
1101 | case Instruction::Select: |
1102 | if (!canEvaluateZExtd(V: I->getOperand(i: 1), Ty, BitsToClear&: Tmp, IC, CxtI) || |
1103 | !canEvaluateZExtd(V: I->getOperand(i: 2), Ty, BitsToClear, IC, CxtI) || |
1104 | // TODO: If important, we could handle the case when the BitsToClear are |
1105 | // known zero in the disagreeing side. |
1106 | Tmp != BitsToClear) |
1107 | return false; |
1108 | return true; |
1109 | |
1110 | case Instruction::PHI: { |
1111 | // We can change a phi if we can change all operands. Note that we never |
1112 | // get into trouble with cyclic PHIs here because we only consider |
1113 | // instructions with a single use. |
1114 | PHINode *PN = cast<PHINode>(Val: I); |
1115 | if (!canEvaluateZExtd(V: PN->getIncomingValue(i: 0), Ty, BitsToClear, IC, CxtI)) |
1116 | return false; |
1117 | for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) |
1118 | if (!canEvaluateZExtd(V: PN->getIncomingValue(i), Ty, BitsToClear&: Tmp, IC, CxtI) || |
1119 | // TODO: If important, we could handle the case when the BitsToClear |
1120 | // are known zero in the disagreeing input. |
1121 | Tmp != BitsToClear) |
1122 | return false; |
1123 | return true; |
1124 | } |
1125 | case Instruction::Call: |
1126 | // llvm.vscale() can always be executed in larger type, because the |
1127 | // value is automatically zero-extended. |
1128 | if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I)) |
1129 | if (II->getIntrinsicID() == Intrinsic::vscale) |
1130 | return true; |
1131 | return false; |
1132 | default: |
1133 | // TODO: Can handle more cases here. |
1134 | return false; |
1135 | } |
1136 | } |
1137 | |
1138 | Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { |
1139 | // If this zero extend is only used by a truncate, let the truncate be |
1140 | // eliminated before we try to optimize this zext. |
1141 | if (Zext.hasOneUse() && isa<TruncInst>(Val: Zext.user_back()) && |
1142 | !isa<Constant>(Val: Zext.getOperand(i_nocapture: 0))) |
1143 | return nullptr; |
1144 | |
1145 | // If one of the common conversion will work, do it. |
1146 | if (Instruction *Result = commonCastTransforms(CI&: Zext)) |
1147 | return Result; |
1148 | |
1149 | Value *Src = Zext.getOperand(i_nocapture: 0); |
1150 | Type *SrcTy = Src->getType(), *DestTy = Zext.getType(); |
1151 | |
1152 | // zext nneg bool x -> 0 |
1153 | if (SrcTy->isIntOrIntVectorTy(BitWidth: 1) && Zext.hasNonNeg()) |
1154 | return replaceInstUsesWith(I&: Zext, V: Constant::getNullValue(Ty: Zext.getType())); |
1155 | |
1156 | // Try to extend the entire expression tree to the wide destination type. |
1157 | unsigned BitsToClear; |
1158 | if (shouldChangeType(From: SrcTy, To: DestTy) && |
1159 | canEvaluateZExtd(V: Src, Ty: DestTy, BitsToClear, IC&: *this, CxtI: &Zext)) { |
1160 | assert(BitsToClear <= SrcTy->getScalarSizeInBits() && |
1161 | "Can't clear more bits than in SrcTy" ); |
1162 | |
1163 | // Okay, we can transform this! Insert the new expression now. |
1164 | LLVM_DEBUG( |
1165 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
1166 | " to avoid zero extend: " |
1167 | << Zext << '\n'); |
1168 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: false); |
1169 | assert(Res->getType() == DestTy); |
1170 | |
1171 | // Preserve debug values referring to Src if the zext is its last use. |
1172 | if (auto *SrcOp = dyn_cast<Instruction>(Val: Src)) |
1173 | if (SrcOp->hasOneUse()) |
1174 | replaceAllDbgUsesWith(From&: *SrcOp, To&: *Res, DomPoint&: Zext, DT); |
1175 | |
1176 | uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear; |
1177 | uint32_t DestBitSize = DestTy->getScalarSizeInBits(); |
1178 | |
1179 | // If the high bits are already filled with zeros, just replace this |
1180 | // cast with the result. |
1181 | if (MaskedValueIsZero(V: Res, |
1182 | Mask: APInt::getHighBitsSet(numBits: DestBitSize, |
1183 | hiBitsSet: DestBitSize - SrcBitsKept), |
1184 | Depth: 0, CxtI: &Zext)) |
1185 | return replaceInstUsesWith(I&: Zext, V: Res); |
1186 | |
1187 | // We need to emit an AND to clear the high bits. |
1188 | Constant *C = ConstantInt::get(Ty: Res->getType(), |
1189 | V: APInt::getLowBitsSet(numBits: DestBitSize, loBitsSet: SrcBitsKept)); |
1190 | return BinaryOperator::CreateAnd(V1: Res, V2: C); |
1191 | } |
1192 | |
1193 | // If this is a TRUNC followed by a ZEXT then we are dealing with integral |
1194 | // types and if the sizes are just right we can convert this into a logical |
1195 | // 'and' which will be much cheaper than the pair of casts. |
1196 | if (auto *CSrc = dyn_cast<TruncInst>(Val: Src)) { // A->B->C cast |
1197 | // TODO: Subsume this into EvaluateInDifferentType. |
1198 | |
1199 | // Get the sizes of the types involved. We know that the intermediate type |
1200 | // will be smaller than A or C, but don't know the relation between A and C. |
1201 | Value *A = CSrc->getOperand(i_nocapture: 0); |
1202 | unsigned SrcSize = A->getType()->getScalarSizeInBits(); |
1203 | unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); |
1204 | unsigned DstSize = DestTy->getScalarSizeInBits(); |
1205 | // If we're actually extending zero bits, then if |
1206 | // SrcSize < DstSize: zext(a & mask) |
1207 | // SrcSize == DstSize: a & mask |
1208 | // SrcSize > DstSize: trunc(a) & mask |
1209 | if (SrcSize < DstSize) { |
1210 | APInt AndValue(APInt::getLowBitsSet(numBits: SrcSize, loBitsSet: MidSize)); |
1211 | Constant *AndConst = ConstantInt::get(Ty: A->getType(), V: AndValue); |
1212 | Value *And = Builder.CreateAnd(LHS: A, RHS: AndConst, Name: CSrc->getName() + ".mask" ); |
1213 | return new ZExtInst(And, DestTy); |
1214 | } |
1215 | |
1216 | if (SrcSize == DstSize) { |
1217 | APInt AndValue(APInt::getLowBitsSet(numBits: SrcSize, loBitsSet: MidSize)); |
1218 | return BinaryOperator::CreateAnd(V1: A, V2: ConstantInt::get(Ty: A->getType(), |
1219 | V: AndValue)); |
1220 | } |
1221 | if (SrcSize > DstSize) { |
1222 | Value *Trunc = Builder.CreateTrunc(V: A, DestTy); |
1223 | APInt AndValue(APInt::getLowBitsSet(numBits: DstSize, loBitsSet: MidSize)); |
1224 | return BinaryOperator::CreateAnd(V1: Trunc, |
1225 | V2: ConstantInt::get(Ty: Trunc->getType(), |
1226 | V: AndValue)); |
1227 | } |
1228 | } |
1229 | |
1230 | if (auto *Cmp = dyn_cast<ICmpInst>(Val: Src)) |
1231 | return transformZExtICmp(Cmp, Zext); |
1232 | |
1233 | // zext(trunc(X) & C) -> (X & zext(C)). |
1234 | Constant *C; |
1235 | Value *X; |
1236 | if (match(V: Src, P: m_OneUse(SubPattern: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Constant(C)))) && |
1237 | X->getType() == DestTy) |
1238 | return BinaryOperator::CreateAnd(V1: X, V2: Builder.CreateZExt(V: C, DestTy)); |
1239 | |
1240 | // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). |
1241 | Value *And; |
1242 | if (match(V: Src, P: m_OneUse(SubPattern: m_Xor(L: m_Value(V&: And), R: m_Constant(C)))) && |
1243 | match(V: And, P: m_OneUse(SubPattern: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Specific(V: C)))) && |
1244 | X->getType() == DestTy) { |
1245 | Value *ZC = Builder.CreateZExt(V: C, DestTy); |
1246 | return BinaryOperator::CreateXor(V1: Builder.CreateAnd(LHS: X, RHS: ZC), V2: ZC); |
1247 | } |
1248 | |
1249 | // If we are truncating, masking, and then zexting back to the original type, |
1250 | // that's just a mask. This is not handled by canEvaluateZextd if the |
1251 | // intermediate values have extra uses. This could be generalized further for |
1252 | // a non-constant mask operand. |
1253 | // zext (and (trunc X), C) --> and X, (zext C) |
1254 | if (match(V: Src, P: m_And(L: m_Trunc(Op: m_Value(V&: X)), R: m_Constant(C))) && |
1255 | X->getType() == DestTy) { |
1256 | Value *ZextC = Builder.CreateZExt(V: C, DestTy); |
1257 | return BinaryOperator::CreateAnd(V1: X, V2: ZextC); |
1258 | } |
1259 | |
1260 | if (match(V: Src, P: m_VScale())) { |
1261 | if (Zext.getFunction() && |
1262 | Zext.getFunction()->hasFnAttribute(Kind: Attribute::VScaleRange)) { |
1263 | Attribute Attr = |
1264 | Zext.getFunction()->getFnAttribute(Kind: Attribute::VScaleRange); |
1265 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
1266 | unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); |
1267 | if (Log2_32(Value: *MaxVScale) < TypeWidth) { |
1268 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
1269 | return replaceInstUsesWith(I&: Zext, V: VScale); |
1270 | } |
1271 | } |
1272 | } |
1273 | } |
1274 | |
1275 | if (!Zext.hasNonNeg()) { |
1276 | // If this zero extend is only used by a shift, add nneg flag. |
1277 | if (Zext.hasOneUse() && |
1278 | SrcTy->getScalarSizeInBits() > |
1279 | Log2_64_Ceil(Value: DestTy->getScalarSizeInBits()) && |
1280 | match(V: Zext.user_back(), P: m_Shift(L: m_Value(), R: m_Specific(V: &Zext)))) { |
1281 | Zext.setNonNeg(); |
1282 | return &Zext; |
1283 | } |
1284 | |
1285 | if (isKnownNonNegative(V: Src, SQ: SQ.getWithInstruction(I: &Zext))) { |
1286 | Zext.setNonNeg(); |
1287 | return &Zext; |
1288 | } |
1289 | } |
1290 | |
1291 | return nullptr; |
1292 | } |
1293 | |
1294 | /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. |
1295 | Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, |
1296 | SExtInst &Sext) { |
1297 | Value *Op0 = Cmp->getOperand(i_nocapture: 0), *Op1 = Cmp->getOperand(i_nocapture: 1); |
1298 | ICmpInst::Predicate Pred = Cmp->getPredicate(); |
1299 | |
1300 | // Don't bother if Op1 isn't of vector or integer type. |
1301 | if (!Op1->getType()->isIntOrIntVectorTy()) |
1302 | return nullptr; |
1303 | |
1304 | if (Pred == ICmpInst::ICMP_SLT && match(V: Op1, P: m_ZeroInt())) { |
1305 | // sext (x <s 0) --> ashr x, 31 (all ones if negative) |
1306 | Value *Sh = ConstantInt::get(Ty: Op0->getType(), |
1307 | V: Op0->getType()->getScalarSizeInBits() - 1); |
1308 | Value *In = Builder.CreateAShr(LHS: Op0, RHS: Sh, Name: Op0->getName() + ".lobit" ); |
1309 | if (In->getType() != Sext.getType()) |
1310 | In = Builder.CreateIntCast(V: In, DestTy: Sext.getType(), isSigned: true /*SExt*/); |
1311 | |
1312 | return replaceInstUsesWith(I&: Sext, V: In); |
1313 | } |
1314 | |
1315 | if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Val: Op1)) { |
1316 | // If we know that only one bit of the LHS of the icmp can be set and we |
1317 | // have an equality comparison with zero or a power of 2, we can transform |
1318 | // the icmp and sext into bitwise/integer operations. |
1319 | if (Cmp->hasOneUse() && |
1320 | Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ |
1321 | KnownBits Known = computeKnownBits(V: Op0, Depth: 0, CxtI: &Sext); |
1322 | |
1323 | APInt KnownZeroMask(~Known.Zero); |
1324 | if (KnownZeroMask.isPowerOf2()) { |
1325 | Value *In = Cmp->getOperand(i_nocapture: 0); |
1326 | |
1327 | // If the icmp tests for a known zero bit we can constant fold it. |
1328 | if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) { |
1329 | Value *V = Pred == ICmpInst::ICMP_NE ? |
1330 | ConstantInt::getAllOnesValue(Ty: Sext.getType()) : |
1331 | ConstantInt::getNullValue(Ty: Sext.getType()); |
1332 | return replaceInstUsesWith(I&: Sext, V); |
1333 | } |
1334 | |
1335 | if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { |
1336 | // sext ((x & 2^n) == 0) -> (x >> n) - 1 |
1337 | // sext ((x & 2^n) != 2^n) -> (x >> n) - 1 |
1338 | unsigned ShiftAmt = KnownZeroMask.countr_zero(); |
1339 | // Perform a right shift to place the desired bit in the LSB. |
1340 | if (ShiftAmt) |
1341 | In = Builder.CreateLShr(LHS: In, |
1342 | RHS: ConstantInt::get(Ty: In->getType(), V: ShiftAmt)); |
1343 | |
1344 | // At this point "In" is either 1 or 0. Subtract 1 to turn |
1345 | // {1, 0} -> {0, -1}. |
1346 | In = Builder.CreateAdd(LHS: In, |
1347 | RHS: ConstantInt::getAllOnesValue(Ty: In->getType()), |
1348 | Name: "sext" ); |
1349 | } else { |
1350 | // sext ((x & 2^n) != 0) -> (x << bitwidth-n) a>> bitwidth-1 |
1351 | // sext ((x & 2^n) == 2^n) -> (x << bitwidth-n) a>> bitwidth-1 |
1352 | unsigned ShiftAmt = KnownZeroMask.countl_zero(); |
1353 | // Perform a left shift to place the desired bit in the MSB. |
1354 | if (ShiftAmt) |
1355 | In = Builder.CreateShl(LHS: In, |
1356 | RHS: ConstantInt::get(Ty: In->getType(), V: ShiftAmt)); |
1357 | |
1358 | // Distribute the bit over the whole bit width. |
1359 | In = Builder.CreateAShr(LHS: In, RHS: ConstantInt::get(Ty: In->getType(), |
1360 | V: KnownZeroMask.getBitWidth() - 1), Name: "sext" ); |
1361 | } |
1362 | |
1363 | if (Sext.getType() == In->getType()) |
1364 | return replaceInstUsesWith(I&: Sext, V: In); |
1365 | return CastInst::CreateIntegerCast(S: In, Ty: Sext.getType(), isSigned: true/*SExt*/); |
1366 | } |
1367 | } |
1368 | } |
1369 | |
1370 | return nullptr; |
1371 | } |
1372 | |
1373 | /// Return true if we can take the specified value and return it as type Ty |
1374 | /// without inserting any new casts and without changing the value of the common |
1375 | /// low bits. This is used by code that tries to promote integer operations to |
1376 | /// a wider types will allow us to eliminate the extension. |
1377 | /// |
1378 | /// This function works on both vectors and scalars. |
1379 | /// |
1380 | static bool canEvaluateSExtd(Value *V, Type *Ty) { |
1381 | assert(V->getType()->getScalarSizeInBits() < Ty->getScalarSizeInBits() && |
1382 | "Can't sign extend type to a smaller type" ); |
1383 | if (canAlwaysEvaluateInType(V, Ty)) |
1384 | return true; |
1385 | if (canNotEvaluateInType(V, Ty)) |
1386 | return false; |
1387 | |
1388 | auto *I = cast<Instruction>(Val: V); |
1389 | switch (I->getOpcode()) { |
1390 | case Instruction::SExt: // sext(sext(x)) -> sext(x) |
1391 | case Instruction::ZExt: // sext(zext(x)) -> zext(x) |
1392 | case Instruction::Trunc: // sext(trunc(x)) -> trunc(x) or sext(x) |
1393 | return true; |
1394 | case Instruction::And: |
1395 | case Instruction::Or: |
1396 | case Instruction::Xor: |
1397 | case Instruction::Add: |
1398 | case Instruction::Sub: |
1399 | case Instruction::Mul: |
1400 | // These operators can all arbitrarily be extended if their inputs can. |
1401 | return canEvaluateSExtd(V: I->getOperand(i: 0), Ty) && |
1402 | canEvaluateSExtd(V: I->getOperand(i: 1), Ty); |
1403 | |
1404 | //case Instruction::Shl: TODO |
1405 | //case Instruction::LShr: TODO |
1406 | |
1407 | case Instruction::Select: |
1408 | return canEvaluateSExtd(V: I->getOperand(i: 1), Ty) && |
1409 | canEvaluateSExtd(V: I->getOperand(i: 2), Ty); |
1410 | |
1411 | case Instruction::PHI: { |
1412 | // We can change a phi if we can change all operands. Note that we never |
1413 | // get into trouble with cyclic PHIs here because we only consider |
1414 | // instructions with a single use. |
1415 | PHINode *PN = cast<PHINode>(Val: I); |
1416 | for (Value *IncValue : PN->incoming_values()) |
1417 | if (!canEvaluateSExtd(V: IncValue, Ty)) return false; |
1418 | return true; |
1419 | } |
1420 | default: |
1421 | // TODO: Can handle more cases here. |
1422 | break; |
1423 | } |
1424 | |
1425 | return false; |
1426 | } |
1427 | |
1428 | Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { |
1429 | // If this sign extend is only used by a truncate, let the truncate be |
1430 | // eliminated before we try to optimize this sext. |
1431 | if (Sext.hasOneUse() && isa<TruncInst>(Val: Sext.user_back())) |
1432 | return nullptr; |
1433 | |
1434 | if (Instruction *I = commonCastTransforms(CI&: Sext)) |
1435 | return I; |
1436 | |
1437 | Value *Src = Sext.getOperand(i_nocapture: 0); |
1438 | Type *SrcTy = Src->getType(), *DestTy = Sext.getType(); |
1439 | unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); |
1440 | unsigned DestBitSize = DestTy->getScalarSizeInBits(); |
1441 | |
1442 | // If the value being extended is zero or positive, use a zext instead. |
1443 | if (isKnownNonNegative(V: Src, SQ: SQ.getWithInstruction(I: &Sext))) { |
1444 | auto CI = CastInst::Create(Instruction::ZExt, S: Src, Ty: DestTy); |
1445 | CI->setNonNeg(true); |
1446 | return CI; |
1447 | } |
1448 | |
1449 | // Try to extend the entire expression tree to the wide destination type. |
1450 | if (shouldChangeType(From: SrcTy, To: DestTy) && canEvaluateSExtd(V: Src, Ty: DestTy)) { |
1451 | // Okay, we can transform this! Insert the new expression now. |
1452 | LLVM_DEBUG( |
1453 | dbgs() << "ICE: EvaluateInDifferentType converting expression type" |
1454 | " to avoid sign extend: " |
1455 | << Sext << '\n'); |
1456 | Value *Res = EvaluateInDifferentType(V: Src, Ty: DestTy, isSigned: true); |
1457 | assert(Res->getType() == DestTy); |
1458 | |
1459 | // If the high bits are already filled with sign bit, just replace this |
1460 | // cast with the result. |
1461 | if (ComputeNumSignBits(Op: Res, Depth: 0, CxtI: &Sext) > DestBitSize - SrcBitSize) |
1462 | return replaceInstUsesWith(I&: Sext, V: Res); |
1463 | |
1464 | // We need to emit a shl + ashr to do the sign extend. |
1465 | Value *ShAmt = ConstantInt::get(Ty: DestTy, V: DestBitSize-SrcBitSize); |
1466 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: Res, RHS: ShAmt, Name: "sext" ), |
1467 | V2: ShAmt); |
1468 | } |
1469 | |
1470 | Value *X; |
1471 | if (match(V: Src, P: m_Trunc(Op: m_Value(V&: X)))) { |
1472 | // If the input has more sign bits than bits truncated, then convert |
1473 | // directly to final type. |
1474 | unsigned XBitSize = X->getType()->getScalarSizeInBits(); |
1475 | if (ComputeNumSignBits(Op: X, Depth: 0, CxtI: &Sext) > XBitSize - SrcBitSize) |
1476 | return CastInst::CreateIntegerCast(S: X, Ty: DestTy, /* isSigned */ true); |
1477 | |
1478 | // If input is a trunc from the destination type, then convert into shifts. |
1479 | if (Src->hasOneUse() && X->getType() == DestTy) { |
1480 | // sext (trunc X) --> ashr (shl X, C), C |
1481 | Constant *ShAmt = ConstantInt::get(Ty: DestTy, V: DestBitSize - SrcBitSize); |
1482 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: X, RHS: ShAmt), V2: ShAmt); |
1483 | } |
1484 | |
1485 | // If we are replacing shifted-in high zero bits with sign bits, convert |
1486 | // the logic shift to arithmetic shift and eliminate the cast to |
1487 | // intermediate type: |
1488 | // sext (trunc (lshr Y, C)) --> sext/trunc (ashr Y, C) |
1489 | Value *Y; |
1490 | if (Src->hasOneUse() && |
1491 | match(V: X, P: m_LShr(L: m_Value(V&: Y), |
1492 | R: m_SpecificIntAllowPoison(V: XBitSize - SrcBitSize)))) { |
1493 | Value *Ashr = Builder.CreateAShr(LHS: Y, RHS: XBitSize - SrcBitSize); |
1494 | return CastInst::CreateIntegerCast(S: Ashr, Ty: DestTy, /* isSigned */ true); |
1495 | } |
1496 | } |
1497 | |
1498 | if (auto *Cmp = dyn_cast<ICmpInst>(Val: Src)) |
1499 | return transformSExtICmp(Cmp, Sext); |
1500 | |
1501 | // If the input is a shl/ashr pair of a same constant, then this is a sign |
1502 | // extension from a smaller value. If we could trust arbitrary bitwidth |
1503 | // integers, we could turn this into a truncate to the smaller bit and then |
1504 | // use a sext for the whole extension. Since we don't, look deeper and check |
1505 | // for a truncate. If the source and dest are the same type, eliminate the |
1506 | // trunc and extend and just do shifts. For example, turn: |
1507 | // %a = trunc i32 %i to i8 |
1508 | // %b = shl i8 %a, C |
1509 | // %c = ashr i8 %b, C |
1510 | // %d = sext i8 %c to i32 |
1511 | // into: |
1512 | // %a = shl i32 %i, 32-(8-C) |
1513 | // %d = ashr i32 %a, 32-(8-C) |
1514 | Value *A = nullptr; |
1515 | // TODO: Eventually this could be subsumed by EvaluateInDifferentType. |
1516 | Constant *BA = nullptr, *CA = nullptr; |
1517 | if (match(V: Src, P: m_AShr(L: m_Shl(L: m_Trunc(Op: m_Value(V&: A)), R: m_Constant(C&: BA)), |
1518 | R: m_ImmConstant(C&: CA))) && |
1519 | BA->isElementWiseEqual(Y: CA) && A->getType() == DestTy) { |
1520 | Constant *WideCurrShAmt = |
1521 | ConstantFoldCastOperand(Opcode: Instruction::SExt, C: CA, DestTy, DL); |
1522 | assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail" ); |
1523 | Constant *NumLowbitsLeft = ConstantExpr::getSub( |
1524 | C1: ConstantInt::get(Ty: DestTy, V: SrcTy->getScalarSizeInBits()), C2: WideCurrShAmt); |
1525 | Constant *NewShAmt = ConstantExpr::getSub( |
1526 | C1: ConstantInt::get(Ty: DestTy, V: DestTy->getScalarSizeInBits()), |
1527 | C2: NumLowbitsLeft); |
1528 | NewShAmt = |
1529 | Constant::mergeUndefsWith(C: Constant::mergeUndefsWith(C: NewShAmt, Other: BA), Other: CA); |
1530 | A = Builder.CreateShl(LHS: A, RHS: NewShAmt, Name: Sext.getName()); |
1531 | return BinaryOperator::CreateAShr(V1: A, V2: NewShAmt); |
1532 | } |
1533 | |
1534 | // Splatting a bit of constant-index across a value: |
1535 | // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 |
1536 | // If the dest type is different, use a cast (adjust use check). |
1537 | if (match(V: Src, P: m_OneUse(SubPattern: m_AShr(L: m_Trunc(Op: m_Value(V&: X)), |
1538 | R: m_SpecificInt(V: SrcBitSize - 1))))) { |
1539 | Type *XTy = X->getType(); |
1540 | unsigned XBitSize = XTy->getScalarSizeInBits(); |
1541 | Constant *ShlAmtC = ConstantInt::get(Ty: XTy, V: XBitSize - SrcBitSize); |
1542 | Constant *AshrAmtC = ConstantInt::get(Ty: XTy, V: XBitSize - 1); |
1543 | if (XTy == DestTy) |
1544 | return BinaryOperator::CreateAShr(V1: Builder.CreateShl(LHS: X, RHS: ShlAmtC), |
1545 | V2: AshrAmtC); |
1546 | if (cast<BinaryOperator>(Val: Src)->getOperand(i_nocapture: 0)->hasOneUse()) { |
1547 | Value *Ashr = Builder.CreateAShr(LHS: Builder.CreateShl(LHS: X, RHS: ShlAmtC), RHS: AshrAmtC); |
1548 | return CastInst::CreateIntegerCast(S: Ashr, Ty: DestTy, /* isSigned */ true); |
1549 | } |
1550 | } |
1551 | |
1552 | if (match(V: Src, P: m_VScale())) { |
1553 | if (Sext.getFunction() && |
1554 | Sext.getFunction()->hasFnAttribute(Kind: Attribute::VScaleRange)) { |
1555 | Attribute Attr = |
1556 | Sext.getFunction()->getFnAttribute(Kind: Attribute::VScaleRange); |
1557 | if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { |
1558 | if (Log2_32(Value: *MaxVScale) < (SrcBitSize - 1)) { |
1559 | Value *VScale = Builder.CreateVScale(Scaling: ConstantInt::get(Ty: DestTy, V: 1)); |
1560 | return replaceInstUsesWith(I&: Sext, V: VScale); |
1561 | } |
1562 | } |
1563 | } |
1564 | } |
1565 | |
1566 | return nullptr; |
1567 | } |
1568 | |
1569 | /// Return a Constant* for the specified floating-point constant if it fits |
1570 | /// in the specified FP type without changing its value. |
1571 | static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { |
1572 | bool losesInfo; |
1573 | APFloat F = CFP->getValueAPF(); |
1574 | (void)F.convert(ToSemantics: Sem, RM: APFloat::rmNearestTiesToEven, losesInfo: &losesInfo); |
1575 | return !losesInfo; |
1576 | } |
1577 | |
1578 | static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) { |
1579 | if (CFP->getType() == Type::getPPC_FP128Ty(C&: CFP->getContext())) |
1580 | return nullptr; // No constant folding of this. |
1581 | // See if the value can be truncated to bfloat and then reextended. |
1582 | if (PreferBFloat && fitsInFPType(CFP, Sem: APFloat::BFloat())) |
1583 | return Type::getBFloatTy(C&: CFP->getContext()); |
1584 | // See if the value can be truncated to half and then reextended. |
1585 | if (!PreferBFloat && fitsInFPType(CFP, Sem: APFloat::IEEEhalf())) |
1586 | return Type::getHalfTy(C&: CFP->getContext()); |
1587 | // See if the value can be truncated to float and then reextended. |
1588 | if (fitsInFPType(CFP, Sem: APFloat::IEEEsingle())) |
1589 | return Type::getFloatTy(C&: CFP->getContext()); |
1590 | if (CFP->getType()->isDoubleTy()) |
1591 | return nullptr; // Won't shrink. |
1592 | if (fitsInFPType(CFP, Sem: APFloat::IEEEdouble())) |
1593 | return Type::getDoubleTy(C&: CFP->getContext()); |
1594 | // Don't try to shrink to various long double types. |
1595 | return nullptr; |
1596 | } |
1597 | |
1598 | // Determine if this is a vector of ConstantFPs and if so, return the minimal |
1599 | // type we can safely truncate all elements to. |
1600 | static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) { |
1601 | auto *CV = dyn_cast<Constant>(Val: V); |
1602 | auto *CVVTy = dyn_cast<FixedVectorType>(Val: V->getType()); |
1603 | if (!CV || !CVVTy) |
1604 | return nullptr; |
1605 | |
1606 | Type *MinType = nullptr; |
1607 | |
1608 | unsigned NumElts = CVVTy->getNumElements(); |
1609 | |
1610 | // For fixed-width vectors we find the minimal type by looking |
1611 | // through the constant values of the vector. |
1612 | for (unsigned i = 0; i != NumElts; ++i) { |
1613 | if (isa<UndefValue>(Val: CV->getAggregateElement(Elt: i))) |
1614 | continue; |
1615 | |
1616 | auto *CFP = dyn_cast_or_null<ConstantFP>(Val: CV->getAggregateElement(Elt: i)); |
1617 | if (!CFP) |
1618 | return nullptr; |
1619 | |
1620 | Type *T = shrinkFPConstant(CFP, PreferBFloat); |
1621 | if (!T) |
1622 | return nullptr; |
1623 | |
1624 | // If we haven't found a type yet or this type has a larger mantissa than |
1625 | // our previous type, this is our new minimal type. |
1626 | if (!MinType || T->getFPMantissaWidth() > MinType->getFPMantissaWidth()) |
1627 | MinType = T; |
1628 | } |
1629 | |
1630 | // Make a vector type from the minimal type. |
1631 | return MinType ? FixedVectorType::get(ElementType: MinType, NumElts) : nullptr; |
1632 | } |
1633 | |
1634 | /// Find the minimum FP type we can safely truncate to. |
1635 | static Type *getMinimumFPType(Value *V, bool PreferBFloat) { |
1636 | if (auto *FPExt = dyn_cast<FPExtInst>(Val: V)) |
1637 | return FPExt->getOperand(i_nocapture: 0)->getType(); |
1638 | |
1639 | // If this value is a constant, return the constant in the smallest FP type |
1640 | // that can accurately represent it. This allows us to turn |
1641 | // (float)((double)X+2.0) into x+2.0f. |
1642 | if (auto *CFP = dyn_cast<ConstantFP>(Val: V)) |
1643 | if (Type *T = shrinkFPConstant(CFP, PreferBFloat)) |
1644 | return T; |
1645 | |
1646 | // We can only correctly find a minimum type for a scalable vector when it is |
1647 | // a splat. For splats of constant values the fpext is wrapped up as a |
1648 | // ConstantExpr. |
1649 | if (auto *FPCExt = dyn_cast<ConstantExpr>(Val: V)) |
1650 | if (FPCExt->getOpcode() == Instruction::FPExt) |
1651 | return FPCExt->getOperand(i_nocapture: 0)->getType(); |
1652 | |
1653 | // Try to shrink a vector of FP constants. This returns nullptr on scalable |
1654 | // vectors |
1655 | if (Type *T = shrinkFPConstantVector(V, PreferBFloat)) |
1656 | return T; |
1657 | |
1658 | return V->getType(); |
1659 | } |
1660 | |
1661 | /// Return true if the cast from integer to FP can be proven to be exact for all |
1662 | /// possible inputs (the conversion does not lose any precision). |
1663 | static bool isKnownExactCastIntToFP(CastInst &I, InstCombinerImpl &IC) { |
1664 | CastInst::CastOps Opcode = I.getOpcode(); |
1665 | assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && |
1666 | "Unexpected cast" ); |
1667 | Value *Src = I.getOperand(i_nocapture: 0); |
1668 | Type *SrcTy = Src->getType(); |
1669 | Type *FPTy = I.getType(); |
1670 | bool IsSigned = Opcode == Instruction::SIToFP; |
1671 | int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; |
1672 | |
1673 | // Easy case - if the source integer type has less bits than the FP mantissa, |
1674 | // then the cast must be exact. |
1675 | int DestNumSigBits = FPTy->getFPMantissaWidth(); |
1676 | if (SrcSize <= DestNumSigBits) |
1677 | return true; |
1678 | |
1679 | // Cast from FP to integer and back to FP is independent of the intermediate |
1680 | // integer width because of poison on overflow. |
1681 | Value *F; |
1682 | if (match(V: Src, P: m_FPToSI(Op: m_Value(V&: F))) || match(V: Src, P: m_FPToUI(Op: m_Value(V&: F)))) { |
1683 | // If this is uitofp (fptosi F), the source needs an extra bit to avoid |
1684 | // potential rounding of negative FP input values. |
1685 | int SrcNumSigBits = F->getType()->getFPMantissaWidth(); |
1686 | if (!IsSigned && match(V: Src, P: m_FPToSI(Op: m_Value()))) |
1687 | SrcNumSigBits++; |
1688 | |
1689 | // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal |
1690 | // significant bits than the destination (and make sure neither type is |
1691 | // weird -- ppc_fp128). |
1692 | if (SrcNumSigBits > 0 && DestNumSigBits > 0 && |
1693 | SrcNumSigBits <= DestNumSigBits) |
1694 | return true; |
1695 | } |
1696 | |
1697 | // TODO: |
1698 | // Try harder to find if the source integer type has less significant bits. |
1699 | // For example, compute number of sign bits. |
1700 | KnownBits SrcKnown = IC.computeKnownBits(V: Src, Depth: 0, CxtI: &I); |
1701 | int SigBits = (int)SrcTy->getScalarSizeInBits() - |
1702 | SrcKnown.countMinLeadingZeros() - |
1703 | SrcKnown.countMinTrailingZeros(); |
1704 | if (SigBits <= DestNumSigBits) |
1705 | return true; |
1706 | |
1707 | return false; |
1708 | } |
1709 | |
1710 | Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) { |
1711 | if (Instruction *I = commonCastTransforms(CI&: FPT)) |
1712 | return I; |
1713 | |
1714 | // If we have fptrunc(OpI (fpextend x), (fpextend y)), we would like to |
1715 | // simplify this expression to avoid one or more of the trunc/extend |
1716 | // operations if we can do so without changing the numerical results. |
1717 | // |
1718 | // The exact manner in which the widths of the operands interact to limit |
1719 | // what we can and cannot do safely varies from operation to operation, and |
1720 | // is explained below in the various case statements. |
1721 | Type *Ty = FPT.getType(); |
1722 | auto *BO = dyn_cast<BinaryOperator>(Val: FPT.getOperand(i_nocapture: 0)); |
1723 | if (BO && BO->hasOneUse()) { |
1724 | Type *LHSMinType = |
1725 | getMinimumFPType(V: BO->getOperand(i_nocapture: 0), /*PreferBFloat=*/Ty->isBFloatTy()); |
1726 | Type *RHSMinType = |
1727 | getMinimumFPType(V: BO->getOperand(i_nocapture: 1), /*PreferBFloat=*/Ty->isBFloatTy()); |
1728 | unsigned OpWidth = BO->getType()->getFPMantissaWidth(); |
1729 | unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); |
1730 | unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); |
1731 | unsigned SrcWidth = std::max(a: LHSWidth, b: RHSWidth); |
1732 | unsigned DstWidth = Ty->getFPMantissaWidth(); |
1733 | switch (BO->getOpcode()) { |
1734 | default: break; |
1735 | case Instruction::FAdd: |
1736 | case Instruction::FSub: |
1737 | // For addition and subtraction, the infinitely precise result can |
1738 | // essentially be arbitrarily wide; proving that double rounding |
1739 | // will not occur because the result of OpI is exact (as we will for |
1740 | // FMul, for example) is hopeless. However, we *can* nonetheless |
1741 | // frequently know that double rounding cannot occur (or that it is |
1742 | // innocuous) by taking advantage of the specific structure of |
1743 | // infinitely-precise results that admit double rounding. |
1744 | // |
1745 | // Specifically, if OpWidth >= 2*DstWdith+1 and DstWidth is sufficient |
1746 | // to represent both sources, we can guarantee that the double |
1747 | // rounding is innocuous (See p50 of Figueroa's 2000 PhD thesis, |
1748 | // "A Rigorous Framework for Fully Supporting the IEEE Standard ..." |
1749 | // for proof of this fact). |
1750 | // |
1751 | // Note: Figueroa does not consider the case where DstFormat != |
1752 | // SrcFormat. It's possible (likely even!) that this analysis |
1753 | // could be tightened for those cases, but they are rare (the main |
1754 | // case of interest here is (float)((double)float + float)). |
1755 | if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { |
1756 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1757 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1758 | Instruction *RI = BinaryOperator::Create(Op: BO->getOpcode(), S1: LHS, S2: RHS); |
1759 | RI->copyFastMathFlags(I: BO); |
1760 | return RI; |
1761 | } |
1762 | break; |
1763 | case Instruction::FMul: |
1764 | // For multiplication, the infinitely precise result has at most |
1765 | // LHSWidth + RHSWidth significant bits; if OpWidth is sufficient |
1766 | // that such a value can be exactly represented, then no double |
1767 | // rounding can possibly occur; we can safely perform the operation |
1768 | // in the destination format if it can represent both sources. |
1769 | if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { |
1770 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1771 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1772 | return BinaryOperator::CreateFMulFMF(V1: LHS, V2: RHS, FMFSource: BO); |
1773 | } |
1774 | break; |
1775 | case Instruction::FDiv: |
1776 | // For division, we use again use the bound from Figueroa's |
1777 | // dissertation. I am entirely certain that this bound can be |
1778 | // tightened in the unbalanced operand case by an analysis based on |
1779 | // the diophantine rational approximation bound, but the well-known |
1780 | // condition used here is a good conservative first pass. |
1781 | // TODO: Tighten bound via rigorous analysis of the unbalanced case. |
1782 | if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { |
1783 | Value *LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: Ty); |
1784 | Value *RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: Ty); |
1785 | return BinaryOperator::CreateFDivFMF(V1: LHS, V2: RHS, FMFSource: BO); |
1786 | } |
1787 | break; |
1788 | case Instruction::FRem: { |
1789 | // Remainder is straightforward. Remainder is always exact, so the |
1790 | // type of OpI doesn't enter into things at all. We simply evaluate |
1791 | // in whichever source type is larger, then convert to the |
1792 | // destination type. |
1793 | if (SrcWidth == OpWidth) |
1794 | break; |
1795 | Value *LHS, *RHS; |
1796 | if (LHSWidth == SrcWidth) { |
1797 | LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: LHSMinType); |
1798 | RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: LHSMinType); |
1799 | } else { |
1800 | LHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 0), DestTy: RHSMinType); |
1801 | RHS = Builder.CreateFPTrunc(V: BO->getOperand(i_nocapture: 1), DestTy: RHSMinType); |
1802 | } |
1803 | |
1804 | Value *ExactResult = Builder.CreateFRemFMF(L: LHS, R: RHS, FMFSource: BO); |
1805 | return CastInst::CreateFPCast(S: ExactResult, Ty); |
1806 | } |
1807 | } |
1808 | } |
1809 | |
1810 | // (fptrunc (fneg x)) -> (fneg (fptrunc x)) |
1811 | Value *X; |
1812 | Instruction *Op = dyn_cast<Instruction>(Val: FPT.getOperand(i_nocapture: 0)); |
1813 | if (Op && Op->hasOneUse()) { |
1814 | // FIXME: The FMF should propagate from the fptrunc, not the source op. |
1815 | IRBuilder<>::FastMathFlagGuard FMFG(Builder); |
1816 | if (isa<FPMathOperator>(Val: Op)) |
1817 | Builder.setFastMathFlags(Op->getFastMathFlags()); |
1818 | |
1819 | if (match(V: Op, P: m_FNeg(X: m_Value(V&: X)))) { |
1820 | Value *InnerTrunc = Builder.CreateFPTrunc(V: X, DestTy: Ty); |
1821 | |
1822 | return UnaryOperator::CreateFNegFMF(Op: InnerTrunc, FMFSource: Op); |
1823 | } |
1824 | |
1825 | // If we are truncating a select that has an extended operand, we can |
1826 | // narrow the other operand and do the select as a narrow op. |
1827 | Value *Cond, *X, *Y; |
1828 | if (match(V: Op, P: m_Select(C: m_Value(V&: Cond), L: m_FPExt(Op: m_Value(V&: X)), R: m_Value(V&: Y))) && |
1829 | X->getType() == Ty) { |
1830 | // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) |
1831 | Value *NarrowY = Builder.CreateFPTrunc(V: Y, DestTy: Ty); |
1832 | Value *Sel = Builder.CreateSelect(C: Cond, True: X, False: NarrowY, Name: "narrow.sel" , MDFrom: Op); |
1833 | return replaceInstUsesWith(I&: FPT, V: Sel); |
1834 | } |
1835 | if (match(V: Op, P: m_Select(C: m_Value(V&: Cond), L: m_Value(V&: Y), R: m_FPExt(Op: m_Value(V&: X)))) && |
1836 | X->getType() == Ty) { |
1837 | // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X |
1838 | Value *NarrowY = Builder.CreateFPTrunc(V: Y, DestTy: Ty); |
1839 | Value *Sel = Builder.CreateSelect(C: Cond, True: NarrowY, False: X, Name: "narrow.sel" , MDFrom: Op); |
1840 | return replaceInstUsesWith(I&: FPT, V: Sel); |
1841 | } |
1842 | } |
1843 | |
1844 | if (auto *II = dyn_cast<IntrinsicInst>(Val: FPT.getOperand(i_nocapture: 0))) { |
1845 | switch (II->getIntrinsicID()) { |
1846 | default: break; |
1847 | case Intrinsic::ceil: |
1848 | case Intrinsic::fabs: |
1849 | case Intrinsic::floor: |
1850 | case Intrinsic::nearbyint: |
1851 | case Intrinsic::rint: |
1852 | case Intrinsic::round: |
1853 | case Intrinsic::roundeven: |
1854 | case Intrinsic::trunc: { |
1855 | Value *Src = II->getArgOperand(i: 0); |
1856 | if (!Src->hasOneUse()) |
1857 | break; |
1858 | |
1859 | // Except for fabs, this transformation requires the input of the unary FP |
1860 | // operation to be itself an fpext from the type to which we're |
1861 | // truncating. |
1862 | if (II->getIntrinsicID() != Intrinsic::fabs) { |
1863 | FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Val: Src); |
1864 | if (!FPExtSrc || FPExtSrc->getSrcTy() != Ty) |
1865 | break; |
1866 | } |
1867 | |
1868 | // Do unary FP operation on smaller type. |
1869 | // (fptrunc (fabs x)) -> (fabs (fptrunc x)) |
1870 | Value *InnerTrunc = Builder.CreateFPTrunc(V: Src, DestTy: Ty); |
1871 | Function *Overload = Intrinsic::getDeclaration(M: FPT.getModule(), |
1872 | id: II->getIntrinsicID(), Tys: Ty); |
1873 | SmallVector<OperandBundleDef, 1> OpBundles; |
1874 | II->getOperandBundlesAsDefs(Defs&: OpBundles); |
1875 | CallInst *NewCI = |
1876 | CallInst::Create(Func: Overload, Args: {InnerTrunc}, Bundles: OpBundles, NameStr: II->getName()); |
1877 | NewCI->copyFastMathFlags(I: II); |
1878 | return NewCI; |
1879 | } |
1880 | } |
1881 | } |
1882 | |
1883 | if (Instruction *I = shrinkInsertElt(Trunc&: FPT, Builder)) |
1884 | return I; |
1885 | |
1886 | Value *Src = FPT.getOperand(i_nocapture: 0); |
1887 | if (isa<SIToFPInst>(Val: Src) || isa<UIToFPInst>(Val: Src)) { |
1888 | auto *FPCast = cast<CastInst>(Val: Src); |
1889 | if (isKnownExactCastIntToFP(I&: *FPCast, IC&: *this)) |
1890 | return CastInst::Create(FPCast->getOpcode(), S: FPCast->getOperand(i_nocapture: 0), Ty); |
1891 | } |
1892 | |
1893 | return nullptr; |
1894 | } |
1895 | |
1896 | Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) { |
1897 | // If the source operand is a cast from integer to FP and known exact, then |
1898 | // cast the integer operand directly to the destination type. |
1899 | Type *Ty = FPExt.getType(); |
1900 | Value *Src = FPExt.getOperand(i_nocapture: 0); |
1901 | if (isa<SIToFPInst>(Val: Src) || isa<UIToFPInst>(Val: Src)) { |
1902 | auto *FPCast = cast<CastInst>(Val: Src); |
1903 | if (isKnownExactCastIntToFP(I&: *FPCast, IC&: *this)) |
1904 | return CastInst::Create(FPCast->getOpcode(), S: FPCast->getOperand(i_nocapture: 0), Ty); |
1905 | } |
1906 | |
1907 | return commonCastTransforms(CI&: FPExt); |
1908 | } |
1909 | |
1910 | /// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) |
1911 | /// This is safe if the intermediate type has enough bits in its mantissa to |
1912 | /// accurately represent all values of X. For example, this won't work with |
1913 | /// i64 -> float -> i64. |
1914 | Instruction *InstCombinerImpl::foldItoFPtoI(CastInst &FI) { |
1915 | if (!isa<UIToFPInst>(Val: FI.getOperand(i_nocapture: 0)) && !isa<SIToFPInst>(Val: FI.getOperand(i_nocapture: 0))) |
1916 | return nullptr; |
1917 | |
1918 | auto *OpI = cast<CastInst>(Val: FI.getOperand(i_nocapture: 0)); |
1919 | Value *X = OpI->getOperand(i_nocapture: 0); |
1920 | Type *XType = X->getType(); |
1921 | Type *DestType = FI.getType(); |
1922 | bool IsOutputSigned = isa<FPToSIInst>(Val: FI); |
1923 | |
1924 | // Since we can assume the conversion won't overflow, our decision as to |
1925 | // whether the input will fit in the float should depend on the minimum |
1926 | // of the input range and output range. |
1927 | |
1928 | // This means this is also safe for a signed input and unsigned output, since |
1929 | // a negative input would lead to undefined behavior. |
1930 | if (!isKnownExactCastIntToFP(I&: *OpI, IC&: *this)) { |
1931 | // The first cast may not round exactly based on the source integer width |
1932 | // and FP width, but the overflow UB rules can still allow this to fold. |
1933 | // If the destination type is narrow, that means the intermediate FP value |
1934 | // must be large enough to hold the source value exactly. |
1935 | // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. |
1936 | int OutputSize = (int)DestType->getScalarSizeInBits(); |
1937 | if (OutputSize > OpI->getType()->getFPMantissaWidth()) |
1938 | return nullptr; |
1939 | } |
1940 | |
1941 | if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { |
1942 | bool IsInputSigned = isa<SIToFPInst>(Val: OpI); |
1943 | if (IsInputSigned && IsOutputSigned) |
1944 | return new SExtInst(X, DestType); |
1945 | return new ZExtInst(X, DestType); |
1946 | } |
1947 | if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits()) |
1948 | return new TruncInst(X, DestType); |
1949 | |
1950 | assert(XType == DestType && "Unexpected types for int to FP to int casts" ); |
1951 | return replaceInstUsesWith(I&: FI, V: X); |
1952 | } |
1953 | |
1954 | static Instruction *foldFPtoI(Instruction &FI, InstCombiner &IC) { |
1955 | // fpto{u/s}i non-norm --> 0 |
1956 | FPClassTest Mask = |
1957 | FI.getOpcode() == Instruction::FPToUI ? fcPosNormal : fcNormal; |
1958 | KnownFPClass FPClass = |
1959 | computeKnownFPClass(V: FI.getOperand(i: 0), InterestedClasses: Mask, /*Depth=*/0, |
1960 | SQ: IC.getSimplifyQuery().getWithInstruction(I: &FI)); |
1961 | if (FPClass.isKnownNever(Mask)) |
1962 | return IC.replaceInstUsesWith(I&: FI, V: ConstantInt::getNullValue(Ty: FI.getType())); |
1963 | |
1964 | return nullptr; |
1965 | } |
1966 | |
1967 | Instruction *InstCombinerImpl::visitFPToUI(FPToUIInst &FI) { |
1968 | if (Instruction *I = foldItoFPtoI(FI)) |
1969 | return I; |
1970 | |
1971 | if (Instruction *I = foldFPtoI(FI, IC&: *this)) |
1972 | return I; |
1973 | |
1974 | return commonCastTransforms(CI&: FI); |
1975 | } |
1976 | |
1977 | Instruction *InstCombinerImpl::visitFPToSI(FPToSIInst &FI) { |
1978 | if (Instruction *I = foldItoFPtoI(FI)) |
1979 | return I; |
1980 | |
1981 | if (Instruction *I = foldFPtoI(FI, IC&: *this)) |
1982 | return I; |
1983 | |
1984 | return commonCastTransforms(CI&: FI); |
1985 | } |
1986 | |
1987 | Instruction *InstCombinerImpl::visitUIToFP(CastInst &CI) { |
1988 | if (Instruction *R = commonCastTransforms(CI)) |
1989 | return R; |
1990 | if (!CI.hasNonNeg() && isKnownNonNegative(V: CI.getOperand(i_nocapture: 0), SQ)) { |
1991 | CI.setNonNeg(); |
1992 | return &CI; |
1993 | } |
1994 | return nullptr; |
1995 | } |
1996 | |
1997 | Instruction *InstCombinerImpl::visitSIToFP(CastInst &CI) { |
1998 | if (Instruction *R = commonCastTransforms(CI)) |
1999 | return R; |
2000 | if (isKnownNonNegative(V: CI.getOperand(i_nocapture: 0), SQ)) { |
2001 | auto *UI = |
2002 | CastInst::Create(Instruction::UIToFP, S: CI.getOperand(i_nocapture: 0), Ty: CI.getType()); |
2003 | UI->setNonNeg(true); |
2004 | return UI; |
2005 | } |
2006 | return nullptr; |
2007 | } |
2008 | |
2009 | Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { |
2010 | // If the source integer type is not the intptr_t type for this target, do a |
2011 | // trunc or zext to the intptr_t type, then inttoptr of it. This allows the |
2012 | // cast to be exposed to other transforms. |
2013 | unsigned AS = CI.getAddressSpace(); |
2014 | if (CI.getOperand(i_nocapture: 0)->getType()->getScalarSizeInBits() != |
2015 | DL.getPointerSizeInBits(AS)) { |
2016 | Type *Ty = CI.getOperand(i_nocapture: 0)->getType()->getWithNewType( |
2017 | EltTy: DL.getIntPtrType(C&: CI.getContext(), AddressSpace: AS)); |
2018 | Value *P = Builder.CreateZExtOrTrunc(V: CI.getOperand(i_nocapture: 0), DestTy: Ty); |
2019 | return new IntToPtrInst(P, CI.getType()); |
2020 | } |
2021 | |
2022 | if (Instruction *I = commonCastTransforms(CI)) |
2023 | return I; |
2024 | |
2025 | return nullptr; |
2026 | } |
2027 | |
2028 | Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { |
2029 | // If the destination integer type is not the intptr_t type for this target, |
2030 | // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast |
2031 | // to be exposed to other transforms. |
2032 | Value *SrcOp = CI.getPointerOperand(); |
2033 | Type *SrcTy = SrcOp->getType(); |
2034 | Type *Ty = CI.getType(); |
2035 | unsigned AS = CI.getPointerAddressSpace(); |
2036 | unsigned TySize = Ty->getScalarSizeInBits(); |
2037 | unsigned PtrSize = DL.getPointerSizeInBits(AS); |
2038 | if (TySize != PtrSize) { |
2039 | Type *IntPtrTy = |
2040 | SrcTy->getWithNewType(EltTy: DL.getIntPtrType(C&: CI.getContext(), AddressSpace: AS)); |
2041 | Value *P = Builder.CreatePtrToInt(V: SrcOp, DestTy: IntPtrTy); |
2042 | return CastInst::CreateIntegerCast(S: P, Ty, /*isSigned=*/false); |
2043 | } |
2044 | |
2045 | // (ptrtoint (ptrmask P, M)) |
2046 | // -> (and (ptrtoint P), M) |
2047 | // This is generally beneficial as `and` is better supported than `ptrmask`. |
2048 | Value *Ptr, *Mask; |
2049 | if (match(V: SrcOp, P: m_OneUse(SubPattern: m_Intrinsic<Intrinsic::ptrmask>(Op0: m_Value(V&: Ptr), |
2050 | Op1: m_Value(V&: Mask)))) && |
2051 | Mask->getType() == Ty) |
2052 | return BinaryOperator::CreateAnd(V1: Builder.CreatePtrToInt(V: Ptr, DestTy: Ty), V2: Mask); |
2053 | |
2054 | if (auto *GEP = dyn_cast<GEPOperator>(Val: SrcOp)) { |
2055 | // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. |
2056 | // While this can increase the number of instructions it doesn't actually |
2057 | // increase the overall complexity since the arithmetic is just part of |
2058 | // the GEP otherwise. |
2059 | if (GEP->hasOneUse() && |
2060 | isa<ConstantPointerNull>(Val: GEP->getPointerOperand())) { |
2061 | return replaceInstUsesWith(I&: CI, |
2062 | V: Builder.CreateIntCast(V: EmitGEPOffset(GEP), DestTy: Ty, |
2063 | /*isSigned=*/false)); |
2064 | } |
2065 | |
2066 | // (ptrtoint (gep (inttoptr Base), ...)) -> Base + Offset |
2067 | Value *Base; |
2068 | if (GEP->hasOneUse() && |
2069 | match(V: GEP->getPointerOperand(), P: m_OneUse(SubPattern: m_IntToPtr(Op: m_Value(V&: Base)))) && |
2070 | Base->getType() == Ty) { |
2071 | Value *Offset = EmitGEPOffset(GEP); |
2072 | auto *NewOp = BinaryOperator::CreateAdd(V1: Base, V2: Offset); |
2073 | if (GEP->hasNoUnsignedWrap() || |
2074 | (GEP->hasNoUnsignedSignedWrap() && |
2075 | isKnownNonNegative(V: Offset, SQ: SQ.getWithInstruction(I: &CI)))) |
2076 | NewOp->setHasNoUnsignedWrap(true); |
2077 | return NewOp; |
2078 | } |
2079 | } |
2080 | |
2081 | Value *Vec, *Scalar, *Index; |
2082 | if (match(V: SrcOp, P: m_OneUse(SubPattern: m_InsertElt(Val: m_IntToPtr(Op: m_Value(V&: Vec)), |
2083 | Elt: m_Value(V&: Scalar), Idx: m_Value(V&: Index)))) && |
2084 | Vec->getType() == Ty) { |
2085 | assert(Vec->getType()->getScalarSizeInBits() == PtrSize && "Wrong type" ); |
2086 | // Convert the scalar to int followed by insert to eliminate one cast: |
2087 | // p2i (ins (i2p Vec), Scalar, Index --> ins Vec, (p2i Scalar), Index |
2088 | Value *NewCast = Builder.CreatePtrToInt(V: Scalar, DestTy: Ty->getScalarType()); |
2089 | return InsertElementInst::Create(Vec, NewElt: NewCast, Idx: Index); |
2090 | } |
2091 | |
2092 | return commonCastTransforms(CI); |
2093 | } |
2094 | |
2095 | /// This input value (which is known to have vector type) is being zero extended |
2096 | /// or truncated to the specified vector type. Since the zext/trunc is done |
2097 | /// using an integer type, we have a (bitcast(cast(bitcast))) pattern, |
2098 | /// endianness will impact which end of the vector that is extended or |
2099 | /// truncated. |
2100 | /// |
2101 | /// A vector is always stored with index 0 at the lowest address, which |
2102 | /// corresponds to the most significant bits for a big endian stored integer and |
2103 | /// the least significant bits for little endian. A trunc/zext of an integer |
2104 | /// impacts the big end of the integer. Thus, we need to add/remove elements at |
2105 | /// the front of the vector for big endian targets, and the back of the vector |
2106 | /// for little endian targets. |
2107 | /// |
2108 | /// Try to replace it with a shuffle (and vector/vector bitcast) if possible. |
2109 | /// |
2110 | /// The source and destination vector types may have different element types. |
2111 | static Instruction * |
2112 | optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, |
2113 | InstCombinerImpl &IC) { |
2114 | // We can only do this optimization if the output is a multiple of the input |
2115 | // element size, or the input is a multiple of the output element size. |
2116 | // Convert the input type to have the same element type as the output. |
2117 | VectorType *SrcTy = cast<VectorType>(Val: InVal->getType()); |
2118 | |
2119 | if (SrcTy->getElementType() != DestTy->getElementType()) { |
2120 | // The input types don't need to be identical, but for now they must be the |
2121 | // same size. There is no specific reason we couldn't handle things like |
2122 | // <4 x i16> -> <4 x i32> by bitcasting to <2 x i32> but haven't gotten |
2123 | // there yet. |
2124 | if (SrcTy->getElementType()->getPrimitiveSizeInBits() != |
2125 | DestTy->getElementType()->getPrimitiveSizeInBits()) |
2126 | return nullptr; |
2127 | |
2128 | SrcTy = |
2129 | FixedVectorType::get(ElementType: DestTy->getElementType(), |
2130 | NumElts: cast<FixedVectorType>(Val: SrcTy)->getNumElements()); |
2131 | InVal = IC.Builder.CreateBitCast(V: InVal, DestTy: SrcTy); |
2132 | } |
2133 | |
2134 | bool IsBigEndian = IC.getDataLayout().isBigEndian(); |
2135 | unsigned SrcElts = cast<FixedVectorType>(Val: SrcTy)->getNumElements(); |
2136 | unsigned DestElts = cast<FixedVectorType>(Val: DestTy)->getNumElements(); |
2137 | |
2138 | assert(SrcElts != DestElts && "Element counts should be different." ); |
2139 | |
2140 | // Now that the element types match, get the shuffle mask and RHS of the |
2141 | // shuffle to use, which depends on whether we're increasing or decreasing the |
2142 | // size of the input. |
2143 | auto ShuffleMaskStorage = llvm::to_vector<16>(Range: llvm::seq<int>(Begin: 0, End: SrcElts)); |
2144 | ArrayRef<int> ShuffleMask; |
2145 | Value *V2; |
2146 | |
2147 | if (SrcElts > DestElts) { |
2148 | // If we're shrinking the number of elements (rewriting an integer |
2149 | // truncate), just shuffle in the elements corresponding to the least |
2150 | // significant bits from the input and use poison as the second shuffle |
2151 | // input. |
2152 | V2 = PoisonValue::get(T: SrcTy); |
2153 | // Make sure the shuffle mask selects the "least significant bits" by |
2154 | // keeping elements from back of the src vector for big endian, and from the |
2155 | // front for little endian. |
2156 | ShuffleMask = ShuffleMaskStorage; |
2157 | if (IsBigEndian) |
2158 | ShuffleMask = ShuffleMask.take_back(N: DestElts); |
2159 | else |
2160 | ShuffleMask = ShuffleMask.take_front(N: DestElts); |
2161 | } else { |
2162 | // If we're increasing the number of elements (rewriting an integer zext), |
2163 | // shuffle in all of the elements from InVal. Fill the rest of the result |
2164 | // elements with zeros from a constant zero. |
2165 | V2 = Constant::getNullValue(Ty: SrcTy); |
2166 | // Use first elt from V2 when indicating zero in the shuffle mask. |
2167 | uint32_t NullElt = SrcElts; |
2168 | // Extend with null values in the "most significant bits" by adding elements |
2169 | // in front of the src vector for big endian, and at the back for little |
2170 | // endian. |
2171 | unsigned DeltaElts = DestElts - SrcElts; |
2172 | if (IsBigEndian) |
2173 | ShuffleMaskStorage.insert(I: ShuffleMaskStorage.begin(), NumToInsert: DeltaElts, Elt: NullElt); |
2174 | else |
2175 | ShuffleMaskStorage.append(NumInputs: DeltaElts, Elt: NullElt); |
2176 | ShuffleMask = ShuffleMaskStorage; |
2177 | } |
2178 | |
2179 | return new ShuffleVectorInst(InVal, V2, ShuffleMask); |
2180 | } |
2181 | |
2182 | static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) { |
2183 | return Value % Ty->getPrimitiveSizeInBits() == 0; |
2184 | } |
2185 | |
2186 | static unsigned getTypeSizeIndex(unsigned Value, Type *Ty) { |
2187 | return Value / Ty->getPrimitiveSizeInBits(); |
2188 | } |
2189 | |
2190 | /// V is a value which is inserted into a vector of VecEltTy. |
2191 | /// Look through the value to see if we can decompose it into |
2192 | /// insertions into the vector. See the example in the comment for |
2193 | /// OptimizeIntegerToVectorInsertions for the pattern this handles. |
2194 | /// The type of V is always a non-zero multiple of VecEltTy's size. |
2195 | /// Shift is the number of bits between the lsb of V and the lsb of |
2196 | /// the vector. |
2197 | /// |
2198 | /// This returns false if the pattern can't be matched or true if it can, |
2199 | /// filling in Elements with the elements found here. |
2200 | static bool collectInsertionElements(Value *V, unsigned Shift, |
2201 | SmallVectorImpl<Value *> &Elements, |
2202 | Type *VecEltTy, bool isBigEndian) { |
2203 | assert(isMultipleOfTypeSize(Shift, VecEltTy) && |
2204 | "Shift should be a multiple of the element type size" ); |
2205 | |
2206 | // Undef values never contribute useful bits to the result. |
2207 | if (isa<UndefValue>(Val: V)) return true; |
2208 | |
2209 | // If we got down to a value of the right type, we win, try inserting into the |
2210 | // right element. |
2211 | if (V->getType() == VecEltTy) { |
2212 | // Inserting null doesn't actually insert any elements. |
2213 | if (Constant *C = dyn_cast<Constant>(Val: V)) |
2214 | if (C->isNullValue()) |
2215 | return true; |
2216 | |
2217 | unsigned ElementIndex = getTypeSizeIndex(Value: Shift, Ty: VecEltTy); |
2218 | if (isBigEndian) |
2219 | ElementIndex = Elements.size() - ElementIndex - 1; |
2220 | |
2221 | // Fail if multiple elements are inserted into this slot. |
2222 | if (Elements[ElementIndex]) |
2223 | return false; |
2224 | |
2225 | Elements[ElementIndex] = V; |
2226 | return true; |
2227 | } |
2228 | |
2229 | if (Constant *C = dyn_cast<Constant>(Val: V)) { |
2230 | // Figure out the # elements this provides, and bitcast it or slice it up |
2231 | // as required. |
2232 | unsigned NumElts = getTypeSizeIndex(Value: C->getType()->getPrimitiveSizeInBits(), |
2233 | Ty: VecEltTy); |
2234 | // If the constant is the size of a vector element, we just need to bitcast |
2235 | // it to the right type so it gets properly inserted. |
2236 | if (NumElts == 1) |
2237 | return collectInsertionElements(V: ConstantExpr::getBitCast(C, Ty: VecEltTy), |
2238 | Shift, Elements, VecEltTy, isBigEndian); |
2239 | |
2240 | // Okay, this is a constant that covers multiple elements. Slice it up into |
2241 | // pieces and insert each element-sized piece into the vector. |
2242 | if (!isa<IntegerType>(Val: C->getType())) |
2243 | C = ConstantExpr::getBitCast(C, Ty: IntegerType::get(C&: V->getContext(), |
2244 | NumBits: C->getType()->getPrimitiveSizeInBits())); |
2245 | unsigned ElementSize = VecEltTy->getPrimitiveSizeInBits(); |
2246 | Type *ElementIntTy = IntegerType::get(C&: C->getContext(), NumBits: ElementSize); |
2247 | |
2248 | for (unsigned i = 0; i != NumElts; ++i) { |
2249 | unsigned ShiftI = i * ElementSize; |
2250 | Constant *Piece = ConstantFoldBinaryInstruction( |
2251 | Opcode: Instruction::LShr, V1: C, V2: ConstantInt::get(Ty: C->getType(), V: ShiftI)); |
2252 | if (!Piece) |
2253 | return false; |
2254 | |
2255 | Piece = ConstantExpr::getTrunc(C: Piece, Ty: ElementIntTy); |
2256 | if (!collectInsertionElements(V: Piece, Shift: ShiftI + Shift, Elements, VecEltTy, |
2257 | isBigEndian)) |
2258 | return false; |
2259 | } |
2260 | return true; |
2261 | } |
2262 | |
2263 | if (!V->hasOneUse()) return false; |
2264 | |
2265 | Instruction *I = dyn_cast<Instruction>(Val: V); |
2266 | if (!I) return false; |
2267 | switch (I->getOpcode()) { |
2268 | default: return false; // Unhandled case. |
2269 | case Instruction::BitCast: |
2270 | if (I->getOperand(i: 0)->getType()->isVectorTy()) |
2271 | return false; |
2272 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2273 | isBigEndian); |
2274 | case Instruction::ZExt: |
2275 | if (!isMultipleOfTypeSize( |
2276 | Value: I->getOperand(i: 0)->getType()->getPrimitiveSizeInBits(), |
2277 | Ty: VecEltTy)) |
2278 | return false; |
2279 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2280 | isBigEndian); |
2281 | case Instruction::Or: |
2282 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2283 | isBigEndian) && |
2284 | collectInsertionElements(V: I->getOperand(i: 1), Shift, Elements, VecEltTy, |
2285 | isBigEndian); |
2286 | case Instruction::Shl: { |
2287 | // Must be shifting by a constant that is a multiple of the element size. |
2288 | ConstantInt *CI = dyn_cast<ConstantInt>(Val: I->getOperand(i: 1)); |
2289 | if (!CI) return false; |
2290 | Shift += CI->getZExtValue(); |
2291 | if (!isMultipleOfTypeSize(Value: Shift, Ty: VecEltTy)) return false; |
2292 | return collectInsertionElements(V: I->getOperand(i: 0), Shift, Elements, VecEltTy, |
2293 | isBigEndian); |
2294 | } |
2295 | |
2296 | } |
2297 | } |
2298 | |
2299 | |
2300 | /// If the input is an 'or' instruction, we may be doing shifts and ors to |
2301 | /// assemble the elements of the vector manually. |
2302 | /// Try to rip the code out and replace it with insertelements. This is to |
2303 | /// optimize code like this: |
2304 | /// |
2305 | /// %tmp37 = bitcast float %inc to i32 |
2306 | /// %tmp38 = zext i32 %tmp37 to i64 |
2307 | /// %tmp31 = bitcast float %inc5 to i32 |
2308 | /// %tmp32 = zext i32 %tmp31 to i64 |
2309 | /// %tmp33 = shl i64 %tmp32, 32 |
2310 | /// %ins35 = or i64 %tmp33, %tmp38 |
2311 | /// %tmp43 = bitcast i64 %ins35 to <2 x float> |
2312 | /// |
2313 | /// Into two insertelements that do "buildvector{%inc, %inc5}". |
2314 | static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI, |
2315 | InstCombinerImpl &IC) { |
2316 | auto *DestVecTy = cast<FixedVectorType>(Val: CI.getType()); |
2317 | Value *IntInput = CI.getOperand(i_nocapture: 0); |
2318 | |
2319 | SmallVector<Value*, 8> Elements(DestVecTy->getNumElements()); |
2320 | if (!collectInsertionElements(V: IntInput, Shift: 0, Elements, |
2321 | VecEltTy: DestVecTy->getElementType(), |
2322 | isBigEndian: IC.getDataLayout().isBigEndian())) |
2323 | return nullptr; |
2324 | |
2325 | // If we succeeded, we know that all of the element are specified by Elements |
2326 | // or are zero if Elements has a null entry. Recast this as a set of |
2327 | // insertions. |
2328 | Value *Result = Constant::getNullValue(Ty: CI.getType()); |
2329 | for (unsigned i = 0, e = Elements.size(); i != e; ++i) { |
2330 | if (!Elements[i]) continue; // Unset element. |
2331 | |
2332 | Result = IC.Builder.CreateInsertElement(Vec: Result, NewElt: Elements[i], |
2333 | Idx: IC.Builder.getInt32(C: i)); |
2334 | } |
2335 | |
2336 | return Result; |
2337 | } |
2338 | |
2339 | /// Canonicalize scalar bitcasts of extracted elements into a bitcast of the |
2340 | /// vector followed by extract element. The backend tends to handle bitcasts of |
2341 | /// vectors better than bitcasts of scalars because vector registers are |
2342 | /// usually not type-specific like scalar integer or scalar floating-point. |
2343 | static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, |
2344 | InstCombinerImpl &IC) { |
2345 | Value *VecOp, *Index; |
2346 | if (!match(V: BitCast.getOperand(i_nocapture: 0), |
2347 | P: m_OneUse(SubPattern: m_ExtractElt(Val: m_Value(V&: VecOp), Idx: m_Value(V&: Index))))) |
2348 | return nullptr; |
2349 | |
2350 | // The bitcast must be to a vectorizable type, otherwise we can't make a new |
2351 | // type to extract from. |
2352 | Type *DestType = BitCast.getType(); |
2353 | VectorType *VecType = cast<VectorType>(Val: VecOp->getType()); |
2354 | if (VectorType::isValidElementType(ElemTy: DestType)) { |
2355 | auto *NewVecType = VectorType::get(ElementType: DestType, Other: VecType); |
2356 | auto *NewBC = IC.Builder.CreateBitCast(V: VecOp, DestTy: NewVecType, Name: "bc" ); |
2357 | return ExtractElementInst::Create(Vec: NewBC, Idx: Index); |
2358 | } |
2359 | |
2360 | // Only solve DestType is vector to avoid inverse transform in visitBitCast. |
2361 | // bitcast (extractelement <1 x elt>, dest) -> bitcast(<1 x elt>, dest) |
2362 | auto *FixedVType = dyn_cast<FixedVectorType>(Val: VecType); |
2363 | if (DestType->isVectorTy() && FixedVType && FixedVType->getNumElements() == 1) |
2364 | return CastInst::Create(Instruction::BitCast, S: VecOp, Ty: DestType); |
2365 | |
2366 | return nullptr; |
2367 | } |
2368 | |
2369 | /// Change the type of a bitwise logic operation if we can eliminate a bitcast. |
2370 | static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, |
2371 | InstCombiner::BuilderTy &Builder) { |
2372 | Type *DestTy = BitCast.getType(); |
2373 | BinaryOperator *BO; |
2374 | |
2375 | if (!match(V: BitCast.getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BinOp(I&: BO))) || |
2376 | !BO->isBitwiseLogicOp()) |
2377 | return nullptr; |
2378 | |
2379 | // FIXME: This transform is restricted to vector types to avoid backend |
2380 | // problems caused by creating potentially illegal operations. If a fix-up is |
2381 | // added to handle that situation, we can remove this check. |
2382 | if (!DestTy->isVectorTy() || !BO->getType()->isVectorTy()) |
2383 | return nullptr; |
2384 | |
2385 | if (DestTy->isFPOrFPVectorTy()) { |
2386 | Value *X, *Y; |
2387 | // bitcast(logic(bitcast(X), bitcast(Y))) -> bitcast'(logic(bitcast'(X), Y)) |
2388 | if (match(V: BO->getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2389 | match(V: BO->getOperand(i_nocapture: 1), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: Y))))) { |
2390 | if (X->getType()->isFPOrFPVectorTy() && |
2391 | Y->getType()->isIntOrIntVectorTy()) { |
2392 | Value *CastedOp = |
2393 | Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy: Y->getType()); |
2394 | Value *NewBO = Builder.CreateBinOp(Opc: BO->getOpcode(), LHS: CastedOp, RHS: Y); |
2395 | return CastInst::CreateBitOrPointerCast(S: NewBO, Ty: DestTy); |
2396 | } |
2397 | if (X->getType()->isIntOrIntVectorTy() && |
2398 | Y->getType()->isFPOrFPVectorTy()) { |
2399 | Value *CastedOp = |
2400 | Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 1), DestTy: X->getType()); |
2401 | Value *NewBO = Builder.CreateBinOp(Opc: BO->getOpcode(), LHS: CastedOp, RHS: X); |
2402 | return CastInst::CreateBitOrPointerCast(S: NewBO, Ty: DestTy); |
2403 | } |
2404 | } |
2405 | return nullptr; |
2406 | } |
2407 | |
2408 | if (!DestTy->isIntOrIntVectorTy()) |
2409 | return nullptr; |
2410 | |
2411 | Value *X; |
2412 | if (match(V: BO->getOperand(i_nocapture: 0), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2413 | X->getType() == DestTy && !isa<Constant>(Val: X)) { |
2414 | // bitcast(logic(bitcast(X), Y)) --> logic'(X, bitcast(Y)) |
2415 | Value *CastedOp1 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 1), DestTy); |
2416 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: X, S2: CastedOp1); |
2417 | } |
2418 | |
2419 | if (match(V: BO->getOperand(i_nocapture: 1), P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && |
2420 | X->getType() == DestTy && !isa<Constant>(Val: X)) { |
2421 | // bitcast(logic(Y, bitcast(X))) --> logic'(bitcast(Y), X) |
2422 | Value *CastedOp0 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy); |
2423 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: CastedOp0, S2: X); |
2424 | } |
2425 | |
2426 | // Canonicalize vector bitcasts to come before vector bitwise logic with a |
2427 | // constant. This eases recognition of special constants for later ops. |
2428 | // Example: |
2429 | // icmp u/s (a ^ signmask), (b ^ signmask) --> icmp s/u a, b |
2430 | Constant *C; |
2431 | if (match(V: BO->getOperand(i_nocapture: 1), P: m_Constant(C))) { |
2432 | // bitcast (logic X, C) --> logic (bitcast X, C') |
2433 | Value *CastedOp0 = Builder.CreateBitCast(V: BO->getOperand(i_nocapture: 0), DestTy); |
2434 | Value *CastedC = Builder.CreateBitCast(V: C, DestTy); |
2435 | return BinaryOperator::Create(Op: BO->getOpcode(), S1: CastedOp0, S2: CastedC); |
2436 | } |
2437 | |
2438 | return nullptr; |
2439 | } |
2440 | |
2441 | /// Change the type of a select if we can eliminate a bitcast. |
2442 | static Instruction *foldBitCastSelect(BitCastInst &BitCast, |
2443 | InstCombiner::BuilderTy &Builder) { |
2444 | Value *Cond, *TVal, *FVal; |
2445 | if (!match(V: BitCast.getOperand(i_nocapture: 0), |
2446 | P: m_OneUse(SubPattern: m_Select(C: m_Value(V&: Cond), L: m_Value(V&: TVal), R: m_Value(V&: FVal))))) |
2447 | return nullptr; |
2448 | |
2449 | // A vector select must maintain the same number of elements in its operands. |
2450 | Type *CondTy = Cond->getType(); |
2451 | Type *DestTy = BitCast.getType(); |
2452 | if (auto *CondVTy = dyn_cast<VectorType>(Val: CondTy)) |
2453 | if (!DestTy->isVectorTy() || |
2454 | CondVTy->getElementCount() != |
2455 | cast<VectorType>(Val: DestTy)->getElementCount()) |
2456 | return nullptr; |
2457 | |
2458 | // FIXME: This transform is restricted from changing the select between |
2459 | // scalars and vectors to avoid backend problems caused by creating |
2460 | // potentially illegal operations. If a fix-up is added to handle that |
2461 | // situation, we can remove this check. |
2462 | if (DestTy->isVectorTy() != TVal->getType()->isVectorTy()) |
2463 | return nullptr; |
2464 | |
2465 | auto *Sel = cast<Instruction>(Val: BitCast.getOperand(i_nocapture: 0)); |
2466 | Value *X; |
2467 | if (match(V: TVal, P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && X->getType() == DestTy && |
2468 | !isa<Constant>(Val: X)) { |
2469 | // bitcast(select(Cond, bitcast(X), Y)) --> select'(Cond, X, bitcast(Y)) |
2470 | Value *CastedVal = Builder.CreateBitCast(V: FVal, DestTy); |
2471 | return SelectInst::Create(C: Cond, S1: X, S2: CastedVal, NameStr: "" , InsertBefore: nullptr, MDFrom: Sel); |
2472 | } |
2473 | |
2474 | if (match(V: FVal, P: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X)))) && X->getType() == DestTy && |
2475 | !isa<Constant>(Val: X)) { |
2476 | // bitcast(select(Cond, Y, bitcast(X))) --> select'(Cond, bitcast(Y), X) |
2477 | Value *CastedVal = Builder.CreateBitCast(V: TVal, DestTy); |
2478 | return SelectInst::Create(C: Cond, S1: CastedVal, S2: X, NameStr: "" , InsertBefore: nullptr, MDFrom: Sel); |
2479 | } |
2480 | |
2481 | return nullptr; |
2482 | } |
2483 | |
2484 | /// Check if all users of CI are StoreInsts. |
2485 | static bool hasStoreUsersOnly(CastInst &CI) { |
2486 | for (User *U : CI.users()) { |
2487 | if (!isa<StoreInst>(Val: U)) |
2488 | return false; |
2489 | } |
2490 | return true; |
2491 | } |
2492 | |
2493 | /// This function handles following case |
2494 | /// |
2495 | /// A -> B cast |
2496 | /// PHI |
2497 | /// B -> A cast |
2498 | /// |
2499 | /// All the related PHI nodes can be replaced by new PHI nodes with type A. |
2500 | /// The uses of \p CI can be changed to the new PHI node corresponding to \p PN. |
2501 | Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, |
2502 | PHINode *PN) { |
2503 | // BitCast used by Store can be handled in InstCombineLoadStoreAlloca.cpp. |
2504 | if (hasStoreUsersOnly(CI)) |
2505 | return nullptr; |
2506 | |
2507 | Value *Src = CI.getOperand(i_nocapture: 0); |
2508 | Type *SrcTy = Src->getType(); // Type B |
2509 | Type *DestTy = CI.getType(); // Type A |
2510 | |
2511 | SmallVector<PHINode *, 4> PhiWorklist; |
2512 | SmallSetVector<PHINode *, 4> OldPhiNodes; |
2513 | |
2514 | // Find all of the A->B casts and PHI nodes. |
2515 | // We need to inspect all related PHI nodes, but PHIs can be cyclic, so |
2516 | // OldPhiNodes is used to track all known PHI nodes, before adding a new |
2517 | // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first. |
2518 | PhiWorklist.push_back(Elt: PN); |
2519 | OldPhiNodes.insert(X: PN); |
2520 | while (!PhiWorklist.empty()) { |
2521 | auto *OldPN = PhiWorklist.pop_back_val(); |
2522 | for (Value *IncValue : OldPN->incoming_values()) { |
2523 | if (isa<Constant>(Val: IncValue)) |
2524 | continue; |
2525 | |
2526 | if (auto *LI = dyn_cast<LoadInst>(Val: IncValue)) { |
2527 | // If there is a sequence of one or more load instructions, each loaded |
2528 | // value is used as address of later load instruction, bitcast is |
2529 | // necessary to change the value type, don't optimize it. For |
2530 | // simplicity we give up if the load address comes from another load. |
2531 | Value *Addr = LI->getOperand(i_nocapture: 0); |
2532 | if (Addr == &CI || isa<LoadInst>(Val: Addr)) |
2533 | return nullptr; |
2534 | // Don't tranform "load <256 x i32>, <256 x i32>*" to |
2535 | // "load x86_amx, x86_amx*", because x86_amx* is invalid. |
2536 | // TODO: Remove this check when bitcast between vector and x86_amx |
2537 | // is replaced with a specific intrinsic. |
2538 | if (DestTy->isX86_AMXTy()) |
2539 | return nullptr; |
2540 | if (LI->hasOneUse() && LI->isSimple()) |
2541 | continue; |
2542 | // If a LoadInst has more than one use, changing the type of loaded |
2543 | // value may create another bitcast. |
2544 | return nullptr; |
2545 | } |
2546 | |
2547 | if (auto *PNode = dyn_cast<PHINode>(Val: IncValue)) { |
2548 | if (OldPhiNodes.insert(X: PNode)) |
2549 | PhiWorklist.push_back(Elt: PNode); |
2550 | continue; |
2551 | } |
2552 | |
2553 | auto *BCI = dyn_cast<BitCastInst>(Val: IncValue); |
2554 | // We can't handle other instructions. |
2555 | if (!BCI) |
2556 | return nullptr; |
2557 | |
2558 | // Verify it's a A->B cast. |
2559 | Type *TyA = BCI->getOperand(i_nocapture: 0)->getType(); |
2560 | Type *TyB = BCI->getType(); |
2561 | if (TyA != DestTy || TyB != SrcTy) |
2562 | return nullptr; |
2563 | } |
2564 | } |
2565 | |
2566 | // Check that each user of each old PHI node is something that we can |
2567 | // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. |
2568 | for (auto *OldPN : OldPhiNodes) { |
2569 | for (User *V : OldPN->users()) { |
2570 | if (auto *SI = dyn_cast<StoreInst>(Val: V)) { |
2571 | if (!SI->isSimple() || SI->getOperand(i_nocapture: 0) != OldPN) |
2572 | return nullptr; |
2573 | } else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2574 | // Verify it's a B->A cast. |
2575 | Type *TyB = BCI->getOperand(i_nocapture: 0)->getType(); |
2576 | Type *TyA = BCI->getType(); |
2577 | if (TyA != DestTy || TyB != SrcTy) |
2578 | return nullptr; |
2579 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
2580 | // As long as the user is another old PHI node, then even if we don't |
2581 | // rewrite it, the PHI web we're considering won't have any users |
2582 | // outside itself, so it'll be dead. |
2583 | if (!OldPhiNodes.contains(key: PHI)) |
2584 | return nullptr; |
2585 | } else { |
2586 | return nullptr; |
2587 | } |
2588 | } |
2589 | } |
2590 | |
2591 | // For each old PHI node, create a corresponding new PHI node with a type A. |
2592 | SmallDenseMap<PHINode *, PHINode *> NewPNodes; |
2593 | for (auto *OldPN : OldPhiNodes) { |
2594 | Builder.SetInsertPoint(OldPN); |
2595 | PHINode *NewPN = Builder.CreatePHI(Ty: DestTy, NumReservedValues: OldPN->getNumOperands()); |
2596 | NewPNodes[OldPN] = NewPN; |
2597 | } |
2598 | |
2599 | // Fill in the operands of new PHI nodes. |
2600 | for (auto *OldPN : OldPhiNodes) { |
2601 | PHINode *NewPN = NewPNodes[OldPN]; |
2602 | for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) { |
2603 | Value *V = OldPN->getOperand(i_nocapture: j); |
2604 | Value *NewV = nullptr; |
2605 | if (auto *C = dyn_cast<Constant>(Val: V)) { |
2606 | NewV = ConstantExpr::getBitCast(C, Ty: DestTy); |
2607 | } else if (auto *LI = dyn_cast<LoadInst>(Val: V)) { |
2608 | // Explicitly perform load combine to make sure no opposing transform |
2609 | // can remove the bitcast in the meantime and trigger an infinite loop. |
2610 | Builder.SetInsertPoint(LI); |
2611 | NewV = combineLoadToNewType(LI&: *LI, NewTy: DestTy); |
2612 | // Remove the old load and its use in the old phi, which itself becomes |
2613 | // dead once the whole transform finishes. |
2614 | replaceInstUsesWith(I&: *LI, V: PoisonValue::get(T: LI->getType())); |
2615 | eraseInstFromFunction(I&: *LI); |
2616 | } else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2617 | NewV = BCI->getOperand(i_nocapture: 0); |
2618 | } else if (auto *PrevPN = dyn_cast<PHINode>(Val: V)) { |
2619 | NewV = NewPNodes[PrevPN]; |
2620 | } |
2621 | assert(NewV); |
2622 | NewPN->addIncoming(V: NewV, BB: OldPN->getIncomingBlock(i: j)); |
2623 | } |
2624 | } |
2625 | |
2626 | // Traverse all accumulated PHI nodes and process its users, |
2627 | // which are Stores and BitcCasts. Without this processing |
2628 | // NewPHI nodes could be replicated and could lead to extra |
2629 | // moves generated after DeSSA. |
2630 | // If there is a store with type B, change it to type A. |
2631 | |
2632 | |
2633 | // Replace users of BitCast B->A with NewPHI. These will help |
2634 | // later to get rid off a closure formed by OldPHI nodes. |
2635 | Instruction *RetVal = nullptr; |
2636 | for (auto *OldPN : OldPhiNodes) { |
2637 | PHINode *NewPN = NewPNodes[OldPN]; |
2638 | for (User *V : make_early_inc_range(Range: OldPN->users())) { |
2639 | if (auto *SI = dyn_cast<StoreInst>(Val: V)) { |
2640 | assert(SI->isSimple() && SI->getOperand(0) == OldPN); |
2641 | Builder.SetInsertPoint(SI); |
2642 | auto *NewBC = |
2643 | cast<BitCastInst>(Val: Builder.CreateBitCast(V: NewPN, DestTy: SrcTy)); |
2644 | SI->setOperand(i_nocapture: 0, Val_nocapture: NewBC); |
2645 | Worklist.push(I: SI); |
2646 | assert(hasStoreUsersOnly(*NewBC)); |
2647 | } |
2648 | else if (auto *BCI = dyn_cast<BitCastInst>(Val: V)) { |
2649 | Type *TyB = BCI->getOperand(i_nocapture: 0)->getType(); |
2650 | Type *TyA = BCI->getType(); |
2651 | assert(TyA == DestTy && TyB == SrcTy); |
2652 | (void) TyA; |
2653 | (void) TyB; |
2654 | Instruction *I = replaceInstUsesWith(I&: *BCI, V: NewPN); |
2655 | if (BCI == &CI) |
2656 | RetVal = I; |
2657 | } else if (auto *PHI = dyn_cast<PHINode>(Val: V)) { |
2658 | assert(OldPhiNodes.contains(PHI)); |
2659 | (void) PHI; |
2660 | } else { |
2661 | llvm_unreachable("all uses should be handled" ); |
2662 | } |
2663 | } |
2664 | } |
2665 | |
2666 | return RetVal; |
2667 | } |
2668 | |
2669 | Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { |
2670 | // If the operands are integer typed then apply the integer transforms, |
2671 | // otherwise just apply the common ones. |
2672 | Value *Src = CI.getOperand(i_nocapture: 0); |
2673 | Type *SrcTy = Src->getType(); |
2674 | Type *DestTy = CI.getType(); |
2675 | |
2676 | // Get rid of casts from one type to the same type. These are useless and can |
2677 | // be replaced by the operand. |
2678 | if (DestTy == Src->getType()) |
2679 | return replaceInstUsesWith(I&: CI, V: Src); |
2680 | |
2681 | if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(Val: DestTy)) { |
2682 | // Beware: messing with this target-specific oddity may cause trouble. |
2683 | if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { |
2684 | Value *Elem = Builder.CreateBitCast(V: Src, DestTy: DestVTy->getElementType()); |
2685 | return InsertElementInst::Create(Vec: PoisonValue::get(T: DestTy), NewElt: Elem, |
2686 | Idx: Constant::getNullValue(Ty: Type::getInt32Ty(C&: CI.getContext()))); |
2687 | } |
2688 | |
2689 | if (isa<IntegerType>(Val: SrcTy)) { |
2690 | // If this is a cast from an integer to vector, check to see if the input |
2691 | // is a trunc or zext of a bitcast from vector. If so, we can replace all |
2692 | // the casts with a shuffle and (potentially) a bitcast. |
2693 | if (isa<TruncInst>(Val: Src) || isa<ZExtInst>(Val: Src)) { |
2694 | CastInst *SrcCast = cast<CastInst>(Val: Src); |
2695 | if (BitCastInst *BCIn = dyn_cast<BitCastInst>(Val: SrcCast->getOperand(i_nocapture: 0))) |
2696 | if (isa<VectorType>(Val: BCIn->getOperand(i_nocapture: 0)->getType())) |
2697 | if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts( |
2698 | InVal: BCIn->getOperand(i_nocapture: 0), DestTy: cast<VectorType>(Val: DestTy), IC&: *this)) |
2699 | return I; |
2700 | } |
2701 | |
2702 | // If the input is an 'or' instruction, we may be doing shifts and ors to |
2703 | // assemble the elements of the vector manually. Try to rip the code out |
2704 | // and replace it with insertelements. |
2705 | if (Value *V = optimizeIntegerToVectorInsertions(CI, IC&: *this)) |
2706 | return replaceInstUsesWith(I&: CI, V); |
2707 | } |
2708 | } |
2709 | |
2710 | if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(Val: SrcTy)) { |
2711 | if (SrcVTy->getNumElements() == 1) { |
2712 | // If our destination is not a vector, then make this a straight |
2713 | // scalar-scalar cast. |
2714 | if (!DestTy->isVectorTy()) { |
2715 | Value *Elem = |
2716 | Builder.CreateExtractElement(Vec: Src, |
2717 | Idx: Constant::getNullValue(Ty: Type::getInt32Ty(C&: CI.getContext()))); |
2718 | return CastInst::Create(Instruction::BitCast, S: Elem, Ty: DestTy); |
2719 | } |
2720 | |
2721 | // Otherwise, see if our source is an insert. If so, then use the scalar |
2722 | // component directly: |
2723 | // bitcast (inselt <1 x elt> V, X, 0) to <n x m> --> bitcast X to <n x m> |
2724 | if (auto *InsElt = dyn_cast<InsertElementInst>(Val: Src)) |
2725 | return new BitCastInst(InsElt->getOperand(i_nocapture: 1), DestTy); |
2726 | } |
2727 | |
2728 | // Convert an artificial vector insert into more analyzable bitwise logic. |
2729 | unsigned BitWidth = DestTy->getScalarSizeInBits(); |
2730 | Value *X, *Y; |
2731 | uint64_t IndexC; |
2732 | if (match(V: Src, P: m_OneUse(SubPattern: m_InsertElt(Val: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: X))), |
2733 | Elt: m_Value(V&: Y), Idx: m_ConstantInt(V&: IndexC)))) && |
2734 | DestTy->isIntegerTy() && X->getType() == DestTy && |
2735 | Y->getType()->isIntegerTy() && isDesirableIntType(BitWidth)) { |
2736 | // Adjust for big endian - the LSBs are at the high index. |
2737 | if (DL.isBigEndian()) |
2738 | IndexC = SrcVTy->getNumElements() - 1 - IndexC; |
2739 | |
2740 | // We only handle (endian-normalized) insert to index 0. Any other insert |
2741 | // would require a left-shift, so that is an extra instruction. |
2742 | if (IndexC == 0) { |
2743 | // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y) |
2744 | unsigned EltWidth = Y->getType()->getScalarSizeInBits(); |
2745 | APInt MaskC = APInt::getHighBitsSet(numBits: BitWidth, hiBitsSet: BitWidth - EltWidth); |
2746 | Value *AndX = Builder.CreateAnd(LHS: X, RHS: MaskC); |
2747 | Value *ZextY = Builder.CreateZExt(V: Y, DestTy); |
2748 | return BinaryOperator::CreateOr(V1: AndX, V2: ZextY); |
2749 | } |
2750 | } |
2751 | } |
2752 | |
2753 | if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Val: Src)) { |
2754 | // Okay, we have (bitcast (shuffle ..)). Check to see if this is |
2755 | // a bitcast to a vector with the same # elts. |
2756 | Value *ShufOp0 = Shuf->getOperand(i_nocapture: 0); |
2757 | Value *ShufOp1 = Shuf->getOperand(i_nocapture: 1); |
2758 | auto ShufElts = cast<VectorType>(Val: Shuf->getType())->getElementCount(); |
2759 | auto SrcVecElts = cast<VectorType>(Val: ShufOp0->getType())->getElementCount(); |
2760 | if (Shuf->hasOneUse() && DestTy->isVectorTy() && |
2761 | cast<VectorType>(Val: DestTy)->getElementCount() == ShufElts && |
2762 | ShufElts == SrcVecElts) { |
2763 | BitCastInst *Tmp; |
2764 | // If either of the operands is a cast from CI.getType(), then |
2765 | // evaluating the shuffle in the casted destination's type will allow |
2766 | // us to eliminate at least one cast. |
2767 | if (((Tmp = dyn_cast<BitCastInst>(Val: ShufOp0)) && |
2768 | Tmp->getOperand(i_nocapture: 0)->getType() == DestTy) || |
2769 | ((Tmp = dyn_cast<BitCastInst>(Val: ShufOp1)) && |
2770 | Tmp->getOperand(i_nocapture: 0)->getType() == DestTy)) { |
2771 | Value *LHS = Builder.CreateBitCast(V: ShufOp0, DestTy); |
2772 | Value *RHS = Builder.CreateBitCast(V: ShufOp1, DestTy); |
2773 | // Return a new shuffle vector. Use the same element ID's, as we |
2774 | // know the vector types match #elts. |
2775 | return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask()); |
2776 | } |
2777 | } |
2778 | |
2779 | // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized |
2780 | // as a byte/bit swap: |
2781 | // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X) |
2782 | // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X) |
2783 | if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 && |
2784 | Shuf->hasOneUse() && Shuf->isReverse()) { |
2785 | unsigned IntrinsicNum = 0; |
2786 | if (DL.isLegalInteger(Width: DestTy->getScalarSizeInBits()) && |
2787 | SrcTy->getScalarSizeInBits() == 8) { |
2788 | IntrinsicNum = Intrinsic::bswap; |
2789 | } else if (SrcTy->getScalarSizeInBits() == 1) { |
2790 | IntrinsicNum = Intrinsic::bitreverse; |
2791 | } |
2792 | if (IntrinsicNum != 0) { |
2793 | assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask" ); |
2794 | assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op" ); |
2795 | Function *BswapOrBitreverse = |
2796 | Intrinsic::getDeclaration(M: CI.getModule(), id: IntrinsicNum, Tys: DestTy); |
2797 | Value *ScalarX = Builder.CreateBitCast(V: ShufOp0, DestTy); |
2798 | return CallInst::Create(Func: BswapOrBitreverse, Args: {ScalarX}); |
2799 | } |
2800 | } |
2801 | } |
2802 | |
2803 | // Handle the A->B->A cast, and there is an intervening PHI node. |
2804 | if (PHINode *PN = dyn_cast<PHINode>(Val: Src)) |
2805 | if (Instruction *I = optimizeBitCastFromPhi(CI, PN)) |
2806 | return I; |
2807 | |
2808 | if (Instruction *I = canonicalizeBitCastExtElt(BitCast&: CI, IC&: *this)) |
2809 | return I; |
2810 | |
2811 | if (Instruction *I = foldBitCastBitwiseLogic(BitCast&: CI, Builder)) |
2812 | return I; |
2813 | |
2814 | if (Instruction *I = foldBitCastSelect(BitCast&: CI, Builder)) |
2815 | return I; |
2816 | |
2817 | return commonCastTransforms(CI); |
2818 | } |
2819 | |
2820 | Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { |
2821 | return commonCastTransforms(CI); |
2822 | } |
2823 | |