1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Vectorize/VectorCombine.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/Statistic.h"
21#include "llvm/Analysis/AssumptionCache.h"
22#include "llvm/Analysis/BasicAliasAnalysis.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/Loads.h"
26#include "llvm/Analysis/TargetFolder.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/Analysis/ValueTracking.h"
29#include "llvm/Analysis/VectorUtils.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/PatternMatch.h"
35#include "llvm/Support/CommandLine.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Transforms/Utils/Local.h"
38#include "llvm/Transforms/Utils/LoopUtils.h"
39#include <numeric>
40#include <optional>
41#include <queue>
42#include <set>
43
44#define DEBUG_TYPE "vector-combine"
45#include "llvm/Transforms/Utils/InstructionWorklist.h"
46
47using namespace llvm;
48using namespace llvm::PatternMatch;
49
50STATISTIC(NumVecLoad, "Number of vector loads formed");
51STATISTIC(NumVecCmp, "Number of vector compares formed");
52STATISTIC(NumVecBO, "Number of vector binops formed");
53STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
54STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
55STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
56STATISTIC(NumScalarCmp, "Number of scalar compares formed");
57STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
58
59static cl::opt<bool> DisableVectorCombine(
60 "disable-vector-combine", cl::init(Val: false), cl::Hidden,
61 cl::desc("Disable all vector combine transforms"));
62
63static cl::opt<bool> DisableBinopExtractShuffle(
64 "disable-binop-extract-shuffle", cl::init(Val: false), cl::Hidden,
65 cl::desc("Disable binop extract to shuffle transforms"));
66
67static cl::opt<unsigned> MaxInstrsToScan(
68 "vector-combine-max-scan-instrs", cl::init(Val: 30), cl::Hidden,
69 cl::desc("Max number of instructions to scan for vector combining."));
70
71static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
72
73namespace {
74class VectorCombine {
75public:
76 VectorCombine(Function &F, const TargetTransformInfo &TTI,
77 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
78 const DataLayout *DL, TTI::TargetCostKind CostKind,
79 bool TryEarlyFoldsOnly)
80 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
81 DT(DT), AA(AA), AC(AC), DL(DL), CostKind(CostKind), SQ(*DL),
82 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
83
84 bool run();
85
86private:
87 Function &F;
88 IRBuilder<InstSimplifyFolder> Builder;
89 const TargetTransformInfo &TTI;
90 const DominatorTree &DT;
91 AAResults &AA;
92 AssumptionCache &AC;
93 const DataLayout *DL;
94 TTI::TargetCostKind CostKind;
95 const SimplifyQuery SQ;
96
97 /// If true, only perform beneficial early IR transforms. Do not introduce new
98 /// vector operations.
99 bool TryEarlyFoldsOnly;
100
101 InstructionWorklist Worklist;
102
103 /// Next instruction to iterate. It will be updated when it is erased by
104 /// RecursivelyDeleteTriviallyDeadInstructions.
105 Instruction *NextInst;
106
107 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
108 // parameter. That should be updated to specific sub-classes because the
109 // run loop was changed to dispatch on opcode.
110 bool vectorizeLoadInsert(Instruction &I);
111 bool widenSubvectorLoad(Instruction &I);
112 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
113 ExtractElementInst *Ext1,
114 unsigned PreferredExtractIndex) const;
115 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
116 const Instruction &I,
117 ExtractElementInst *&ConvertToShuffle,
118 unsigned PreferredExtractIndex);
119 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
120 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 bool foldExtractExtract(Instruction &I);
122 bool foldInsExtFNeg(Instruction &I);
123 bool foldInsExtBinop(Instruction &I);
124 bool foldInsExtVectorToShuffle(Instruction &I);
125 bool foldBitOpOfCastops(Instruction &I);
126 bool foldBitOpOfCastConstant(Instruction &I);
127 bool foldBitcastShuffle(Instruction &I);
128 bool scalarizeOpOrCmp(Instruction &I);
129 bool scalarizeVPIntrinsic(Instruction &I);
130 bool foldExtractedCmps(Instruction &I);
131 bool foldSelectsFromBitcast(Instruction &I);
132 bool foldBinopOfReductions(Instruction &I);
133 bool foldSingleElementStore(Instruction &I);
134 bool scalarizeLoad(Instruction &I);
135 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
136 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeExtExtract(Instruction &I);
138 bool foldConcatOfBoolMasks(Instruction &I);
139 bool foldPermuteOfBinops(Instruction &I);
140 bool foldShuffleOfBinops(Instruction &I);
141 bool foldShuffleOfSelects(Instruction &I);
142 bool foldShuffleOfCastops(Instruction &I);
143 bool foldShuffleOfShuffles(Instruction &I);
144 bool foldPermuteOfIntrinsic(Instruction &I);
145 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
146 bool foldShuffleOfIntrinsics(Instruction &I);
147 bool foldShuffleToIdentity(Instruction &I);
148 bool foldShuffleFromReductions(Instruction &I);
149 bool foldShuffleChainsToReduce(Instruction &I);
150 bool foldCastFromReductions(Instruction &I);
151 bool foldSignBitReductionCmp(Instruction &I);
152 bool foldICmpEqZeroVectorReduce(Instruction &I);
153 bool foldEquivalentReductionCmp(Instruction &I);
154 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
155 bool foldInterleaveIntrinsics(Instruction &I);
156 bool shrinkType(Instruction &I);
157 bool shrinkLoadForShuffles(Instruction &I);
158 bool shrinkPhiOfShuffles(Instruction &I);
159
160 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
161 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
162 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
163 Old.replaceAllUsesWith(V: &New);
164 if (auto *NewI = dyn_cast<Instruction>(Val: &New)) {
165 New.takeName(V: &Old);
166 Worklist.pushUsersToWorkList(I&: *NewI);
167 Worklist.pushValue(V: NewI);
168 }
169 if (Erase && isInstructionTriviallyDead(I: &Old)) {
170 eraseInstruction(I&: Old);
171 } else {
172 Worklist.push(I: &Old);
173 }
174 }
175
176 void eraseInstruction(Instruction &I) {
177 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
178 SmallVector<Value *> Ops(I.operands());
179 Worklist.remove(I: &I);
180 I.eraseFromParent();
181
182 // Push remaining users of the operands and then the operand itself - allows
183 // further folds that were hindered by OneUse limits.
184 SmallPtrSet<Value *, 4> Visited;
185 for (Value *Op : Ops) {
186 if (!Visited.contains(Ptr: Op)) {
187 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
188 if (RecursivelyDeleteTriviallyDeadInstructions(
189 V: OpI, TLI: nullptr, MSSAU: nullptr, AboutToDeleteCallback: [&](Value *V) {
190 if (auto *I = dyn_cast<Instruction>(Val: V)) {
191 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
192 Worklist.remove(I);
193 if (I == NextInst)
194 NextInst = NextInst->getNextNode();
195 Visited.insert(Ptr: I);
196 }
197 }))
198 continue;
199 Worklist.pushUsersToWorkList(I&: *OpI);
200 Worklist.pushValue(V: OpI);
201 }
202 }
203 }
204 }
205};
206} // namespace
207
208/// Return the source operand of a potentially bitcasted value. If there is no
209/// bitcast, return the input value itself.
210static Value *peekThroughBitcasts(Value *V) {
211 while (auto *BitCast = dyn_cast<BitCastInst>(Val: V))
212 V = BitCast->getOperand(i_nocapture: 0);
213 return V;
214}
215
216static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
217 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
218 // The widened load may load data from dirty regions or create data races
219 // non-existent in the source.
220 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
221 Load->getFunction()->hasFnAttribute(Kind: Attribute::SanitizeMemTag) ||
222 mustSuppressSpeculation(LI: *Load))
223 return false;
224
225 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
226 // sure we have all of our type-based constraints in place for this target.
227 Type *ScalarTy = Load->getType()->getScalarType();
228 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
229 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
230 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
231 ScalarSize % 8 != 0)
232 return false;
233
234 return true;
235}
236
237bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
238 // Match insert into fixed vector of scalar value.
239 // TODO: Handle non-zero insert index.
240 Value *Scalar;
241 if (!match(V: &I,
242 P: m_InsertElt(Val: m_Poison(), Elt: m_OneUse(SubPattern: m_Value(V&: Scalar)), Idx: m_ZeroInt())))
243 return false;
244
245 // Optionally match an extract from another vector.
246 Value *X;
247 bool HasExtract = match(V: Scalar, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ZeroInt()));
248 if (!HasExtract)
249 X = Scalar;
250
251 auto *Load = dyn_cast<LoadInst>(Val: X);
252 if (!canWidenLoad(Load, TTI))
253 return false;
254
255 Type *ScalarTy = Scalar->getType();
256 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
257 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
258
259 // Check safety of replacing the scalar load with a larger vector load.
260 // We use minimal alignment (maximum flexibility) because we only care about
261 // the dereferenceable region. When calculating cost and creating a new op,
262 // we may use a larger value based on alignment attributes.
263 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
264 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
265
266 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
267 auto *MinVecTy = VectorType::get(ElementType: ScalarTy, NumElements: MinVecNumElts, Scalable: false);
268 unsigned OffsetEltIndex = 0;
269 Align Alignment = Load->getAlign();
270 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
271 DT: &DT)) {
272 // It is not safe to load directly from the pointer, but we can still peek
273 // through gep offsets and check if it safe to load from a base address with
274 // updated alignment. If it is, we can shuffle the element(s) into place
275 // after loading.
276 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(Ty: SrcPtr->getType());
277 APInt Offset(OffsetBitWidth, 0);
278 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL: *DL, Offset);
279
280 // We want to shuffle the result down from a high element of a vector, so
281 // the offset must be positive.
282 if (Offset.isNegative())
283 return false;
284
285 // The offset must be a multiple of the scalar element to shuffle cleanly
286 // in the element's size.
287 uint64_t ScalarSizeInBytes = ScalarSize / 8;
288 if (Offset.urem(RHS: ScalarSizeInBytes) != 0)
289 return false;
290
291 // If we load MinVecNumElts, will our target element still be loaded?
292 OffsetEltIndex = Offset.udiv(RHS: ScalarSizeInBytes).getZExtValue();
293 if (OffsetEltIndex >= MinVecNumElts)
294 return false;
295
296 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
297 DT: &DT))
298 return false;
299
300 // Update alignment with offset value. Note that the offset could be negated
301 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
302 // negation does not change the result of the alignment calculation.
303 Alignment = commonAlignment(A: Alignment, Offset: Offset.getZExtValue());
304 }
305
306 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
307 // Use the greater of the alignment on the load or its source pointer.
308 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
309 Type *LoadTy = Load->getType();
310 unsigned AS = Load->getPointerAddressSpace();
311 InstructionCost OldCost =
312 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
313 APInt DemandedElts = APInt::getOneBitSet(numBits: MinVecNumElts, BitNo: 0);
314 OldCost +=
315 TTI.getScalarizationOverhead(Ty: MinVecTy, DemandedElts,
316 /* Insert */ true, Extract: HasExtract, CostKind);
317
318 // New pattern: load VecPtr
319 InstructionCost NewCost =
320 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: MinVecTy, Alignment, AddressSpace: AS, CostKind);
321 // Optionally, we are shuffling the loaded vector element(s) into place.
322 // For the mask set everything but element 0 to undef to prevent poison from
323 // propagating from the extra loaded memory. This will also optionally
324 // shrink/grow the vector from the loaded size to the output size.
325 // We assume this operation has no cost in codegen if there was no offset.
326 // Note that we could use freeze to avoid poison problems, but then we might
327 // still need a shuffle to change the vector size.
328 auto *Ty = cast<FixedVectorType>(Val: I.getType());
329 unsigned OutputNumElts = Ty->getNumElements();
330 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
331 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
332 Mask[0] = OffsetEltIndex;
333 if (OffsetEltIndex)
334 NewCost += TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Ty, SrcTy: MinVecTy, Mask,
335 CostKind);
336
337 // We can aggressively convert to the vector form because the backend can
338 // invert this transform if it does not result in a performance win.
339 if (OldCost < NewCost || !NewCost.isValid())
340 return false;
341
342 // It is safe and potentially profitable to load a vector directly:
343 // inselt undef, load Scalar, 0 --> load VecPtr
344 IRBuilder<> Builder(Load);
345 Value *CastedPtr =
346 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
347 Value *VecLd = Builder.CreateAlignedLoad(Ty: MinVecTy, Ptr: CastedPtr, Align: Alignment);
348 VecLd = Builder.CreateShuffleVector(V: VecLd, Mask);
349
350 replaceValue(Old&: I, New&: *VecLd);
351 ++NumVecLoad;
352 return true;
353}
354
355/// If we are loading a vector and then inserting it into a larger vector with
356/// undefined elements, try to load the larger vector and eliminate the insert.
357/// This removes a shuffle in IR and may allow combining of other loaded values.
358bool VectorCombine::widenSubvectorLoad(Instruction &I) {
359 // Match subvector insert of fixed vector.
360 auto *Shuf = cast<ShuffleVectorInst>(Val: &I);
361 if (!Shuf->isIdentityWithPadding())
362 return false;
363
364 // Allow a non-canonical shuffle mask that is choosing elements from op1.
365 unsigned NumOpElts =
366 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
367 unsigned OpIndex = any_of(Range: Shuf->getShuffleMask(), P: [&NumOpElts](int M) {
368 return M >= (int)(NumOpElts);
369 });
370
371 auto *Load = dyn_cast<LoadInst>(Val: Shuf->getOperand(i_nocapture: OpIndex));
372 if (!canWidenLoad(Load, TTI))
373 return false;
374
375 // We use minimal alignment (maximum flexibility) because we only care about
376 // the dereferenceable region. When calculating cost and creating a new op,
377 // we may use a larger value based on alignment attributes.
378 auto *Ty = cast<FixedVectorType>(Val: I.getType());
379 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
380 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
381 Align Alignment = Load->getAlign();
382 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC, DT: &DT))
383 return false;
384
385 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
386 Type *LoadTy = Load->getType();
387 unsigned AS = Load->getPointerAddressSpace();
388
389 // Original pattern: insert_subvector (load PtrOp)
390 // This conservatively assumes that the cost of a subvector insert into an
391 // undef value is 0. We could add that cost if the cost model accurately
392 // reflects the real cost of that operation.
393 InstructionCost OldCost =
394 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
395
396 // New pattern: load PtrOp
397 InstructionCost NewCost =
398 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: Ty, Alignment, AddressSpace: AS, CostKind);
399
400 // We can aggressively convert to the vector form because the backend can
401 // invert this transform if it does not result in a performance win.
402 if (OldCost < NewCost || !NewCost.isValid())
403 return false;
404
405 IRBuilder<> Builder(Load);
406 Value *CastedPtr =
407 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
408 Value *VecLd = Builder.CreateAlignedLoad(Ty, Ptr: CastedPtr, Align: Alignment);
409 replaceValue(Old&: I, New&: *VecLd);
410 ++NumVecLoad;
411 return true;
412}
413
414/// Determine which, if any, of the inputs should be replaced by a shuffle
415/// followed by extract from a different index.
416ExtractElementInst *VectorCombine::getShuffleExtract(
417 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
418 unsigned PreferredExtractIndex = InvalidIndex) const {
419 auto *Index0C = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
420 auto *Index1C = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
421 assert(Index0C && Index1C && "Expected constant extract indexes");
422
423 unsigned Index0 = Index0C->getZExtValue();
424 unsigned Index1 = Index1C->getZExtValue();
425
426 // If the extract indexes are identical, no shuffle is needed.
427 if (Index0 == Index1)
428 return nullptr;
429
430 Type *VecTy = Ext0->getVectorOperand()->getType();
431 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
432 InstructionCost Cost0 =
433 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
434 InstructionCost Cost1 =
435 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
436
437 // If both costs are invalid no shuffle is needed
438 if (!Cost0.isValid() && !Cost1.isValid())
439 return nullptr;
440
441 // We are extracting from 2 different indexes, so one operand must be shuffled
442 // before performing a vector operation and/or extract. The more expensive
443 // extract will be replaced by a shuffle.
444 if (Cost0 > Cost1)
445 return Ext0;
446 if (Cost1 > Cost0)
447 return Ext1;
448
449 // If the costs are equal and there is a preferred extract index, shuffle the
450 // opposite operand.
451 if (PreferredExtractIndex == Index0)
452 return Ext1;
453 if (PreferredExtractIndex == Index1)
454 return Ext0;
455
456 // Otherwise, replace the extract with the higher index.
457 return Index0 > Index1 ? Ext0 : Ext1;
458}
459
460/// Compare the relative costs of 2 extracts followed by scalar operation vs.
461/// vector operation(s) followed by extract. Return true if the existing
462/// instructions are cheaper than a vector alternative. Otherwise, return false
463/// and if one of the extracts should be transformed to a shufflevector, set
464/// \p ConvertToShuffle to that extract instruction.
465bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
466 ExtractElementInst *Ext1,
467 const Instruction &I,
468 ExtractElementInst *&ConvertToShuffle,
469 unsigned PreferredExtractIndex) {
470 auto *Ext0IndexC = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
471 auto *Ext1IndexC = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
472 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
473
474 unsigned Opcode = I.getOpcode();
475 Value *Ext0Src = Ext0->getVectorOperand();
476 Value *Ext1Src = Ext1->getVectorOperand();
477 Type *ScalarTy = Ext0->getType();
478 auto *VecTy = cast<VectorType>(Val: Ext0Src->getType());
479 InstructionCost ScalarOpCost, VectorOpCost;
480
481 // Get cost estimates for scalar and vector versions of the operation.
482 bool IsBinOp = Instruction::isBinaryOp(Opcode);
483 if (IsBinOp) {
484 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
485 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
486 } else {
487 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
488 "Expected a compare");
489 CmpInst::Predicate Pred = cast<CmpInst>(Val: I).getPredicate();
490 ScalarOpCost = TTI.getCmpSelInstrCost(
491 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
492 VectorOpCost = TTI.getCmpSelInstrCost(
493 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
494 }
495
496 // Get cost estimates for the extract elements. These costs will factor into
497 // both sequences.
498 unsigned Ext0Index = Ext0IndexC->getZExtValue();
499 unsigned Ext1Index = Ext1IndexC->getZExtValue();
500
501 InstructionCost Extract0Cost =
502 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Ext0Index);
503 InstructionCost Extract1Cost =
504 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Ext1Index);
505
506 // A more expensive extract will always be replaced by a splat shuffle.
507 // For example, if Ext0 is more expensive:
508 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
509 // extelt (opcode (splat V0, Ext0), V1), Ext1
510 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
511 // check the cost of creating a broadcast shuffle and shuffling both
512 // operands to element 0.
513 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
514 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
515 InstructionCost CheapExtractCost = std::min(a: Extract0Cost, b: Extract1Cost);
516
517 // Extra uses of the extracts mean that we include those costs in the
518 // vector total because those instructions will not be eliminated.
519 InstructionCost OldCost, NewCost;
520 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
521 // Handle a special case. If the 2 extracts are identical, adjust the
522 // formulas to account for that. The extra use charge allows for either the
523 // CSE'd pattern or an unoptimized form with identical values:
524 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
525 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(N: 2)
526 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
527 OldCost = CheapExtractCost + ScalarOpCost;
528 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
529 } else {
530 // Handle the general case. Each extract is actually a different value:
531 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
532 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost +
534 !Ext0->hasOneUse() * Extract0Cost +
535 !Ext1->hasOneUse() * Extract1Cost;
536 }
537
538 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
539 if (ConvertToShuffle) {
540 if (IsBinOp && DisableBinopExtractShuffle)
541 return true;
542
543 // If we are extracting from 2 different indexes, then one operand must be
544 // shuffled before performing the vector operation. The shuffle mask is
545 // poison except for 1 lane that is being translated to the remaining
546 // extraction lane. Therefore, it is a splat shuffle. Ex:
547 // ShufMask = { poison, poison, 0, poison }
548 // TODO: The cost model has an option for a "broadcast" shuffle
549 // (splat-from-element-0), but no option for a more general splat.
550 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: VecTy)) {
551 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
552 PoisonMaskElem);
553 ShuffleMask[BestInsIndex] = BestExtIndex;
554 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
555 DstTy: VecTy, SrcTy: VecTy, Mask: ShuffleMask, CostKind, Index: 0,
556 SubTp: nullptr, Args: {ConvertToShuffle});
557 } else {
558 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
559 DstTy: VecTy, SrcTy: VecTy, Mask: {}, CostKind, Index: 0, SubTp: nullptr,
560 Args: {ConvertToShuffle});
561 }
562 }
563
564 // Aggressively form a vector op if the cost is equal because the transform
565 // may enable further optimization.
566 // Codegen can reverse this transform (scalarize) if it was not profitable.
567 return OldCost < NewCost;
568}
569
570/// Create a shuffle that translates (shifts) 1 element from the input vector
571/// to a new element location.
572static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
573 unsigned NewIndex, IRBuilderBase &Builder) {
574 // The shuffle mask is poison except for 1 lane that is being translated
575 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
576 // ShufMask = { 2, poison, poison, poison }
577 auto *VecTy = cast<FixedVectorType>(Val: Vec->getType());
578 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
579 ShufMask[NewIndex] = OldIndex;
580 return Builder.CreateShuffleVector(V: Vec, Mask: ShufMask, Name: "shift");
581}
582
583/// Given an extract element instruction with constant index operand, shuffle
584/// the source vector (shift the scalar element) to a NewIndex for extraction.
585/// Return null if the input can be constant folded, so that we are not creating
586/// unnecessary instructions.
587static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
588 IRBuilderBase &Builder) {
589 // Shufflevectors can only be created for fixed-width vectors.
590 Value *X = ExtElt->getVectorOperand();
591 if (!isa<FixedVectorType>(Val: X->getType()))
592 return nullptr;
593
594 // If the extract can be constant-folded, this code is unsimplified. Defer
595 // to other passes to handle that.
596 Value *C = ExtElt->getIndexOperand();
597 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
598 if (isa<Constant>(Val: X))
599 return nullptr;
600
601 Value *Shuf = createShiftShuffle(Vec: X, OldIndex: cast<ConstantInt>(Val: C)->getZExtValue(),
602 NewIndex, Builder);
603 return Shuf;
604}
605
606/// Try to reduce extract element costs by converting scalar compares to vector
607/// compares followed by extract.
608/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
609Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
610 Instruction &I) {
611 assert(isa<CmpInst>(&I) && "Expected a compare");
612
613 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
614 // --> extelt (cmp Pred V0, V1), ExtIndex
615 ++NumVecCmp;
616 CmpInst::Predicate Pred = cast<CmpInst>(Val: &I)->getPredicate();
617 Value *VecCmp = Builder.CreateCmp(Pred, LHS: V0, RHS: V1);
618 return Builder.CreateExtractElement(Vec: VecCmp, Idx: ExtIndex, Name: "foldExtExtCmp");
619}
620
621/// Try to reduce extract element costs by converting scalar binops to vector
622/// binops followed by extract.
623/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
624Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
625 Instruction &I) {
626 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
627
628 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
629 // --> extelt (bo V0, V1), ExtIndex
630 ++NumVecBO;
631 Value *VecBO = Builder.CreateBinOp(Opc: cast<BinaryOperator>(Val: &I)->getOpcode(), LHS: V0,
632 RHS: V1, Name: "foldExtExtBinop");
633
634 // All IR flags are safe to back-propagate because any potential poison
635 // created in unused vector elements is discarded by the extract.
636 if (auto *VecBOInst = dyn_cast<Instruction>(Val: VecBO))
637 VecBOInst->copyIRFlags(V: &I);
638
639 return Builder.CreateExtractElement(Vec: VecBO, Idx: ExtIndex, Name: "foldExtExtBinop");
640}
641
642/// Match an instruction with extracted vector operands.
643bool VectorCombine::foldExtractExtract(Instruction &I) {
644 // It is not safe to transform things like div, urem, etc. because we may
645 // create undefined behavior when executing those on unknown vector elements.
646 if (!isSafeToSpeculativelyExecute(I: &I))
647 return false;
648
649 Instruction *I0, *I1;
650 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
651 if (!match(V: &I, P: m_Cmp(Pred, L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))) &&
652 !match(V: &I, P: m_BinOp(L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))))
653 return false;
654
655 Value *V0, *V1;
656 uint64_t C0, C1;
657 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: V0), Idx: m_ConstantInt(V&: C0))) ||
658 !match(V: I1, P: m_ExtractElt(Val: m_Value(V&: V1), Idx: m_ConstantInt(V&: C1))) ||
659 V0->getType() != V1->getType())
660 return false;
661
662 // If the scalar value 'I' is going to be re-inserted into a vector, then try
663 // to create an extract to that same element. The extract/insert can be
664 // reduced to a "select shuffle".
665 // TODO: If we add a larger pattern match that starts from an insert, this
666 // probably becomes unnecessary.
667 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
668 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
669 uint64_t InsertIndex = InvalidIndex;
670 if (I.hasOneUse())
671 match(V: I.user_back(),
672 P: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ConstantInt(V&: InsertIndex)));
673
674 ExtractElementInst *ExtractToChange;
675 if (isExtractExtractCheap(Ext0, Ext1, I, ConvertToShuffle&: ExtractToChange, PreferredExtractIndex: InsertIndex))
676 return false;
677
678 Value *ExtOp0 = Ext0->getVectorOperand();
679 Value *ExtOp1 = Ext1->getVectorOperand();
680
681 if (ExtractToChange) {
682 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
683 Value *NewExtOp =
684 translateExtract(ExtElt: ExtractToChange, NewIndex: CheapExtractIdx, Builder);
685 if (!NewExtOp)
686 return false;
687 if (ExtractToChange == Ext0)
688 ExtOp0 = NewExtOp;
689 else
690 ExtOp1 = NewExtOp;
691 }
692
693 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
694 : Ext0->getIndexOperand();
695 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
696 ? foldExtExtCmp(V0: ExtOp0, V1: ExtOp1, ExtIndex, I)
697 : foldExtExtBinop(V0: ExtOp0, V1: ExtOp1, ExtIndex, I);
698 Worklist.push(I: Ext0);
699 Worklist.push(I: Ext1);
700 replaceValue(Old&: I, New&: *NewExt);
701 return true;
702}
703
704/// Try to replace an extract + scalar fneg + insert with a vector fneg +
705/// shuffle.
706bool VectorCombine::foldInsExtFNeg(Instruction &I) {
707 // Match an insert (op (extract)) pattern.
708 Value *DstVec;
709 uint64_t ExtIdx, InsIdx;
710 Instruction *FNeg;
711 if (!match(V: &I, P: m_InsertElt(Val: m_Value(V&: DstVec), Elt: m_OneUse(SubPattern: m_Instruction(I&: FNeg)),
712 Idx: m_ConstantInt(V&: InsIdx))))
713 return false;
714
715 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
716 Value *SrcVec;
717 Instruction *Extract;
718 if (!match(V: FNeg, P: m_FNeg(X: m_CombineAnd(
719 L: m_Instruction(I&: Extract),
720 R: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx))))))
721 return false;
722
723 auto *DstVecTy = cast<FixedVectorType>(Val: DstVec->getType());
724 auto *DstVecScalarTy = DstVecTy->getScalarType();
725 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
726 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
727 return false;
728
729 // Ignore if insert/extract index is out of bounds or destination vector has
730 // one element
731 unsigned NumDstElts = DstVecTy->getNumElements();
732 unsigned NumSrcElts = SrcVecTy->getNumElements();
733 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
734 return false;
735
736 // We are inserting the negated element into the same lane that we extracted
737 // from. This is equivalent to a select-shuffle that chooses all but the
738 // negated element from the destination vector.
739 SmallVector<int> Mask(NumDstElts);
740 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
741 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
742 InstructionCost OldCost =
743 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: DstVecScalarTy, CostKind) +
744 TTI.getVectorInstrCost(I, Val: DstVecTy, CostKind, Index: InsIdx);
745
746 // If the extract has one use, it will be eliminated, so count it in the
747 // original cost. If it has more than one use, ignore the cost because it will
748 // be the same before/after.
749 if (Extract->hasOneUse())
750 OldCost += TTI.getVectorInstrCost(I: *Extract, Val: SrcVecTy, CostKind, Index: ExtIdx);
751
752 InstructionCost NewCost =
753 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: SrcVecTy, CostKind) +
754 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: DstVecTy,
755 SrcTy: DstVecTy, Mask, CostKind);
756
757 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
758 // If the lengths of the two vectors are not equal,
759 // we need to add a length-change vector. Add this cost.
760 SmallVector<int> SrcMask;
761 if (NeedLenChg) {
762 SrcMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
763 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
764 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
765 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: SrcMask, CostKind);
766 }
767
768 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
769 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
770 << "\n");
771 if (NewCost > OldCost)
772 return false;
773
774 Value *NewShuf, *LenChgShuf = nullptr;
775 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
776 Value *VecFNeg = Builder.CreateFNegFMF(V: SrcVec, FMFSource: FNeg);
777 if (NeedLenChg) {
778 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
779 LenChgShuf = Builder.CreateShuffleVector(V: VecFNeg, Mask: SrcMask);
780 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: LenChgShuf, Mask);
781 Worklist.pushValue(V: LenChgShuf);
782 } else {
783 // shuffle DstVec, (fneg SrcVec), Mask
784 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: VecFNeg, Mask);
785 }
786
787 Worklist.pushValue(V: VecFNeg);
788 replaceValue(Old&: I, New&: *NewShuf);
789 return true;
790}
791
792/// Try to fold insert(binop(x,y),binop(a,b),idx)
793/// --> binop(insert(x,a,idx),insert(y,b,idx))
794bool VectorCombine::foldInsExtBinop(Instruction &I) {
795 BinaryOperator *VecBinOp, *SclBinOp;
796 uint64_t Index;
797 if (!match(V: &I,
798 P: m_InsertElt(Val: m_OneUse(SubPattern: m_BinOp(I&: VecBinOp)),
799 Elt: m_OneUse(SubPattern: m_BinOp(I&: SclBinOp)), Idx: m_ConstantInt(V&: Index))))
800 return false;
801
802 // TODO: Add support for addlike etc.
803 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
804 if (BinOpcode != SclBinOp->getOpcode())
805 return false;
806
807 auto *ResultTy = dyn_cast<FixedVectorType>(Val: I.getType());
808 if (!ResultTy)
809 return false;
810
811 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
812 // shuffle?
813
814 InstructionCost OldCost = TTI.getInstructionCost(U: &I, CostKind) +
815 TTI.getInstructionCost(U: VecBinOp, CostKind) +
816 TTI.getInstructionCost(U: SclBinOp, CostKind);
817 InstructionCost NewCost =
818 TTI.getArithmeticInstrCost(Opcode: BinOpcode, Ty: ResultTy, CostKind) +
819 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
820 Index, Op0: VecBinOp->getOperand(i_nocapture: 0),
821 Op1: SclBinOp->getOperand(i_nocapture: 0)) +
822 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
823 Index, Op0: VecBinOp->getOperand(i_nocapture: 1),
824 Op1: SclBinOp->getOperand(i_nocapture: 1));
825
826 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
827 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
828 << "\n");
829 if (NewCost > OldCost)
830 return false;
831
832 Value *NewIns0 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 0),
833 NewElt: SclBinOp->getOperand(i_nocapture: 0), Idx: Index);
834 Value *NewIns1 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 1),
835 NewElt: SclBinOp->getOperand(i_nocapture: 1), Idx: Index);
836 Value *NewBO = Builder.CreateBinOp(Opc: BinOpcode, LHS: NewIns0, RHS: NewIns1);
837
838 // Intersect flags from the old binops.
839 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
840 NewInst->copyIRFlags(V: VecBinOp);
841 NewInst->andIRFlags(V: SclBinOp);
842 }
843
844 Worklist.pushValue(V: NewIns0);
845 Worklist.pushValue(V: NewIns1);
846 replaceValue(Old&: I, New&: *NewBO);
847 return true;
848}
849
850/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
851/// Supports: bitcast, trunc, sext, zext
852bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
853 // Check if this is a bitwise logic operation
854 auto *BinOp = dyn_cast<BinaryOperator>(Val: &I);
855 if (!BinOp || !BinOp->isBitwiseLogicOp())
856 return false;
857
858 // Get the cast instructions
859 auto *LHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 0));
860 auto *RHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 1));
861 if (!LHSCast || !RHSCast) {
862 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
863 return false;
864 }
865
866 // Both casts must be the same type
867 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
868 if (CastOpcode != RHSCast->getOpcode())
869 return false;
870
871 // Only handle supported cast operations
872 switch (CastOpcode) {
873 case Instruction::BitCast:
874 case Instruction::Trunc:
875 case Instruction::SExt:
876 case Instruction::ZExt:
877 break;
878 default:
879 return false;
880 }
881
882 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
883 Value *RHSSrc = RHSCast->getOperand(i_nocapture: 0);
884
885 // Source types must match
886 if (LHSSrc->getType() != RHSSrc->getType())
887 return false;
888
889 auto *SrcTy = LHSSrc->getType();
890 auto *DstTy = I.getType();
891 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
892 // Other casts only handle vector types with integer elements.
893 if (CastOpcode != Instruction::BitCast &&
894 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
895 return false;
896
897 // Only integer scalar/vector values are legal for bitwise logic operations.
898 if (!SrcTy->getScalarType()->isIntegerTy() ||
899 !DstTy->getScalarType()->isIntegerTy())
900 return false;
901
902 // Cost Check :
903 // OldCost = bitlogic + 2*casts
904 // NewCost = bitlogic + cast
905
906 // Calculate specific costs for each cast with instruction context
907 InstructionCost LHSCastCost = TTI.getCastInstrCost(
908 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
909 InstructionCost RHSCastCost = TTI.getCastInstrCost(
910 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: RHSCast);
911
912 InstructionCost OldCost =
913 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: DstTy, CostKind) +
914 LHSCastCost + RHSCastCost;
915
916 // For new cost, we can't provide an instruction (it doesn't exist yet)
917 InstructionCost GenericCastCost = TTI.getCastInstrCost(
918 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
919
920 InstructionCost NewCost =
921 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: SrcTy, CostKind) +
922 GenericCastCost;
923
924 // Account for multi-use casts using specific costs
925 if (!LHSCast->hasOneUse())
926 NewCost += LHSCastCost;
927 if (!RHSCast->hasOneUse())
928 NewCost += RHSCastCost;
929
930 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
931 << " NewCost=" << NewCost << "\n");
932
933 if (NewCost > OldCost)
934 return false;
935
936 // Create the operation on the source type
937 Value *NewOp = Builder.CreateBinOp(Opc: BinOp->getOpcode(), LHS: LHSSrc, RHS: RHSSrc,
938 Name: BinOp->getName() + ".inner");
939 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
940 NewBinOp->copyIRFlags(V: BinOp);
941
942 Worklist.pushValue(V: NewOp);
943
944 // Create the cast operation directly to ensure we get a new instruction
945 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
946
947 // Preserve cast instruction flags
948 NewCast->copyIRFlags(V: LHSCast);
949 NewCast->andIRFlags(V: RHSCast);
950
951 // Insert the new instruction
952 Value *Result = Builder.Insert(I: NewCast);
953
954 replaceValue(Old&: I, New&: *Result);
955 return true;
956}
957
958/// Match:
959// bitop(castop(x), C) ->
960// bitop(castop(x), castop(InvC)) ->
961// castop(bitop(x, InvC))
962// Supports: bitcast
963bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
964 Instruction *LHS;
965 Constant *C;
966
967 // Check if this is a bitwise logic operation
968 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_Instruction(I&: LHS), R: m_Constant(C))))
969 return false;
970
971 // Get the cast instructions
972 auto *LHSCast = dyn_cast<CastInst>(Val: LHS);
973 if (!LHSCast)
974 return false;
975
976 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
977
978 // Only handle supported cast operations
979 switch (CastOpcode) {
980 case Instruction::BitCast:
981 case Instruction::ZExt:
982 case Instruction::SExt:
983 case Instruction::Trunc:
984 break;
985 default:
986 return false;
987 }
988
989 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
990
991 auto *SrcTy = LHSSrc->getType();
992 auto *DstTy = I.getType();
993 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
994 // Other casts only handle vector types with integer elements.
995 if (CastOpcode != Instruction::BitCast &&
996 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
997 return false;
998
999 // Only integer scalar/vector values are legal for bitwise logic operations.
1000 if (!SrcTy->getScalarType()->isIntegerTy() ||
1001 !DstTy->getScalarType()->isIntegerTy())
1002 return false;
1003
1004 // Find the constant InvC, such that castop(InvC) equals to C.
1005 PreservedCastFlags RHSFlags;
1006 Constant *InvC = getLosslessInvCast(C, InvCastTo: SrcTy, CastOp: CastOpcode, DL: *DL, Flags: &RHSFlags);
1007 if (!InvC)
1008 return false;
1009
1010 // Cost Check :
1011 // OldCost = bitlogic + cast
1012 // NewCost = bitlogic + cast
1013
1014 // Calculate specific costs for each cast with instruction context
1015 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1016 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
1017
1018 InstructionCost OldCost =
1019 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: DstTy, CostKind) + LHSCastCost;
1020
1021 // For new cost, we can't provide an instruction (it doesn't exist yet)
1022 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1023 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
1024
1025 InstructionCost NewCost =
1026 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: SrcTy, CostKind) +
1027 GenericCastCost;
1028
1029 // Account for multi-use casts using specific costs
1030 if (!LHSCast->hasOneUse())
1031 NewCost += LHSCastCost;
1032
1033 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1034 << " NewCost=" << NewCost << "\n");
1035
1036 if (NewCost > OldCost)
1037 return false;
1038
1039 // Create the operation on the source type
1040 Value *NewOp = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(),
1041 LHS: LHSSrc, RHS: InvC, Name: I.getName() + ".inner");
1042 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
1043 NewBinOp->copyIRFlags(V: &I);
1044
1045 Worklist.pushValue(V: NewOp);
1046
1047 // Create the cast operation directly to ensure we get a new instruction
1048 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
1049
1050 // Preserve cast instruction flags
1051 if (RHSFlags.NNeg)
1052 NewCast->setNonNeg();
1053 if (RHSFlags.NUW)
1054 NewCast->setHasNoUnsignedWrap();
1055 if (RHSFlags.NSW)
1056 NewCast->setHasNoSignedWrap();
1057
1058 NewCast->andIRFlags(V: LHSCast);
1059
1060 // Insert the new instruction
1061 Value *Result = Builder.Insert(I: NewCast);
1062
1063 replaceValue(Old&: I, New&: *Result);
1064 return true;
1065}
1066
1067/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1068/// destination type followed by shuffle. This can enable further transforms by
1069/// moving bitcasts or shuffles together.
1070bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1071 Value *V0, *V1;
1072 ArrayRef<int> Mask;
1073 if (!match(V: &I, P: m_BitCast(Op: m_OneUse(
1074 SubPattern: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(Mask))))))
1075 return false;
1076
1077 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1078 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1079 // mask for scalable type is a splat or not.
1080 // 2) Disallow non-vector casts.
1081 // TODO: We could allow any shuffle.
1082 auto *DestTy = dyn_cast<FixedVectorType>(Val: I.getType());
1083 auto *SrcTy = dyn_cast<FixedVectorType>(Val: V0->getType());
1084 if (!DestTy || !SrcTy)
1085 return false;
1086
1087 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1088 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1089 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1090 return false;
1091
1092 bool IsUnary = isa<UndefValue>(Val: V1);
1093
1094 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1095 // if it won't increase the number of bitcasts.
1096 if (!IsUnary) {
1097 auto *BCTy0 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V0)->getType());
1098 auto *BCTy1 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V1)->getType());
1099 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1100 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1101 return false;
1102 }
1103
1104 SmallVector<int, 16> NewMask;
1105 if (DestEltSize <= SrcEltSize) {
1106 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1107 // always be expanded to the equivalent form choosing narrower elements.
1108 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
1109 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1110 narrowShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask);
1111 } else {
1112 // The bitcast is from narrow elements to wide elements. The shuffle mask
1113 // must choose consecutive elements to allow casting first.
1114 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
1115 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1116 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask))
1117 return false;
1118 }
1119
1120 // Bitcast the shuffle src - keep its original width but using the destination
1121 // scalar type.
1122 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1123 auto *NewShuffleTy =
1124 FixedVectorType::get(ElementType: DestTy->getScalarType(), NumElts: NumSrcElts);
1125 auto *OldShuffleTy =
1126 FixedVectorType::get(ElementType: SrcTy->getScalarType(), NumElts: Mask.size());
1127 unsigned NumOps = IsUnary ? 1 : 2;
1128
1129 // The new shuffle must not cost more than the old shuffle.
1130 TargetTransformInfo::ShuffleKind SK =
1131 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
1132 : TargetTransformInfo::SK_PermuteTwoSrc;
1133
1134 InstructionCost NewCost =
1135 TTI.getShuffleCost(Kind: SK, DstTy: DestTy, SrcTy: NewShuffleTy, Mask: NewMask, CostKind) +
1136 (NumOps * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewShuffleTy, Src: SrcTy,
1137 CCH: TargetTransformInfo::CastContextHint::None,
1138 CostKind));
1139 InstructionCost OldCost =
1140 TTI.getShuffleCost(Kind: SK, DstTy: OldShuffleTy, SrcTy, Mask, CostKind) +
1141 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DestTy, Src: OldShuffleTy,
1142 CCH: TargetTransformInfo::CastContextHint::None,
1143 CostKind);
1144
1145 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1146 << OldCost << " vs NewCost: " << NewCost << "\n");
1147
1148 if (NewCost > OldCost || !NewCost.isValid())
1149 return false;
1150
1151 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1152 ++NumShufOfBitcast;
1153 Value *CastV0 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V0), DestTy: NewShuffleTy);
1154 Value *CastV1 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V1), DestTy: NewShuffleTy);
1155 Value *Shuf = Builder.CreateShuffleVector(V1: CastV0, V2: CastV1, Mask: NewMask);
1156 replaceValue(Old&: I, New&: *Shuf);
1157 return true;
1158}
1159
1160/// VP Intrinsics whose vector operands are both splat values may be simplified
1161/// into the scalar version of the operation and the result splatted. This
1162/// can lead to scalarization down the line.
1163bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1164 if (!isa<VPIntrinsic>(Val: I))
1165 return false;
1166 VPIntrinsic &VPI = cast<VPIntrinsic>(Val&: I);
1167 Value *Op0 = VPI.getArgOperand(i: 0);
1168 Value *Op1 = VPI.getArgOperand(i: 1);
1169
1170 if (!isSplatValue(V: Op0) || !isSplatValue(V: Op1))
1171 return false;
1172
1173 // Check getSplatValue early in this function, to avoid doing unnecessary
1174 // work.
1175 Value *ScalarOp0 = getSplatValue(V: Op0);
1176 Value *ScalarOp1 = getSplatValue(V: Op1);
1177 if (!ScalarOp0 || !ScalarOp1)
1178 return false;
1179
1180 // For the binary VP intrinsics supported here, the result on disabled lanes
1181 // is a poison value. For now, only do this simplification if all lanes
1182 // are active.
1183 // TODO: Relax the condition that all lanes are active by using insertelement
1184 // on inactive lanes.
1185 auto IsAllTrueMask = [](Value *MaskVal) {
1186 if (Value *SplattedVal = getSplatValue(V: MaskVal))
1187 if (auto *ConstValue = dyn_cast<Constant>(Val: SplattedVal))
1188 return ConstValue->isAllOnesValue();
1189 return false;
1190 };
1191 if (!IsAllTrueMask(VPI.getArgOperand(i: 2)))
1192 return false;
1193
1194 // Check to make sure we support scalarization of the intrinsic
1195 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1196 if (!VPBinOpIntrinsic::isVPBinOp(ID: IntrID))
1197 return false;
1198
1199 // Calculate cost of splatting both operands into vectors and the vector
1200 // intrinsic
1201 VectorType *VecTy = cast<VectorType>(Val: VPI.getType());
1202 SmallVector<int> Mask;
1203 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: VecTy))
1204 Mask.resize(N: FVTy->getNumElements(), NV: 0);
1205 InstructionCost SplatCost =
1206 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: 0) +
1207 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_Broadcast, DstTy: VecTy, SrcTy: VecTy, Mask,
1208 CostKind);
1209
1210 // Calculate the cost of the VP Intrinsic
1211 SmallVector<Type *, 4> Args;
1212 for (Value *V : VPI.args())
1213 Args.push_back(Elt: V->getType());
1214 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1215 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1216 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1217
1218 // Determine scalar opcode
1219 std::optional<unsigned> FunctionalOpcode =
1220 VPI.getFunctionalOpcode();
1221 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1222 if (!FunctionalOpcode) {
1223 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1224 if (!ScalarIntrID)
1225 return false;
1226 }
1227
1228 // Calculate cost of scalarizing
1229 InstructionCost ScalarOpCost = 0;
1230 if (ScalarIntrID) {
1231 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1232 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1233 } else {
1234 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode: *FunctionalOpcode,
1235 Ty: VecTy->getScalarType(), CostKind);
1236 }
1237
1238 // The existing splats may be kept around if other instructions use them.
1239 InstructionCost CostToKeepSplats =
1240 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1241 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1242
1243 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1244 << "\n");
1245 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1246 << ", Cost of scalarizing:" << NewCost << "\n");
1247
1248 // We want to scalarize unless the vector variant actually has lower cost.
1249 if (OldCost < NewCost || !NewCost.isValid())
1250 return false;
1251
1252 // Scalarize the intrinsic
1253 ElementCount EC = cast<VectorType>(Val: Op0->getType())->getElementCount();
1254 Value *EVL = VPI.getArgOperand(i: 3);
1255
1256 // If the VP op might introduce UB or poison, we can scalarize it provided
1257 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1258 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1259 // scalarizing it.
1260 bool SafeToSpeculate;
1261 if (ScalarIntrID)
1262 SafeToSpeculate = Intrinsic::getFnAttributes(C&: I.getContext(), id: *ScalarIntrID)
1263 .hasAttribute(Kind: Attribute::AttrKind::Speculatable);
1264 else
1265 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1266 Opcode: *FunctionalOpcode, Inst: &VPI, CtxI: nullptr, AC: &AC, DT: &DT);
1267 if (!SafeToSpeculate &&
1268 !isKnownNonZero(V: EVL, Q: SimplifyQuery(*DL, &DT, &AC, &VPI)))
1269 return false;
1270
1271 Value *ScalarVal =
1272 ScalarIntrID
1273 ? Builder.CreateIntrinsic(RetTy: VecTy->getScalarType(), ID: *ScalarIntrID,
1274 Args: {ScalarOp0, ScalarOp1})
1275 : Builder.CreateBinOp(Opc: (Instruction::BinaryOps)(*FunctionalOpcode),
1276 LHS: ScalarOp0, RHS: ScalarOp1);
1277
1278 replaceValue(Old&: VPI, New&: *Builder.CreateVectorSplat(EC, V: ScalarVal));
1279 return true;
1280}
1281
1282/// Match a vector op/compare/intrinsic with at least one
1283/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1284/// by insertelement.
1285bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1286 auto *UO = dyn_cast<UnaryOperator>(Val: &I);
1287 auto *BO = dyn_cast<BinaryOperator>(Val: &I);
1288 auto *CI = dyn_cast<CmpInst>(Val: &I);
1289 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
1290 if (!UO && !BO && !CI && !II)
1291 return false;
1292
1293 // TODO: Allow intrinsics with different argument types
1294 if (II) {
1295 if (!isTriviallyVectorizable(ID: II->getIntrinsicID()))
1296 return false;
1297 for (auto [Idx, Arg] : enumerate(First: II->args()))
1298 if (Arg->getType() != II->getType() &&
1299 !isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI))
1300 return false;
1301 }
1302
1303 // Do not convert the vector condition of a vector select into a scalar
1304 // condition. That may cause problems for codegen because of differences in
1305 // boolean formats and register-file transfers.
1306 // TODO: Can we account for that in the cost model?
1307 if (CI)
1308 for (User *U : I.users())
1309 if (match(V: U, P: m_Select(C: m_Specific(V: &I), L: m_Value(), R: m_Value())))
1310 return false;
1311
1312 // Match constant vectors or scalars being inserted into constant vectors:
1313 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1314 SmallVector<Value *> VecCs, ScalarOps;
1315 std::optional<uint64_t> Index;
1316
1317 auto Ops = II ? II->args() : I.operands();
1318 for (auto [OpNum, Op] : enumerate(First&: Ops)) {
1319 Constant *VecC;
1320 Value *V;
1321 uint64_t InsIdx = 0;
1322 if (match(V: Op.get(), P: m_InsertElt(Val: m_Constant(C&: VecC), Elt: m_Value(V),
1323 Idx: m_ConstantInt(V&: InsIdx)))) {
1324 // Bail if any inserts are out of bounds.
1325 VectorType *OpTy = cast<VectorType>(Val: Op->getType());
1326 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1327 return false;
1328 // All inserts must have the same index.
1329 // TODO: Deal with mismatched index constants and variable indexes?
1330 if (!Index)
1331 Index = InsIdx;
1332 else if (InsIdx != *Index)
1333 return false;
1334 VecCs.push_back(Elt: VecC);
1335 ScalarOps.push_back(Elt: V);
1336 } else if (II && isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(),
1337 ScalarOpdIdx: OpNum, TTI: &TTI)) {
1338 VecCs.push_back(Elt: Op.get());
1339 ScalarOps.push_back(Elt: Op.get());
1340 } else if (match(V: Op.get(), P: m_Constant(C&: VecC))) {
1341 VecCs.push_back(Elt: VecC);
1342 ScalarOps.push_back(Elt: nullptr);
1343 } else {
1344 return false;
1345 }
1346 }
1347
1348 // Bail if all operands are constant.
1349 if (!Index.has_value())
1350 return false;
1351
1352 VectorType *VecTy = cast<VectorType>(Val: I.getType());
1353 Type *ScalarTy = VecTy->getScalarType();
1354 assert(VecTy->isVectorTy() &&
1355 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1356 ScalarTy->isPointerTy()) &&
1357 "Unexpected types for insert element into binop or cmp");
1358
1359 unsigned Opcode = I.getOpcode();
1360 InstructionCost ScalarOpCost, VectorOpCost;
1361 if (CI) {
1362 CmpInst::Predicate Pred = CI->getPredicate();
1363 ScalarOpCost = TTI.getCmpSelInstrCost(
1364 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
1365 VectorOpCost = TTI.getCmpSelInstrCost(
1366 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1367 } else if (UO || BO) {
1368 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
1369 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
1370 } else {
1371 IntrinsicCostAttributes ScalarICA(
1372 II->getIntrinsicID(), ScalarTy,
1373 SmallVector<Type *>(II->arg_size(), ScalarTy));
1374 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: ScalarICA, CostKind);
1375 IntrinsicCostAttributes VectorICA(
1376 II->getIntrinsicID(), VecTy,
1377 SmallVector<Type *>(II->arg_size(), VecTy));
1378 VectorOpCost = TTI.getIntrinsicInstrCost(ICA: VectorICA, CostKind);
1379 }
1380
1381 // Fold the vector constants in the original vectors into a new base vector to
1382 // get more accurate cost modelling.
1383 Value *NewVecC = nullptr;
1384 if (CI)
1385 NewVecC = simplifyCmpInst(Predicate: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1386 else if (UO)
1387 NewVecC =
1388 simplifyUnOp(Opcode: UO->getOpcode(), Op: VecCs[0], FMF: UO->getFastMathFlags(), Q: SQ);
1389 else if (BO)
1390 NewVecC = simplifyBinOp(Opcode: BO->getOpcode(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1391 else if (II)
1392 NewVecC = simplifyCall(Call: II, Callee: II->getCalledOperand(), Args: VecCs, Q: SQ);
1393
1394 if (!NewVecC)
1395 return false;
1396
1397 // Get cost estimate for the insert element. This cost will factor into
1398 // both sequences.
1399 InstructionCost OldCost = VectorOpCost;
1400 InstructionCost NewCost =
1401 ScalarOpCost + TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy,
1402 CostKind, Index: *Index, Op0: NewVecC);
1403
1404 for (auto [Idx, Op, VecC, Scalar] : enumerate(First&: Ops, Rest&: VecCs, Rest&: ScalarOps)) {
1405 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1406 ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI)))
1407 continue;
1408 InstructionCost InsertCost = TTI.getVectorInstrCost(
1409 Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: *Index, Op0: VecC, Op1: Scalar);
1410 OldCost += InsertCost;
1411 NewCost += !Op->hasOneUse() * InsertCost;
1412 }
1413
1414 // We want to scalarize unless the vector variant actually has lower cost.
1415 if (OldCost < NewCost || !NewCost.isValid())
1416 return false;
1417
1418 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1419 // inselt NewVecC, (scalar_op V0, V1), Index
1420 if (CI)
1421 ++NumScalarCmp;
1422 else if (UO || BO)
1423 ++NumScalarOps;
1424 else
1425 ++NumScalarIntrinsic;
1426
1427 // For constant cases, extract the scalar element, this should constant fold.
1428 for (auto [OpIdx, Scalar, VecC] : enumerate(First&: ScalarOps, Rest&: VecCs))
1429 if (!Scalar)
1430 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1431 Vec: cast<Constant>(Val: VecC), Idx: Builder.getInt64(C: *Index));
1432
1433 Value *Scalar;
1434 if (CI)
1435 Scalar = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: ScalarOps[0], RHS: ScalarOps[1]);
1436 else if (UO || BO)
1437 Scalar = Builder.CreateNAryOp(Opc: Opcode, Ops: ScalarOps);
1438 else
1439 Scalar = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: II->getIntrinsicID(), Args: ScalarOps);
1440
1441 Scalar->setName(I.getName() + ".scalar");
1442
1443 // All IR flags are safe to back-propagate. There is no potential for extra
1444 // poison to be created by the scalar instruction.
1445 if (auto *ScalarInst = dyn_cast<Instruction>(Val: Scalar))
1446 ScalarInst->copyIRFlags(V: &I);
1447
1448 Value *Insert = Builder.CreateInsertElement(Vec: NewVecC, NewElt: Scalar, Idx: *Index);
1449 replaceValue(Old&: I, New&: *Insert);
1450 return true;
1451}
1452
1453/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1454/// a vector into vector operations followed by extract. Note: The SLP pass
1455/// may miss this pattern because of implementation problems.
1456bool VectorCombine::foldExtractedCmps(Instruction &I) {
1457 auto *BI = dyn_cast<BinaryOperator>(Val: &I);
1458
1459 // We are looking for a scalar binop of booleans.
1460 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1461 if (!BI || !I.getType()->isIntegerTy(Bitwidth: 1))
1462 return false;
1463
1464 // The compare predicates should match, and each compare should have a
1465 // constant operand.
1466 Value *B0 = I.getOperand(i: 0), *B1 = I.getOperand(i: 1);
1467 Instruction *I0, *I1;
1468 Constant *C0, *C1;
1469 CmpPredicate P0, P1;
1470 if (!match(V: B0, P: m_Cmp(Pred&: P0, L: m_Instruction(I&: I0), R: m_Constant(C&: C0))) ||
1471 !match(V: B1, P: m_Cmp(Pred&: P1, L: m_Instruction(I&: I1), R: m_Constant(C&: C1))))
1472 return false;
1473
1474 auto MatchingPred = CmpPredicate::getMatching(A: P0, B: P1);
1475 if (!MatchingPred)
1476 return false;
1477
1478 // The compare operands must be extracts of the same vector with constant
1479 // extract indexes.
1480 Value *X;
1481 uint64_t Index0, Index1;
1482 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ConstantInt(V&: Index0))) ||
1483 !match(V: I1, P: m_ExtractElt(Val: m_Specific(V: X), Idx: m_ConstantInt(V&: Index1))))
1484 return false;
1485
1486 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
1487 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
1488 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex: CostKind);
1489 if (!ConvertToShuf)
1490 return false;
1491 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1492 "Unknown ExtractElementInst");
1493
1494 // The original scalar pattern is:
1495 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1496 CmpInst::Predicate Pred = *MatchingPred;
1497 unsigned CmpOpcode =
1498 CmpInst::isFPPredicate(P: Pred) ? Instruction::FCmp : Instruction::ICmp;
1499 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
1500 if (!VecTy)
1501 return false;
1502
1503 InstructionCost Ext0Cost =
1504 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
1505 InstructionCost Ext1Cost =
1506 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
1507 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1508 Opcode: CmpOpcode, ValTy: I0->getType(), CondTy: CmpInst::makeCmpResultType(opnd_type: I0->getType()), VecPred: Pred,
1509 CostKind);
1510
1511 InstructionCost OldCost =
1512 Ext0Cost + Ext1Cost + CmpCost * 2 +
1513 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: I.getType(), CostKind);
1514
1515 // The proposed vector pattern is:
1516 // vcmp = cmp Pred X, VecC
1517 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1518 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1519 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1520 auto *CmpTy = cast<FixedVectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
1521 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1522 Opcode: CmpOpcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1523 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1524 ShufMask[CheapIndex] = ExpensiveIndex;
1525 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: CmpTy,
1526 SrcTy: CmpTy, Mask: ShufMask, CostKind);
1527 NewCost += TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: CmpTy, CostKind);
1528 NewCost += TTI.getVectorInstrCost(I: *Ext0, Val: CmpTy, CostKind, Index: CheapIndex);
1529 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1530 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1531
1532 // Aggressively form vector ops if the cost is equal because the transform
1533 // may enable further optimization.
1534 // Codegen can reverse this transform (scalarize) if it was not profitable.
1535 if (OldCost < NewCost || !NewCost.isValid())
1536 return false;
1537
1538 // Create a vector constant from the 2 scalar constants.
1539 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1540 PoisonValue::get(T: VecTy->getElementType()));
1541 CmpC[Index0] = C0;
1542 CmpC[Index1] = C1;
1543 Value *VCmp = Builder.CreateCmp(Pred, LHS: X, RHS: ConstantVector::get(V: CmpC));
1544 Value *Shuf = createShiftShuffle(Vec: VCmp, OldIndex: ExpensiveIndex, NewIndex: CheapIndex, Builder);
1545 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1546 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1547 Value *VecLogic = Builder.CreateBinOp(Opc: BI->getOpcode(), LHS, RHS);
1548 Value *NewExt = Builder.CreateExtractElement(Vec: VecLogic, Idx: CheapIndex);
1549 replaceValue(Old&: I, New&: *NewExt);
1550 ++NumVecCmpBO;
1551 return true;
1552}
1553
1554/// Try to fold scalar selects that select between extracted elements and zero
1555/// into extracting from a vector select. This is rooted at the bitcast.
1556///
1557/// This pattern arises when a vector is bitcast to a smaller element type,
1558/// elements are extracted, and then conditionally selected with zero:
1559///
1560/// %bc = bitcast <4 x i32> %src to <16 x i8>
1561/// %e0 = extractelement <16 x i8> %bc, i32 0
1562/// %s0 = select i1 %cond, i8 %e0, i8 0
1563/// %e1 = extractelement <16 x i8> %bc, i32 1
1564/// %s1 = select i1 %cond, i8 %e1, i8 0
1565/// ...
1566///
1567/// Transforms to:
1568/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1569/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1570/// %e0 = extractelement <16 x i8> %bc, i32 0
1571/// %e1 = extractelement <16 x i8> %bc, i32 1
1572/// ...
1573///
1574/// This is profitable because vector select on wider types produces fewer
1575/// select/cndmask instructions than scalar selects on each element.
1576bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1577 auto *BC = dyn_cast<BitCastInst>(Val: &I);
1578 if (!BC)
1579 return false;
1580
1581 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(Val: BC->getSrcTy());
1582 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(Val: BC->getDestTy());
1583 if (!SrcVecTy || !DstVecTy)
1584 return false;
1585
1586 // Source must be 32-bit or 64-bit elements, destination must be smaller
1587 // integer elements. Zero in all these types is all-bits-zero.
1588 Type *SrcEltTy = SrcVecTy->getElementType();
1589 Type *DstEltTy = DstVecTy->getElementType();
1590 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1591 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1592
1593 if (SrcEltBits != 32 && SrcEltBits != 64)
1594 return false;
1595
1596 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1597 return false;
1598
1599 // Check profitability using TTI before collecting users.
1600 Type *CondTy = CmpInst::makeCmpResultType(opnd_type: DstEltTy);
1601 Type *VecCondTy = CmpInst::makeCmpResultType(opnd_type: SrcVecTy);
1602
1603 InstructionCost ScalarSelCost =
1604 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: DstEltTy, CondTy,
1605 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1606 InstructionCost VecSelCost =
1607 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: SrcVecTy, CondTy: VecCondTy,
1608 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1609
1610 // We need at least this many selects for vectorization to be profitable.
1611 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1612 // ScalarSelCost
1613 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1614 return false;
1615
1616 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1617
1618 // Quick check: if bitcast doesn't have enough users, bail early.
1619 if (!BC->hasNUsesOrMore(N: MinSelects))
1620 return false;
1621
1622 // Collect all select users that match the pattern, grouped by condition.
1623 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1624 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1625
1626 for (User *U : BC->users()) {
1627 auto *Ext = dyn_cast<ExtractElementInst>(Val: U);
1628 if (!Ext)
1629 continue;
1630
1631 for (User *ExtUser : Ext->users()) {
1632 Value *Cond;
1633 // Match: select i1 %cond, %ext, 0
1634 if (match(V: ExtUser, P: m_Select(C: m_Value(V&: Cond), L: m_Specific(V: Ext), R: m_Zero())) &&
1635 Cond->getType()->isIntegerTy(Bitwidth: 1))
1636 CondToSelects[Cond].push_back(Elt: cast<SelectInst>(Val: ExtUser));
1637 }
1638 }
1639
1640 if (CondToSelects.empty())
1641 return false;
1642
1643 bool MadeChange = false;
1644 Value *SrcVec = BC->getOperand(i_nocapture: 0);
1645
1646 // Process each group of selects with the same condition.
1647 for (auto [Cond, Selects] : CondToSelects) {
1648 // Only profitable if vector select cost < total scalar select cost.
1649 if (Selects.size() < MinSelects) {
1650 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1651 << "profitable (VecCost=" << VecSelCost
1652 << ", ScalarCost=" << ScalarSelCost
1653 << ", NumSelects=" << Selects.size() << ")\n");
1654 continue;
1655 }
1656
1657 // Create the vector select and bitcast once for this condition.
1658 auto InsertPt = std::next(x: BC->getIterator());
1659
1660 if (auto *CondInst = dyn_cast<Instruction>(Val: Cond))
1661 if (DT.dominates(Def: BC, User: CondInst))
1662 InsertPt = std::next(x: CondInst->getIterator());
1663
1664 Builder.SetInsertPoint(InsertPt);
1665 Value *VecSel =
1666 Builder.CreateSelect(C: Cond, True: SrcVec, False: Constant::getNullValue(Ty: SrcVecTy));
1667 Value *NewBC = Builder.CreateBitCast(V: VecSel, DestTy: DstVecTy);
1668
1669 // Replace each scalar select with an extract from the new bitcast.
1670 for (SelectInst *Sel : Selects) {
1671 auto *Ext = cast<ExtractElementInst>(Val: Sel->getTrueValue());
1672 Value *Idx = Ext->getIndexOperand();
1673
1674 Builder.SetInsertPoint(Sel);
1675 Value *NewExt = Builder.CreateExtractElement(Vec: NewBC, Idx);
1676 replaceValue(Old&: *Sel, New&: *NewExt);
1677 MadeChange = true;
1678 }
1679
1680 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1681 << " selects into vector select\n");
1682 }
1683
1684 return MadeChange;
1685}
1686
1687static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1688 TTI::TargetCostKind CostKind,
1689 const TargetTransformInfo &TTI,
1690 InstructionCost &CostBeforeReduction,
1691 InstructionCost &CostAfterReduction) {
1692 Instruction *Op0, *Op1;
1693 auto *RedOp = dyn_cast<Instruction>(Val: II.getOperand(i_nocapture: 0));
1694 auto *VecRedTy = cast<VectorType>(Val: II.getOperand(i_nocapture: 0)->getType());
1695 unsigned ReductionOpc =
1696 getArithmeticReductionInstruction(RdxID: II.getIntrinsicID());
1697 if (RedOp && match(V: RedOp, P: m_ZExtOrSExt(Op: m_Value()))) {
1698 bool IsUnsigned = isa<ZExtInst>(Val: RedOp);
1699 auto *ExtType = cast<VectorType>(Val: RedOp->getOperand(i: 0)->getType());
1700
1701 CostBeforeReduction =
1702 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: ExtType,
1703 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1704 CostAfterReduction =
1705 TTI.getExtendedReductionCost(Opcode: ReductionOpc, IsUnsigned, ResTy: II.getType(),
1706 Ty: ExtType, FMF: FastMathFlags(), CostKind);
1707 return;
1708 }
1709 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1710 match(V: RedOp,
1711 P: m_ZExtOrSExt(Op: m_Mul(L: m_Instruction(I&: Op0), R: m_Instruction(I&: Op1)))) &&
1712 match(V: Op0, P: m_ZExtOrSExt(Op: m_Value())) &&
1713 Op0->getOpcode() == Op1->getOpcode() &&
1714 Op0->getOperand(i: 0)->getType() == Op1->getOperand(i: 0)->getType() &&
1715 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1716 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1717 bool IsUnsigned = isa<ZExtInst>(Val: Op0);
1718 auto *ExtType = cast<VectorType>(Val: Op0->getOperand(i: 0)->getType());
1719 VectorType *MulType = VectorType::get(ElementType: Op0->getType(), Other: VecRedTy);
1720
1721 InstructionCost ExtCost =
1722 TTI.getCastInstrCost(Opcode: Op0->getOpcode(), Dst: MulType, Src: ExtType,
1723 CCH: TTI::CastContextHint::None, CostKind, I: Op0);
1724 InstructionCost MulCost =
1725 TTI.getArithmeticInstrCost(Opcode: Instruction::Mul, Ty: MulType, CostKind);
1726 InstructionCost Ext2Cost =
1727 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: MulType,
1728 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1729
1730 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1731 CostAfterReduction = TTI.getMulAccReductionCost(
1732 IsUnsigned, RedOpcode: ReductionOpc, ResTy: II.getType(), Ty: ExtType, CostKind);
1733 return;
1734 }
1735 CostAfterReduction = TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecRedTy,
1736 FMF: std::nullopt, CostKind);
1737}
1738
1739bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1740 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(Val: &I)->getOpcode();
1741 Intrinsic::ID ReductionIID = getReductionForBinop(Opc: BinOpOpc);
1742 if (BinOpOpc == Instruction::Sub)
1743 ReductionIID = Intrinsic::vector_reduce_add;
1744 if (ReductionIID == Intrinsic::not_intrinsic)
1745 return false;
1746
1747 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1748 Intrinsic::ID IID) -> Value * {
1749 auto *II = dyn_cast<IntrinsicInst>(Val: V);
1750 if (!II)
1751 return nullptr;
1752 if (II->getIntrinsicID() == IID && II->hasOneUse())
1753 return II->getArgOperand(i: 0);
1754 return nullptr;
1755 };
1756
1757 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 0), ReductionIID);
1758 if (!V0)
1759 return false;
1760 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 1), ReductionIID);
1761 if (!V1)
1762 return false;
1763
1764 auto *VTy = cast<VectorType>(Val: V0->getType());
1765 if (V1->getType() != VTy)
1766 return false;
1767 const auto &II0 = *cast<IntrinsicInst>(Val: I.getOperand(i: 0));
1768 const auto &II1 = *cast<IntrinsicInst>(Val: I.getOperand(i: 1));
1769 unsigned ReductionOpc =
1770 getArithmeticReductionInstruction(RdxID: II0.getIntrinsicID());
1771
1772 InstructionCost OldCost = 0;
1773 InstructionCost NewCost = 0;
1774 InstructionCost CostOfRedOperand0 = 0;
1775 InstructionCost CostOfRed0 = 0;
1776 InstructionCost CostOfRedOperand1 = 0;
1777 InstructionCost CostOfRed1 = 0;
1778 analyzeCostOfVecReduction(II: II0, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand0, CostAfterReduction&: CostOfRed0);
1779 analyzeCostOfVecReduction(II: II1, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand1, CostAfterReduction&: CostOfRed1);
1780 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(U: &I, CostKind);
1781 NewCost =
1782 CostOfRedOperand0 + CostOfRedOperand1 +
1783 TTI.getArithmeticInstrCost(Opcode: BinOpOpc, Ty: VTy, CostKind) +
1784 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VTy, FMF: std::nullopt, CostKind);
1785 if (NewCost >= OldCost || !NewCost.isValid())
1786 return false;
1787
1788 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1789 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1790 << "\n");
1791 Value *VectorBO;
1792 if (BinOpOpc == Instruction::Or)
1793 VectorBO = Builder.CreateOr(LHS: V0, RHS: V1, Name: "",
1794 IsDisjoint: cast<PossiblyDisjointInst>(Val&: I).isDisjoint());
1795 else
1796 VectorBO = Builder.CreateBinOp(Opc: BinOpOpc, LHS: V0, RHS: V1);
1797
1798 Instruction *Rdx = Builder.CreateIntrinsic(ID: ReductionIID, Types: {VTy}, Args: {VectorBO});
1799 replaceValue(Old&: I, New&: *Rdx);
1800 return true;
1801}
1802
1803// Check if memory loc modified between two instrs in the same BB
1804static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1805 BasicBlock::iterator End,
1806 const MemoryLocation &Loc, AAResults &AA) {
1807 unsigned NumScanned = 0;
1808 return std::any_of(first: Begin, last: End, pred: [&](const Instruction &Instr) {
1809 return isModSet(MRI: AA.getModRefInfo(I: &Instr, OptLoc: Loc)) ||
1810 ++NumScanned > MaxInstrsToScan;
1811 });
1812}
1813
1814namespace {
1815/// Helper class to indicate whether a vector index can be safely scalarized and
1816/// if a freeze needs to be inserted.
1817class ScalarizationResult {
1818 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1819
1820 StatusTy Status;
1821 Value *ToFreeze;
1822
1823 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1824 : Status(Status), ToFreeze(ToFreeze) {}
1825
1826public:
1827 ScalarizationResult(const ScalarizationResult &Other) = default;
1828 ~ScalarizationResult() {
1829 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1830 }
1831
1832 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1833 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1834 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1835 return {StatusTy::SafeWithFreeze, ToFreeze};
1836 }
1837
1838 /// Returns true if the index can be scalarize without requiring a freeze.
1839 bool isSafe() const { return Status == StatusTy::Safe; }
1840 /// Returns true if the index cannot be scalarized.
1841 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1842 /// Returns true if the index can be scalarize, but requires inserting a
1843 /// freeze.
1844 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1845
1846 /// Reset the state of Unsafe and clear ToFreze if set.
1847 void discard() {
1848 ToFreeze = nullptr;
1849 Status = StatusTy::Unsafe;
1850 }
1851
1852 /// Freeze the ToFreeze and update the use in \p User to use it.
1853 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1854 assert(isSafeWithFreeze() &&
1855 "should only be used when freezing is required");
1856 assert(is_contained(ToFreeze->users(), &UserI) &&
1857 "UserI must be a user of ToFreeze");
1858 IRBuilder<>::InsertPointGuard Guard(Builder);
1859 Builder.SetInsertPoint(cast<Instruction>(Val: &UserI));
1860 Value *Frozen =
1861 Builder.CreateFreeze(V: ToFreeze, Name: ToFreeze->getName() + ".frozen");
1862 for (Use &U : make_early_inc_range(Range: (UserI.operands())))
1863 if (U.get() == ToFreeze)
1864 U.set(Frozen);
1865
1866 ToFreeze = nullptr;
1867 }
1868};
1869} // namespace
1870
1871/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1872/// Idx. \p Idx must access a valid vector element.
1873static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1874 Instruction *CtxI,
1875 AssumptionCache &AC,
1876 const DominatorTree &DT) {
1877 // We do checks for both fixed vector types and scalable vector types.
1878 // This is the number of elements of fixed vector types,
1879 // or the minimum number of elements of scalable vector types.
1880 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1881 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1882
1883 if (auto *C = dyn_cast<ConstantInt>(Val: Idx)) {
1884 if (C->getValue().ult(RHS: NumElements))
1885 return ScalarizationResult::safe();
1886 return ScalarizationResult::unsafe();
1887 }
1888
1889 // Always unsafe if the index type can't handle all inbound values.
1890 if (!llvm::isUIntN(N: IntWidth, x: NumElements))
1891 return ScalarizationResult::unsafe();
1892
1893 APInt Zero(IntWidth, 0);
1894 APInt MaxElts(IntWidth, NumElements);
1895 ConstantRange ValidIndices(Zero, MaxElts);
1896 ConstantRange IdxRange(IntWidth, true);
1897
1898 if (isGuaranteedNotToBePoison(V: Idx, AC: &AC)) {
1899 if (ValidIndices.contains(CR: computeConstantRange(V: Idx, /* ForSigned */ false,
1900 UseInstrInfo: true, AC: &AC, CtxI, DT: &DT)))
1901 return ScalarizationResult::safe();
1902 return ScalarizationResult::unsafe();
1903 }
1904
1905 // If the index may be poison, check if we can insert a freeze before the
1906 // range of the index is restricted.
1907 Value *IdxBase;
1908 ConstantInt *CI;
1909 if (match(V: Idx, P: m_And(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1910 IdxRange = IdxRange.binaryAnd(Other: CI->getValue());
1911 } else if (match(V: Idx, P: m_URem(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1912 IdxRange = IdxRange.urem(Other: CI->getValue());
1913 }
1914
1915 if (ValidIndices.contains(CR: IdxRange))
1916 return ScalarizationResult::safeWithFreeze(ToFreeze: IdxBase);
1917 return ScalarizationResult::unsafe();
1918}
1919
1920/// The memory operation on a vector of \p ScalarType had alignment of
1921/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1922/// alignment that will be valid for the memory operation on a single scalar
1923/// element of the same type with index \p Idx.
1924static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1925 Type *ScalarType, Value *Idx,
1926 const DataLayout &DL) {
1927 if (auto *C = dyn_cast<ConstantInt>(Val: Idx))
1928 return commonAlignment(A: VectorAlignment,
1929 Offset: C->getZExtValue() * DL.getTypeStoreSize(Ty: ScalarType));
1930 return commonAlignment(A: VectorAlignment, Offset: DL.getTypeStoreSize(Ty: ScalarType));
1931}
1932
1933// Combine patterns like:
1934// %0 = load <4 x i32>, <4 x i32>* %a
1935// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1936// store <4 x i32> %1, <4 x i32>* %a
1937// to:
1938// %0 = bitcast <4 x i32>* %a to i32*
1939// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1940// store i32 %b, i32* %1
1941bool VectorCombine::foldSingleElementStore(Instruction &I) {
1942 if (!TTI.allowVectorElementIndexingUsingGEP())
1943 return false;
1944 auto *SI = cast<StoreInst>(Val: &I);
1945 if (!SI->isSimple() || !isa<VectorType>(Val: SI->getValueOperand()->getType()))
1946 return false;
1947
1948 // TODO: Combine more complicated patterns (multiple insert) by referencing
1949 // TargetTransformInfo.
1950 Instruction *Source;
1951 Value *NewElement;
1952 Value *Idx;
1953 if (!match(V: SI->getValueOperand(),
1954 P: m_InsertElt(Val: m_Instruction(I&: Source), Elt: m_Value(V&: NewElement),
1955 Idx: m_Value(V&: Idx))))
1956 return false;
1957
1958 if (auto *Load = dyn_cast<LoadInst>(Val: Source)) {
1959 auto VecTy = cast<VectorType>(Val: SI->getValueOperand()->getType());
1960 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1961 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1962 // modified between, vector type matches store size, and index is inbounds.
1963 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1964 !DL->typeSizeEqualsStoreSize(Ty: Load->getType()->getScalarType()) ||
1965 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1966 return false;
1967
1968 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, CtxI: Load, AC, DT);
1969 if (ScalarizableIdx.isUnsafe() ||
1970 isMemModifiedBetween(Begin: Load->getIterator(), End: SI->getIterator(),
1971 Loc: MemoryLocation::get(SI), AA))
1972 return false;
1973
1974 // Ensure we add the load back to the worklist BEFORE its users so they can
1975 // erased in the correct order.
1976 Worklist.push(I: Load);
1977
1978 if (ScalarizableIdx.isSafeWithFreeze())
1979 ScalarizableIdx.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
1980 Value *GEP = Builder.CreateInBoundsGEP(
1981 Ty: SI->getValueOperand()->getType(), Ptr: SI->getPointerOperand(),
1982 IdxList: {ConstantInt::get(Ty: Idx->getType(), V: 0), Idx});
1983 StoreInst *NSI = Builder.CreateStore(Val: NewElement, Ptr: GEP);
1984 NSI->copyMetadata(SrcInst: *SI);
1985 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1986 VectorAlignment: std::max(a: SI->getAlign(), b: Load->getAlign()), ScalarType: NewElement->getType(), Idx,
1987 DL: *DL);
1988 NSI->setAlignment(ScalarOpAlignment);
1989 replaceValue(Old&: I, New&: *NSI);
1990 eraseInstruction(I);
1991 return true;
1992 }
1993
1994 return false;
1995}
1996
1997/// Try to scalarize vector loads feeding extractelement or bitcast
1998/// instructions.
1999bool VectorCombine::scalarizeLoad(Instruction &I) {
2000 Value *Ptr;
2001 if (!match(V: &I, P: m_Load(Op: m_Value(V&: Ptr))))
2002 return false;
2003
2004 auto *LI = cast<LoadInst>(Val: &I);
2005 auto *VecTy = cast<VectorType>(Val: LI->getType());
2006 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(Ty: VecTy->getScalarType()))
2007 return false;
2008
2009 bool AllExtracts = true;
2010 bool AllBitcasts = true;
2011 Instruction *LastCheckedInst = LI;
2012 unsigned NumInstChecked = 0;
2013
2014 // Check what type of users we have (must either all be extracts or
2015 // bitcasts) and ensure no memory modifications between the load and
2016 // its users.
2017 for (User *U : LI->users()) {
2018 auto *UI = dyn_cast<Instruction>(Val: U);
2019 if (!UI || UI->getParent() != LI->getParent())
2020 return false;
2021
2022 // If any user is waiting to be erased, then bail out as this will
2023 // distort the cost calculation and possibly lead to infinite loops.
2024 if (UI->use_empty())
2025 return false;
2026
2027 if (!isa<ExtractElementInst>(Val: UI))
2028 AllExtracts = false;
2029 if (!isa<BitCastInst>(Val: UI))
2030 AllBitcasts = false;
2031
2032 // Check if any instruction between the load and the user may modify memory.
2033 if (LastCheckedInst->comesBefore(Other: UI)) {
2034 for (Instruction &I :
2035 make_range(x: std::next(x: LI->getIterator()), y: UI->getIterator())) {
2036 // Bail out if we reached the check limit or the instruction may write
2037 // to memory.
2038 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2039 return false;
2040 NumInstChecked++;
2041 }
2042 LastCheckedInst = UI;
2043 }
2044 }
2045
2046 if (AllExtracts)
2047 return scalarizeLoadExtract(LI, VecTy, Ptr);
2048 if (AllBitcasts)
2049 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2050 return false;
2051}
2052
2053/// Try to scalarize vector loads feeding extractelement instructions.
2054bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2055 Value *Ptr) {
2056 if (!TTI.allowVectorElementIndexingUsingGEP())
2057 return false;
2058
2059 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2060 llvm::scope_exit FailureGuard([&]() {
2061 // If the transform is aborted, discard the ScalarizationResults.
2062 for (auto &Pair : NeedFreeze)
2063 Pair.second.discard();
2064 });
2065
2066 InstructionCost OriginalCost =
2067 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2068 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2069 InstructionCost ScalarizedCost = 0;
2070
2071 for (User *U : LI->users()) {
2072 auto *UI = cast<ExtractElementInst>(Val: U);
2073
2074 auto ScalarIdx =
2075 canScalarizeAccess(VecTy, Idx: UI->getIndexOperand(), CtxI: LI, AC, DT);
2076 if (ScalarIdx.isUnsafe())
2077 return false;
2078 if (ScalarIdx.isSafeWithFreeze()) {
2079 NeedFreeze.try_emplace(Key: UI, Args&: ScalarIdx);
2080 ScalarIdx.discard();
2081 }
2082
2083 auto *Index = dyn_cast<ConstantInt>(Val: UI->getIndexOperand());
2084 OriginalCost +=
2085 TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: VecTy, CostKind,
2086 Index: Index ? Index->getZExtValue() : -1);
2087 ScalarizedCost +=
2088 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy->getElementType(),
2089 Alignment: Align(1), AddressSpace: LI->getPointerAddressSpace(), CostKind);
2090 ScalarizedCost += TTI.getAddressComputationCost(PtrTy: LI->getPointerOperandType(),
2091 SE: nullptr, Ptr: nullptr, CostKind);
2092 }
2093
2094 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2095 << "\n LoadExtractCost: " << OriginalCost
2096 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2097
2098 if (ScalarizedCost >= OriginalCost)
2099 return false;
2100
2101 // Ensure we add the load back to the worklist BEFORE its users so they can
2102 // erased in the correct order.
2103 Worklist.push(I: LI);
2104
2105 Type *ElemType = VecTy->getElementType();
2106
2107 // Replace extracts with narrow scalar loads.
2108 for (User *U : LI->users()) {
2109 auto *EI = cast<ExtractElementInst>(Val: U);
2110 Value *Idx = EI->getIndexOperand();
2111
2112 // Insert 'freeze' for poison indexes.
2113 auto It = NeedFreeze.find(Val: EI);
2114 if (It != NeedFreeze.end())
2115 It->second.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
2116
2117 Builder.SetInsertPoint(EI);
2118 Value *GEP =
2119 Builder.CreateInBoundsGEP(Ty: VecTy, Ptr, IdxList: {Builder.getInt32(C: 0), Idx});
2120 auto *NewLoad = cast<LoadInst>(
2121 Val: Builder.CreateLoad(Ty: ElemType, Ptr: GEP, Name: EI->getName() + ".scalar"));
2122
2123 Align ScalarOpAlignment =
2124 computeAlignmentAfterScalarization(VectorAlignment: LI->getAlign(), ScalarType: ElemType, Idx, DL: *DL);
2125 NewLoad->setAlignment(ScalarOpAlignment);
2126
2127 if (auto *ConstIdx = dyn_cast<ConstantInt>(Val: Idx)) {
2128 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(Ty: ElemType);
2129 AAMDNodes OldAAMD = LI->getAAMetadata();
2130 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, AccessTy: ElemType, DL: *DL));
2131 }
2132
2133 replaceValue(Old&: *EI, New&: *NewLoad, Erase: false);
2134 }
2135
2136 FailureGuard.release();
2137 return true;
2138}
2139
2140/// Try to scalarize vector loads feeding bitcast instructions.
2141bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2142 Value *Ptr) {
2143 InstructionCost OriginalCost =
2144 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2145 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2146
2147 Type *TargetScalarType = nullptr;
2148 unsigned VecBitWidth = DL->getTypeSizeInBits(Ty: VecTy);
2149
2150 for (User *U : LI->users()) {
2151 auto *BC = cast<BitCastInst>(Val: U);
2152
2153 Type *DestTy = BC->getDestTy();
2154 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2155 return false;
2156
2157 unsigned DestBitWidth = DL->getTypeSizeInBits(Ty: DestTy);
2158 if (DestBitWidth != VecBitWidth)
2159 return false;
2160
2161 // All bitcasts must target the same scalar type.
2162 if (!TargetScalarType)
2163 TargetScalarType = DestTy;
2164 else if (TargetScalarType != DestTy)
2165 return false;
2166
2167 OriginalCost +=
2168 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: TargetScalarType, Src: VecTy,
2169 CCH: TTI.getCastContextHint(I: BC), CostKind, I: BC);
2170 }
2171
2172 if (!TargetScalarType)
2173 return false;
2174
2175 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2176 InstructionCost ScalarizedCost =
2177 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: TargetScalarType, Alignment: LI->getAlign(),
2178 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2179
2180 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2181 << "\n OriginalCost: " << OriginalCost
2182 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2183
2184 if (ScalarizedCost >= OriginalCost)
2185 return false;
2186
2187 // Ensure we add the load back to the worklist BEFORE its users so they can
2188 // erased in the correct order.
2189 Worklist.push(I: LI);
2190
2191 Builder.SetInsertPoint(LI);
2192 auto *ScalarLoad =
2193 Builder.CreateLoad(Ty: TargetScalarType, Ptr, Name: LI->getName() + ".scalar");
2194 ScalarLoad->setAlignment(LI->getAlign());
2195 ScalarLoad->copyMetadata(SrcInst: *LI);
2196
2197 // Replace all bitcast users with the scalar load.
2198 for (User *U : LI->users()) {
2199 auto *BC = cast<BitCastInst>(Val: U);
2200 replaceValue(Old&: *BC, New&: *ScalarLoad, Erase: false);
2201 }
2202
2203 return true;
2204}
2205
2206bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2207 if (!TTI.allowVectorElementIndexingUsingGEP())
2208 return false;
2209 auto *Ext = dyn_cast<ZExtInst>(Val: &I);
2210 if (!Ext)
2211 return false;
2212
2213 // Try to convert a vector zext feeding only extracts to a set of scalar
2214 // (Src << ExtIdx *Size) & (Size -1)
2215 // if profitable .
2216 auto *SrcTy = dyn_cast<FixedVectorType>(Val: Ext->getOperand(i_nocapture: 0)->getType());
2217 if (!SrcTy)
2218 return false;
2219 auto *DstTy = cast<FixedVectorType>(Val: Ext->getType());
2220
2221 Type *ScalarDstTy = DstTy->getElementType();
2222 if (DL->getTypeSizeInBits(Ty: SrcTy) != DL->getTypeSizeInBits(Ty: ScalarDstTy))
2223 return false;
2224
2225 InstructionCost VectorCost =
2226 TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: DstTy, Src: SrcTy,
2227 CCH: TTI::CastContextHint::None, CostKind, I: Ext);
2228 unsigned ExtCnt = 0;
2229 bool ExtLane0 = false;
2230 for (User *U : Ext->users()) {
2231 uint64_t Idx;
2232 if (!match(V: U, P: m_ExtractElt(Val: m_Value(), Idx: m_ConstantInt(V&: Idx))))
2233 return false;
2234 if (cast<Instruction>(Val: U)->use_empty())
2235 continue;
2236 ExtCnt += 1;
2237 ExtLane0 |= !Idx;
2238 VectorCost += TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: DstTy,
2239 CostKind, Index: Idx, Op0: U);
2240 }
2241
2242 InstructionCost ScalarCost =
2243 ExtCnt * TTI.getArithmeticInstrCost(
2244 Opcode: Instruction::And, Ty: ScalarDstTy, CostKind,
2245 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2246 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None}) +
2247 (ExtCnt - ExtLane0) *
2248 TTI.getArithmeticInstrCost(
2249 Opcode: Instruction::LShr, Ty: ScalarDstTy, CostKind,
2250 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2251 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None});
2252 if (ScalarCost > VectorCost)
2253 return false;
2254
2255 Value *ScalarV = Ext->getOperand(i_nocapture: 0);
2256 if (!isGuaranteedNotToBePoison(V: ScalarV, AC: &AC, CtxI: dyn_cast<Instruction>(Val: ScalarV),
2257 DT: &DT)) {
2258 // Check wether all lanes are extracted, all extracts trigger UB
2259 // on poison, and the last extract (and hence all previous ones)
2260 // are guaranteed to execute if Ext executes. If so, we do not
2261 // need to insert a freeze.
2262 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2263 bool AllExtractsTriggerUB = true;
2264 ExtractElementInst *LastExtract = nullptr;
2265 BasicBlock *ExtBB = Ext->getParent();
2266 for (User *U : Ext->users()) {
2267 auto *Extract = cast<ExtractElementInst>(Val: U);
2268 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Inst: Extract)) {
2269 AllExtractsTriggerUB = false;
2270 break;
2271 }
2272 ExtractedLanes.insert(V: cast<ConstantInt>(Val: Extract->getIndexOperand()));
2273 if (!LastExtract || LastExtract->comesBefore(Other: Extract))
2274 LastExtract = Extract;
2275 }
2276 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2277 !AllExtractsTriggerUB ||
2278 !isGuaranteedToTransferExecutionToSuccessor(Begin: Ext->getIterator(),
2279 End: LastExtract->getIterator()))
2280 ScalarV = Builder.CreateFreeze(V: ScalarV);
2281 }
2282 ScalarV = Builder.CreateBitCast(
2283 V: ScalarV,
2284 DestTy: IntegerType::get(C&: SrcTy->getContext(), NumBits: DL->getTypeSizeInBits(Ty: SrcTy)));
2285 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(Ty: SrcTy->getElementType());
2286 uint64_t TotalBits = DL->getTypeSizeInBits(Ty: SrcTy);
2287 APInt EltBitMask = APInt::getLowBitsSet(numBits: TotalBits, loBitsSet: SrcEltSizeInBits);
2288 Type *PackedTy = IntegerType::get(C&: SrcTy->getContext(), NumBits: TotalBits);
2289 Value *Mask = ConstantInt::get(Ty: PackedTy, V: EltBitMask);
2290 for (User *U : Ext->users()) {
2291 auto *Extract = cast<ExtractElementInst>(Val: U);
2292 uint64_t Idx =
2293 cast<ConstantInt>(Val: Extract->getIndexOperand())->getZExtValue();
2294 uint64_t ShiftAmt =
2295 DL->isBigEndian()
2296 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2297 : (Idx * SrcEltSizeInBits);
2298 Value *LShr = Builder.CreateLShr(LHS: ScalarV, RHS: ShiftAmt);
2299 Value *And = Builder.CreateAnd(LHS: LShr, RHS: Mask);
2300 U->replaceAllUsesWith(V: And);
2301 }
2302 return true;
2303}
2304
2305/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2306/// to "(bitcast (concat X, Y))"
2307/// where X/Y are bitcasted from i1 mask vectors.
2308bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2309 Type *Ty = I.getType();
2310 if (!Ty->isIntegerTy())
2311 return false;
2312
2313 // TODO: Add big endian test coverage
2314 if (DL->isBigEndian())
2315 return false;
2316
2317 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2318 Instruction *X, *Y;
2319 if (!match(V: &I, P: m_DisjointOr(L: m_Instruction(I&: X), R: m_Instruction(I&: Y))))
2320 return false;
2321
2322 // Allow both sources to contain shl, to handle more generic pattern:
2323 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2324 Value *SrcX;
2325 uint64_t ShAmtX = 0;
2326 if (!match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX)))))) &&
2327 !match(V: X, P: m_OneUse(
2328 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX))))),
2329 R: m_ConstantInt(V&: ShAmtX)))))
2330 return false;
2331
2332 Value *SrcY;
2333 uint64_t ShAmtY = 0;
2334 if (!match(V: Y, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY)))))) &&
2335 !match(V: Y, P: m_OneUse(
2336 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY))))),
2337 R: m_ConstantInt(V&: ShAmtY)))))
2338 return false;
2339
2340 // Canonicalize larger shift to the RHS.
2341 if (ShAmtX > ShAmtY) {
2342 std::swap(a&: X, b&: Y);
2343 std::swap(a&: SrcX, b&: SrcY);
2344 std::swap(a&: ShAmtX, b&: ShAmtY);
2345 }
2346
2347 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2348 // difference is the mask width so they can be easily concatenated together.
2349 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2350 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2351 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2352 auto *MaskTy = dyn_cast<FixedVectorType>(Val: SrcX->getType());
2353 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2354 !MaskTy->getElementType()->isIntegerTy(Bitwidth: 1) ||
2355 MaskTy->getNumElements() != ShAmtDiff ||
2356 MaskTy->getNumElements() > (BitWidth / 2))
2357 return false;
2358
2359 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(VTy: MaskTy);
2360 auto *ConcatIntTy =
2361 Type::getIntNTy(C&: Ty->getContext(), N: ConcatTy->getNumElements());
2362 auto *MaskIntTy = Type::getIntNTy(C&: Ty->getContext(), N: ShAmtDiff);
2363
2364 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2365 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
2366
2367 // TODO: Is it worth supporting multi use cases?
2368 InstructionCost OldCost = 0;
2369 OldCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Or, Ty, CostKind);
2370 OldCost +=
2371 NumSHL * TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2372 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: MaskIntTy,
2373 CCH: TTI::CastContextHint::None, CostKind);
2374 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: MaskIntTy, Src: MaskTy,
2375 CCH: TTI::CastContextHint::None, CostKind);
2376
2377 InstructionCost NewCost = 0;
2378 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ConcatTy,
2379 SrcTy: MaskTy, Mask: ConcatMask, CostKind);
2380 NewCost += TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: ConcatIntTy, Src: ConcatTy,
2381 CCH: TTI::CastContextHint::None, CostKind);
2382 if (Ty != ConcatIntTy)
2383 NewCost += TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: ConcatIntTy,
2384 CCH: TTI::CastContextHint::None, CostKind);
2385 if (ShAmtX > 0)
2386 NewCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2387
2388 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2389 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2390 << "\n");
2391
2392 if (NewCost > OldCost)
2393 return false;
2394
2395 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2396 // any residual zero-extension or shifting.
2397 Value *Concat = Builder.CreateShuffleVector(V1: SrcX, V2: SrcY, Mask: ConcatMask);
2398 Worklist.pushValue(V: Concat);
2399
2400 Value *Result = Builder.CreateBitCast(V: Concat, DestTy: ConcatIntTy);
2401
2402 if (Ty != ConcatIntTy) {
2403 Worklist.pushValue(V: Result);
2404 Result = Builder.CreateZExt(V: Result, DestTy: Ty);
2405 }
2406
2407 if (ShAmtX > 0) {
2408 Worklist.pushValue(V: Result);
2409 Result = Builder.CreateShl(LHS: Result, RHS: ShAmtX);
2410 }
2411
2412 replaceValue(Old&: I, New&: *Result);
2413 return true;
2414}
2415
2416/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2417/// --> "binop (shuffle), (shuffle)".
2418bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2419 BinaryOperator *BinOp;
2420 ArrayRef<int> OuterMask;
2421 if (!match(V: &I, P: m_Shuffle(v1: m_BinOp(I&: BinOp), v2: m_Undef(), mask: m_Mask(OuterMask))))
2422 return false;
2423
2424 // Don't introduce poison into div/rem.
2425 if (BinOp->isIntDivRem() && llvm::is_contained(Range&: OuterMask, Element: PoisonMaskElem))
2426 return false;
2427
2428 Value *Op00, *Op01, *Op10, *Op11;
2429 ArrayRef<int> Mask0, Mask1;
2430 bool Match0 = match(V: BinOp->getOperand(i_nocapture: 0),
2431 P: m_Shuffle(v1: m_Value(V&: Op00), v2: m_Value(V&: Op01), mask: m_Mask(Mask0)));
2432 bool Match1 = match(V: BinOp->getOperand(i_nocapture: 1),
2433 P: m_Shuffle(v1: m_Value(V&: Op10), v2: m_Value(V&: Op11), mask: m_Mask(Mask1)));
2434 if (!Match0 && !Match1)
2435 return false;
2436
2437 Op00 = Match0 ? Op00 : BinOp->getOperand(i_nocapture: 0);
2438 Op01 = Match0 ? Op01 : BinOp->getOperand(i_nocapture: 0);
2439 Op10 = Match1 ? Op10 : BinOp->getOperand(i_nocapture: 1);
2440 Op11 = Match1 ? Op11 : BinOp->getOperand(i_nocapture: 1);
2441
2442 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2443 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2444 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: BinOp->getType());
2445 auto *Op0Ty = dyn_cast<FixedVectorType>(Val: Op00->getType());
2446 auto *Op1Ty = dyn_cast<FixedVectorType>(Val: Op10->getType());
2447 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2448 return false;
2449
2450 unsigned NumSrcElts = BinOpTy->getNumElements();
2451
2452 // Don't accept shuffles that reference the second operand in
2453 // div/rem or if its an undef arg.
2454 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(Val: I.getOperand(i: 1))) &&
2455 any_of(Range&: OuterMask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2456 return false;
2457
2458 // Merge outer / inner (or identity if no match) shuffles.
2459 SmallVector<int> NewMask0, NewMask1;
2460 for (int M : OuterMask) {
2461 if (M < 0 || M >= (int)NumSrcElts) {
2462 NewMask0.push_back(Elt: PoisonMaskElem);
2463 NewMask1.push_back(Elt: PoisonMaskElem);
2464 } else {
2465 NewMask0.push_back(Elt: Match0 ? Mask0[M] : M);
2466 NewMask1.push_back(Elt: Match1 ? Mask1[M] : M);
2467 }
2468 }
2469
2470 unsigned NumOpElts = Op0Ty->getNumElements();
2471 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2472 all_of(Range&: NewMask0, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2473 ShuffleVectorInst::isIdentityMask(Mask: NewMask0, NumSrcElts: NumOpElts);
2474 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2475 all_of(Range&: NewMask1, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2476 ShuffleVectorInst::isIdentityMask(Mask: NewMask1, NumSrcElts: NumOpElts);
2477
2478 InstructionCost NewCost = 0;
2479 // Try to merge shuffles across the binop if the new shuffles are not costly.
2480 InstructionCost BinOpCost =
2481 TTI.getArithmeticInstrCost(Opcode, Ty: BinOpTy, CostKind);
2482 InstructionCost OldCost =
2483 BinOpCost + TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
2484 DstTy: ShuffleDstTy, SrcTy: BinOpTy, Mask: OuterMask, CostKind,
2485 Index: 0, SubTp: nullptr, Args: {BinOp}, CxtI: &I);
2486 if (!BinOp->hasOneUse())
2487 NewCost += BinOpCost;
2488
2489 if (Match0) {
2490 InstructionCost Shuf0Cost = TTI.getShuffleCost(
2491 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op0Ty, Mask: Mask0, CostKind,
2492 Index: 0, SubTp: nullptr, Args: {Op00, Op01}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 0)));
2493 OldCost += Shuf0Cost;
2494 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 0)->hasOneUse())
2495 NewCost += Shuf0Cost;
2496 }
2497 if (Match1) {
2498 InstructionCost Shuf1Cost = TTI.getShuffleCost(
2499 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op1Ty, Mask: Mask1, CostKind,
2500 Index: 0, SubTp: nullptr, Args: {Op10, Op11}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 1)));
2501 OldCost += Shuf1Cost;
2502 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 1)->hasOneUse())
2503 NewCost += Shuf1Cost;
2504 }
2505
2506 NewCost += TTI.getArithmeticInstrCost(Opcode, Ty: ShuffleDstTy, CostKind);
2507
2508 if (!IsIdentity0)
2509 NewCost +=
2510 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2511 SrcTy: Op0Ty, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {Op00, Op01});
2512 if (!IsIdentity1)
2513 NewCost +=
2514 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2515 SrcTy: Op1Ty, Mask: NewMask1, CostKind, Index: 0, SubTp: nullptr, Args: {Op10, Op11});
2516
2517 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2518 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2519 << "\n");
2520
2521 // If costs are equal, still fold as we reduce instruction count.
2522 if (NewCost > OldCost)
2523 return false;
2524
2525 Value *LHS =
2526 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(V1: Op00, V2: Op01, Mask: NewMask0);
2527 Value *RHS =
2528 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(V1: Op10, V2: Op11, Mask: NewMask1);
2529 Value *NewBO = Builder.CreateBinOp(Opc: Opcode, LHS, RHS);
2530
2531 // Intersect flags from the old binops.
2532 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO))
2533 NewInst->copyIRFlags(V: BinOp);
2534
2535 Worklist.pushValue(V: LHS);
2536 Worklist.pushValue(V: RHS);
2537 replaceValue(Old&: I, New&: *NewBO);
2538 return true;
2539}
2540
2541/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2542/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2543bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2544 ArrayRef<int> OldMask;
2545 Instruction *LHS, *RHS;
2546 if (!match(V: &I, P: m_Shuffle(v1: m_OneUse(SubPattern: m_Instruction(I&: LHS)),
2547 v2: m_OneUse(SubPattern: m_Instruction(I&: RHS)), mask: m_Mask(OldMask))))
2548 return false;
2549
2550 // TODO: Add support for addlike etc.
2551 if (LHS->getOpcode() != RHS->getOpcode())
2552 return false;
2553
2554 Value *X, *Y, *Z, *W;
2555 bool IsCommutative = false;
2556 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2557 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2558 if (match(V: LHS, P: m_BinOp(L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2559 match(V: RHS, P: m_BinOp(L: m_Value(V&: Z), R: m_Value(V&: W)))) {
2560 auto *BO = cast<BinaryOperator>(Val: LHS);
2561 // Don't introduce poison into div/rem.
2562 if (llvm::is_contained(Range&: OldMask, Element: PoisonMaskElem) && BO->isIntDivRem())
2563 return false;
2564 IsCommutative = BinaryOperator::isCommutative(Opcode: BO->getOpcode());
2565 } else if (match(V: LHS, P: m_Cmp(Pred&: PredLHS, L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2566 match(V: RHS, P: m_Cmp(Pred&: PredRHS, L: m_Value(V&: Z), R: m_Value(V&: W))) &&
2567 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2568 IsCommutative = cast<CmpInst>(Val: LHS)->isCommutative();
2569 } else
2570 return false;
2571
2572 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2573 auto *BinResTy = dyn_cast<FixedVectorType>(Val: LHS->getType());
2574 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: X->getType());
2575 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2576 return false;
2577
2578 unsigned NumSrcElts = BinOpTy->getNumElements();
2579
2580 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2581 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2582 std::swap(a&: X, b&: Y);
2583
2584 auto ConvertToUnary = [NumSrcElts](int &M) {
2585 if (M >= (int)NumSrcElts)
2586 M -= NumSrcElts;
2587 };
2588
2589 SmallVector<int> NewMask0(OldMask);
2590 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2591 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Y: Z);
2592 if (X == Z) {
2593 llvm::for_each(Range&: NewMask0, F: ConvertToUnary);
2594 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2595 Z = PoisonValue::get(T: BinOpTy);
2596 }
2597
2598 SmallVector<int> NewMask1(OldMask);
2599 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2600 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(X: Y, Y: W);
2601 if (Y == W) {
2602 llvm::for_each(Range&: NewMask1, F: ConvertToUnary);
2603 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2604 W = PoisonValue::get(T: BinOpTy);
2605 }
2606
2607 // Try to replace a binop with a shuffle if the shuffle is not costly.
2608 InstructionCost OldCost =
2609 TTI.getInstructionCost(U: LHS, CostKind) +
2610 TTI.getInstructionCost(U: RHS, CostKind) +
2611 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2612 SrcTy: BinResTy, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {LHS, RHS},
2613 CxtI: &I);
2614
2615 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2616 // where one use shuffles have gotten split across the binop/cmp. These
2617 // often allow a major reduction in total cost that wouldn't happen as
2618 // individual folds.
2619 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2620 TTI::TargetCostKind CostKind) -> bool {
2621 Value *InnerOp;
2622 ArrayRef<int> InnerMask;
2623 if (match(V: Op, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: InnerOp), v2: m_Undef(),
2624 mask: m_Mask(InnerMask)))) &&
2625 InnerOp->getType() == Op->getType() &&
2626 all_of(Range&: InnerMask,
2627 P: [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2628 for (int &M : Mask)
2629 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2630 M = InnerMask[M - Offset];
2631 M = 0 <= M ? M + Offset : M;
2632 }
2633 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: Op), CostKind);
2634 Op = InnerOp;
2635 return true;
2636 }
2637 return false;
2638 };
2639 bool ReducedInstCount = false;
2640 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2641 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2642 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2643 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2644 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2645 ReducedInstCount |= SingleSrcBinOp;
2646
2647 auto *ShuffleCmpTy =
2648 FixedVectorType::get(ElementType: BinOpTy->getElementType(), FVTy: ShuffleDstTy);
2649 InstructionCost NewCost = TTI.getShuffleCost(
2650 Kind: SK0, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {X, Z});
2651 if (!SingleSrcBinOp)
2652 NewCost += TTI.getShuffleCost(Kind: SK1, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask1,
2653 CostKind, Index: 0, SubTp: nullptr, Args: {Y, W});
2654
2655 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2656 NewCost += TTI.getArithmeticInstrCost(Opcode: LHS->getOpcode(), Ty: ShuffleDstTy,
2657 CostKind, Opd1Info: Op0Info, Opd2Info: Op1Info);
2658 } else {
2659 NewCost +=
2660 TTI.getCmpSelInstrCost(Opcode: LHS->getOpcode(), ValTy: ShuffleCmpTy, CondTy: ShuffleDstTy,
2661 VecPred: PredLHS, CostKind, Op1Info: Op0Info, Op2Info: Op1Info);
2662 }
2663
2664 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2665 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2666 << "\n");
2667
2668 // If either shuffle will constant fold away, then fold for the same cost as
2669 // we will reduce the instruction count.
2670 ReducedInstCount |= (isa<Constant>(Val: X) && isa<Constant>(Val: Z)) ||
2671 (isa<Constant>(Val: Y) && isa<Constant>(Val: W));
2672 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2673 return false;
2674
2675 Value *Shuf0 = Builder.CreateShuffleVector(V1: X, V2: Z, Mask: NewMask0);
2676 Value *Shuf1 =
2677 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(V1: Y, V2: W, Mask: NewMask1);
2678 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2679 ? Builder.CreateBinOp(
2680 Opc: cast<BinaryOperator>(Val: LHS)->getOpcode(), LHS: Shuf0, RHS: Shuf1)
2681 : Builder.CreateCmp(Pred: PredLHS, LHS: Shuf0, RHS: Shuf1);
2682
2683 // Intersect flags from the old binops.
2684 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
2685 NewInst->copyIRFlags(V: LHS);
2686 NewInst->andIRFlags(V: RHS);
2687 }
2688
2689 Worklist.pushValue(V: Shuf0);
2690 Worklist.pushValue(V: Shuf1);
2691 replaceValue(Old&: I, New&: *NewBO);
2692 return true;
2693}
2694
2695/// Try to convert,
2696/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2697/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2698bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2699 ArrayRef<int> Mask;
2700 Value *C1, *T1, *F1, *C2, *T2, *F2;
2701 if (!match(V: &I, P: m_Shuffle(v1: m_Select(C: m_Value(V&: C1), L: m_Value(V&: T1), R: m_Value(V&: F1)),
2702 v2: m_Select(C: m_Value(V&: C2), L: m_Value(V&: T2), R: m_Value(V&: F2)),
2703 mask: m_Mask(Mask))))
2704 return false;
2705
2706 auto *Sel1 = cast<Instruction>(Val: I.getOperand(i: 0));
2707 auto *Sel2 = cast<Instruction>(Val: I.getOperand(i: 1));
2708
2709 auto *C1VecTy = dyn_cast<FixedVectorType>(Val: C1->getType());
2710 auto *C2VecTy = dyn_cast<FixedVectorType>(Val: C2->getType());
2711 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2712 return false;
2713
2714 auto *SI0FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 0));
2715 auto *SI1FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 1));
2716 // SelectInsts must have the same FMF.
2717 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2718 ((SI0FOp != nullptr) &&
2719 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2720 return false;
2721
2722 auto *SrcVecTy = cast<FixedVectorType>(Val: T1->getType());
2723 auto *DstVecTy = cast<FixedVectorType>(Val: I.getType());
2724 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2725 auto SelOp = Instruction::Select;
2726
2727 InstructionCost CostSel1 = TTI.getCmpSelInstrCost(
2728 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C1VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2729 InstructionCost CostSel2 = TTI.getCmpSelInstrCost(
2730 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C2VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2731
2732 InstructionCost OldCost =
2733 CostSel1 + CostSel2 +
2734 TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
2735 Args: {I.getOperand(i: 0), I.getOperand(i: 1)}, CxtI: &I);
2736
2737 InstructionCost NewCost = TTI.getShuffleCost(
2738 Kind: SK, DstTy: FixedVectorType::get(ElementType: C1VecTy->getScalarType(), NumElts: Mask.size()), SrcTy: C1VecTy,
2739 Mask, CostKind, Index: 0, SubTp: nullptr, Args: {C1, C2});
2740 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2741 SubTp: nullptr, Args: {T1, T2});
2742 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2743 SubTp: nullptr, Args: {F1, F2});
2744 auto *C1C2ShuffledVecTy = cast<FixedVectorType>(
2745 Val: toVectorTy(Scalar: Type::getInt1Ty(C&: I.getContext()), VF: DstVecTy->getNumElements()));
2746 NewCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: DstVecTy, CondTy: C1C2ShuffledVecTy,
2747 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2748
2749 if (!Sel1->hasOneUse())
2750 NewCost += CostSel1;
2751 if (!Sel2->hasOneUse())
2752 NewCost += CostSel2;
2753
2754 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2755 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2756 << "\n");
2757 if (NewCost > OldCost)
2758 return false;
2759
2760 Value *ShuffleCmp = Builder.CreateShuffleVector(V1: C1, V2: C2, Mask);
2761 Value *ShuffleTrue = Builder.CreateShuffleVector(V1: T1, V2: T2, Mask);
2762 Value *ShuffleFalse = Builder.CreateShuffleVector(V1: F1, V2: F2, Mask);
2763 Value *NewSel;
2764 // We presuppose that the SelectInsts have the same FMF.
2765 if (SI0FOp)
2766 NewSel = Builder.CreateSelectFMF(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse,
2767 FMFSource: SI0FOp->getFastMathFlags());
2768 else
2769 NewSel = Builder.CreateSelect(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse);
2770
2771 Worklist.pushValue(V: ShuffleCmp);
2772 Worklist.pushValue(V: ShuffleTrue);
2773 Worklist.pushValue(V: ShuffleFalse);
2774 replaceValue(Old&: I, New&: *NewSel);
2775 return true;
2776}
2777
2778/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2779/// into "castop (shuffle)".
2780bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2781 Value *V0, *V1;
2782 ArrayRef<int> OldMask;
2783 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
2784 return false;
2785
2786 // Check whether this is a binary shuffle.
2787 bool IsBinaryShuffle = !isa<UndefValue>(Val: V1);
2788
2789 auto *C0 = dyn_cast<CastInst>(Val: V0);
2790 auto *C1 = dyn_cast<CastInst>(Val: V1);
2791 if (!C0 || (IsBinaryShuffle && !C1))
2792 return false;
2793
2794 Instruction::CastOps Opcode = C0->getOpcode();
2795
2796 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2797 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2798 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2799 return false;
2800
2801 if (IsBinaryShuffle) {
2802 if (C0->getSrcTy() != C1->getSrcTy())
2803 return false;
2804 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2805 if (Opcode != C1->getOpcode()) {
2806 if (match(V: C0, P: m_SExtLike(Op: m_Value())) && match(V: C1, P: m_SExtLike(Op: m_Value())))
2807 Opcode = Instruction::SExt;
2808 else
2809 return false;
2810 }
2811 }
2812
2813 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2814 auto *CastDstTy = dyn_cast<FixedVectorType>(Val: C0->getDestTy());
2815 auto *CastSrcTy = dyn_cast<FixedVectorType>(Val: C0->getSrcTy());
2816 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2817 return false;
2818
2819 unsigned NumSrcElts = CastSrcTy->getNumElements();
2820 unsigned NumDstElts = CastDstTy->getNumElements();
2821 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2822 "Only bitcasts expected to alter src/dst element counts");
2823
2824 // Check for bitcasting of unscalable vector types.
2825 // e.g. <32 x i40> -> <40 x i32>
2826 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2827 (NumDstElts % NumSrcElts) != 0)
2828 return false;
2829
2830 SmallVector<int, 16> NewMask;
2831 if (NumSrcElts >= NumDstElts) {
2832 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2833 // always be expanded to the equivalent form choosing narrower elements.
2834 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2835 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2836 narrowShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask);
2837 } else {
2838 // The bitcast is from narrow elements to wide elements. The shuffle mask
2839 // must choose consecutive elements to allow casting first.
2840 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2841 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2842 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask))
2843 return false;
2844 }
2845
2846 auto *NewShuffleDstTy =
2847 FixedVectorType::get(ElementType: CastSrcTy->getScalarType(), NumElts: NewMask.size());
2848
2849 // Try to replace a castop with a shuffle if the shuffle is not costly.
2850 InstructionCost CostC0 =
2851 TTI.getCastInstrCost(Opcode: C0->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2852 CCH: TTI::CastContextHint::None, CostKind);
2853
2854 TargetTransformInfo::ShuffleKind ShuffleKind;
2855 if (IsBinaryShuffle)
2856 ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2857 else
2858 ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2859
2860 InstructionCost OldCost = CostC0;
2861 OldCost += TTI.getShuffleCost(Kind: ShuffleKind, DstTy: ShuffleDstTy, SrcTy: CastDstTy, Mask: OldMask,
2862 CostKind, Index: 0, SubTp: nullptr, Args: {}, CxtI: &I);
2863
2864 InstructionCost NewCost = TTI.getShuffleCost(Kind: ShuffleKind, DstTy: NewShuffleDstTy,
2865 SrcTy: CastSrcTy, Mask: NewMask, CostKind);
2866 NewCost += TTI.getCastInstrCost(Opcode, Dst: ShuffleDstTy, Src: NewShuffleDstTy,
2867 CCH: TTI::CastContextHint::None, CostKind);
2868 if (!C0->hasOneUse())
2869 NewCost += CostC0;
2870 if (IsBinaryShuffle) {
2871 InstructionCost CostC1 =
2872 TTI.getCastInstrCost(Opcode: C1->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2873 CCH: TTI::CastContextHint::None, CostKind);
2874 OldCost += CostC1;
2875 if (!C1->hasOneUse())
2876 NewCost += CostC1;
2877 }
2878
2879 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2880 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2881 << "\n");
2882 if (NewCost > OldCost)
2883 return false;
2884
2885 Value *Shuf;
2886 if (IsBinaryShuffle)
2887 Shuf = Builder.CreateShuffleVector(V1: C0->getOperand(i_nocapture: 0), V2: C1->getOperand(i_nocapture: 0),
2888 Mask: NewMask);
2889 else
2890 Shuf = Builder.CreateShuffleVector(V: C0->getOperand(i_nocapture: 0), Mask: NewMask);
2891
2892 Value *Cast = Builder.CreateCast(Op: Opcode, V: Shuf, DestTy: ShuffleDstTy);
2893
2894 // Intersect flags from the old casts.
2895 if (auto *NewInst = dyn_cast<Instruction>(Val: Cast)) {
2896 NewInst->copyIRFlags(V: C0);
2897 if (IsBinaryShuffle)
2898 NewInst->andIRFlags(V: C1);
2899 }
2900
2901 Worklist.pushValue(V: Shuf);
2902 replaceValue(Old&: I, New&: *Cast);
2903 return true;
2904}
2905
2906/// Try to convert any of:
2907/// "shuffle (shuffle x, y), (shuffle y, x)"
2908/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2909/// "shuffle (shuffle x, undef), y"
2910/// "shuffle x, (shuffle y, undef)"
2911/// into "shuffle x, y".
2912bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2913 ArrayRef<int> OuterMask;
2914 Value *OuterV0, *OuterV1;
2915 if (!match(V: &I,
2916 P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1), mask: m_Mask(OuterMask))))
2917 return false;
2918
2919 ArrayRef<int> InnerMask0, InnerMask1;
2920 Value *X0, *X1, *Y0, *Y1;
2921 bool Match0 =
2922 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: X0), v2: m_Value(V&: Y0), mask: m_Mask(InnerMask0)));
2923 bool Match1 =
2924 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: X1), v2: m_Value(V&: Y1), mask: m_Mask(InnerMask1)));
2925 if (!Match0 && !Match1)
2926 return false;
2927
2928 // If the outer shuffle is a permute, then create a fake inner all-poison
2929 // shuffle. This is easier than accounting for length-changing shuffles below.
2930 SmallVector<int, 16> PoisonMask1;
2931 if (!Match1 && isa<PoisonValue>(Val: OuterV1)) {
2932 X1 = X0;
2933 Y1 = Y0;
2934 PoisonMask1.append(NumInputs: InnerMask0.size(), Elt: PoisonMaskElem);
2935 InnerMask1 = PoisonMask1;
2936 Match1 = true; // fake match
2937 }
2938
2939 X0 = Match0 ? X0 : OuterV0;
2940 Y0 = Match0 ? Y0 : OuterV0;
2941 X1 = Match1 ? X1 : OuterV1;
2942 Y1 = Match1 ? Y1 : OuterV1;
2943 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2944 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(Val: X0->getType());
2945 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(Val: OuterV0->getType());
2946 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2947 X0->getType() != X1->getType())
2948 return false;
2949
2950 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2951 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2952
2953 // Attempt to merge shuffles, matching upto 2 source operands.
2954 // Replace index to a poison arg with PoisonMaskElem.
2955 // Bail if either inner masks reference an undef arg.
2956 SmallVector<int, 16> NewMask(OuterMask);
2957 Value *NewX = nullptr, *NewY = nullptr;
2958 for (int &M : NewMask) {
2959 Value *Src = nullptr;
2960 if (0 <= M && M < (int)NumImmElts) {
2961 Src = OuterV0;
2962 if (Match0) {
2963 M = InnerMask0[M];
2964 Src = M >= (int)NumSrcElts ? Y0 : X0;
2965 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2966 }
2967 } else if (M >= (int)NumImmElts) {
2968 Src = OuterV1;
2969 M -= NumImmElts;
2970 if (Match1) {
2971 M = InnerMask1[M];
2972 Src = M >= (int)NumSrcElts ? Y1 : X1;
2973 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2974 }
2975 }
2976 if (Src && M != PoisonMaskElem) {
2977 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
2978 if (isa<UndefValue>(Val: Src)) {
2979 // We've referenced an undef element - if its poison, update the shuffle
2980 // mask, else bail.
2981 if (!isa<PoisonValue>(Val: Src))
2982 return false;
2983 M = PoisonMaskElem;
2984 continue;
2985 }
2986 if (!NewX || NewX == Src) {
2987 NewX = Src;
2988 continue;
2989 }
2990 if (!NewY || NewY == Src) {
2991 M += NumSrcElts;
2992 NewY = Src;
2993 continue;
2994 }
2995 return false;
2996 }
2997 }
2998
2999 if (!NewX)
3000 return PoisonValue::get(T: ShuffleDstTy);
3001 if (!NewY)
3002 NewY = PoisonValue::get(T: ShuffleSrcTy);
3003
3004 // Have we folded to an Identity shuffle?
3005 if (ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts)) {
3006 replaceValue(Old&: I, New&: *NewX);
3007 return true;
3008 }
3009
3010 // Try to merge the shuffles if the new shuffle is not costly.
3011 InstructionCost InnerCost0 = 0;
3012 if (Match0)
3013 InnerCost0 = TTI.getInstructionCost(U: cast<User>(Val: OuterV0), CostKind);
3014
3015 InstructionCost InnerCost1 = 0;
3016 if (Match1)
3017 InnerCost1 = TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3018
3019 InstructionCost OuterCost = TTI.getInstructionCost(U: &I, CostKind);
3020
3021 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3022
3023 bool IsUnary = all_of(Range&: NewMask, P: [&](int M) { return M < (int)NumSrcElts; });
3024 TargetTransformInfo::ShuffleKind SK =
3025 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
3026 : TargetTransformInfo::SK_PermuteTwoSrc;
3027 InstructionCost NewCost =
3028 TTI.getShuffleCost(Kind: SK, DstTy: ShuffleDstTy, SrcTy: ShuffleSrcTy, Mask: NewMask, CostKind, Index: 0,
3029 SubTp: nullptr, Args: {NewX, NewY});
3030 if (!OuterV0->hasOneUse())
3031 NewCost += InnerCost0;
3032 if (!OuterV1->hasOneUse())
3033 NewCost += InnerCost1;
3034
3035 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3036 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3037 << "\n");
3038 if (NewCost > OldCost)
3039 return false;
3040
3041 Value *Shuf = Builder.CreateShuffleVector(V1: NewX, V2: NewY, Mask: NewMask);
3042 replaceValue(Old&: I, New&: *Shuf);
3043 return true;
3044}
3045
3046/// Try to convert a chain of length-preserving shuffles that are fed by
3047/// length-changing shuffles from the same source, e.g. a chain of length 3:
3048///
3049/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3050/// (shuffle y, undef)),
3051// (shuffle y, undef)"
3052///
3053/// into a single shuffle fed by a length-changing shuffle:
3054///
3055/// "shuffle x, (shuffle y, undef)"
3056///
3057/// Such chains arise e.g. from folding extract/insert sequences.
3058bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3059 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(Val: I.getType());
3060 if (!TrunkType)
3061 return false;
3062
3063 unsigned ChainLength = 0;
3064 SmallVector<int> Mask;
3065 SmallVector<int> YMask;
3066 InstructionCost OldCost = 0;
3067 InstructionCost NewCost = 0;
3068 Value *Trunk = &I;
3069 unsigned NumTrunkElts = TrunkType->getNumElements();
3070 Value *Y = nullptr;
3071
3072 for (;;) {
3073 // Match the current trunk against (commutations of) the pattern
3074 // "shuffle trunk', (shuffle y, undef)"
3075 ArrayRef<int> OuterMask;
3076 Value *OuterV0, *OuterV1;
3077 if (ChainLength != 0 && !Trunk->hasOneUse())
3078 break;
3079 if (!match(V: Trunk, P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1),
3080 mask: m_Mask(OuterMask))))
3081 break;
3082 if (OuterV0->getType() != TrunkType) {
3083 // This shuffle is not length-preserving, so it cannot be part of the
3084 // chain.
3085 break;
3086 }
3087
3088 ArrayRef<int> InnerMask0, InnerMask1;
3089 Value *A0, *A1, *B0, *B1;
3090 bool Match0 =
3091 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: A0), v2: m_Value(V&: B0), mask: m_Mask(InnerMask0)));
3092 bool Match1 =
3093 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: A1), v2: m_Value(V&: B1), mask: m_Mask(InnerMask1)));
3094 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3095 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3096 if (Match0Leaf == Match1Leaf) {
3097 // Only handle the case of exactly one leaf in each step. The "two leaves"
3098 // case is handled by foldShuffleOfShuffles.
3099 break;
3100 }
3101
3102 SmallVector<int> CommutedOuterMask;
3103 if (Match0Leaf) {
3104 std::swap(a&: OuterV0, b&: OuterV1);
3105 std::swap(a&: InnerMask0, b&: InnerMask1);
3106 std::swap(a&: A0, b&: A1);
3107 std::swap(a&: B0, b&: B1);
3108 llvm::append_range(C&: CommutedOuterMask, R&: OuterMask);
3109 for (int &M : CommutedOuterMask) {
3110 if (M == PoisonMaskElem)
3111 continue;
3112 if (M < (int)NumTrunkElts)
3113 M += NumTrunkElts;
3114 else
3115 M -= NumTrunkElts;
3116 }
3117 OuterMask = CommutedOuterMask;
3118 }
3119 if (!OuterV1->hasOneUse())
3120 break;
3121
3122 if (!isa<UndefValue>(Val: A1)) {
3123 if (!Y)
3124 Y = A1;
3125 else if (Y != A1)
3126 break;
3127 }
3128 if (!isa<UndefValue>(Val: B1)) {
3129 if (!Y)
3130 Y = B1;
3131 else if (Y != B1)
3132 break;
3133 }
3134
3135 auto *YType = cast<FixedVectorType>(Val: A1->getType());
3136 int NumLeafElts = YType->getNumElements();
3137 SmallVector<int> LocalYMask(InnerMask1);
3138 for (int &M : LocalYMask) {
3139 if (M >= NumLeafElts)
3140 M -= NumLeafElts;
3141 }
3142
3143 InstructionCost LocalOldCost =
3144 TTI.getInstructionCost(U: cast<User>(Val: Trunk), CostKind) +
3145 TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3146
3147 // Handle the initial (start of chain) case.
3148 if (!ChainLength) {
3149 Mask.assign(AR: OuterMask);
3150 YMask.assign(RHS: LocalYMask);
3151 OldCost = NewCost = LocalOldCost;
3152 Trunk = OuterV0;
3153 ChainLength++;
3154 continue;
3155 }
3156
3157 // For the non-root case, first attempt to combine masks.
3158 SmallVector<int> NewYMask(YMask);
3159 bool Valid = true;
3160 for (auto [CombinedM, LeafM] : llvm::zip(t&: NewYMask, u&: LocalYMask)) {
3161 if (LeafM == -1 || CombinedM == LeafM)
3162 continue;
3163 if (CombinedM == -1) {
3164 CombinedM = LeafM;
3165 } else {
3166 Valid = false;
3167 break;
3168 }
3169 }
3170 if (!Valid)
3171 break;
3172
3173 SmallVector<int> NewMask;
3174 NewMask.reserve(N: NumTrunkElts);
3175 for (int M : Mask) {
3176 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3177 NewMask.push_back(Elt: M);
3178 else
3179 NewMask.push_back(Elt: OuterMask[M]);
3180 }
3181
3182 // Break the chain if adding this new step complicates the shuffles such
3183 // that it would increase the new cost by more than the old cost of this
3184 // step.
3185 InstructionCost LocalNewCost =
3186 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: TrunkType,
3187 SrcTy: YType, Mask: NewYMask, CostKind) +
3188 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: TrunkType,
3189 SrcTy: TrunkType, Mask: NewMask, CostKind);
3190
3191 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3192 break;
3193
3194 LLVM_DEBUG({
3195 if (ChainLength == 1) {
3196 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3197 << I << '\n';
3198 }
3199 dbgs() << " next chain link: " << *Trunk << '\n'
3200 << " old cost: " << (OldCost + LocalOldCost)
3201 << " new cost: " << LocalNewCost << '\n';
3202 });
3203
3204 Mask = NewMask;
3205 YMask = NewYMask;
3206 OldCost += LocalOldCost;
3207 NewCost = LocalNewCost;
3208 Trunk = OuterV0;
3209 ChainLength++;
3210 }
3211 if (ChainLength <= 1)
3212 return false;
3213
3214 if (llvm::all_of(Range&: Mask, P: [&](int M) {
3215 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3216 })) {
3217 // Produce a canonical simplified form if all elements are sourced from Y.
3218 for (int &M : Mask) {
3219 if (M >= static_cast<int>(NumTrunkElts))
3220 M = YMask[M - NumTrunkElts];
3221 }
3222 Value *Root =
3223 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask);
3224 replaceValue(Old&: I, New&: *Root);
3225 return true;
3226 }
3227
3228 Value *Leaf =
3229 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask: YMask);
3230 Value *Root = Builder.CreateShuffleVector(V1: Trunk, V2: Leaf, Mask);
3231 replaceValue(Old&: I, New&: *Root);
3232 return true;
3233}
3234
3235/// Try to convert
3236/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3237bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3238 Value *V0, *V1;
3239 ArrayRef<int> OldMask;
3240 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
3241 return false;
3242
3243 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3244 auto *II1 = dyn_cast<IntrinsicInst>(Val: V1);
3245 if (!II0 || !II1)
3246 return false;
3247
3248 Intrinsic::ID IID = II0->getIntrinsicID();
3249 if (IID != II1->getIntrinsicID())
3250 return false;
3251 InstructionCost CostII0 =
3252 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3253 InstructionCost CostII1 =
3254 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II1), CostKind);
3255
3256 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3257 auto *II0Ty = dyn_cast<FixedVectorType>(Val: II0->getType());
3258 if (!ShuffleDstTy || !II0Ty)
3259 return false;
3260
3261 if (!isTriviallyVectorizable(ID: IID))
3262 return false;
3263
3264 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3265 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI) &&
3266 II0->getArgOperand(i: I) != II1->getArgOperand(i: I))
3267 return false;
3268
3269 InstructionCost OldCost =
3270 CostII0 + CostII1 +
3271 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
3272 SrcTy: II0Ty, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {II0, II1}, CxtI: &I);
3273
3274 SmallVector<Type *> NewArgsTy;
3275 InstructionCost NewCost = 0;
3276 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3277 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3278 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3279 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3280 } else {
3281 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3282 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3283 NumElts: ShuffleDstTy->getNumElements());
3284 NewArgsTy.push_back(Elt: ArgTy);
3285 std::pair<Value *, Value *> OperandPair =
3286 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3287 if (!SeenOperandPairs.insert(V: OperandPair).second) {
3288 // We've already computed the cost for this operand pair.
3289 continue;
3290 }
3291 NewCost += TTI.getShuffleCost(
3292 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ArgTy, SrcTy: VecTy, Mask: OldMask,
3293 CostKind, Index: 0, SubTp: nullptr, Args: {II0->getArgOperand(i: I), II1->getArgOperand(i: I)});
3294 }
3295 }
3296 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3297
3298 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3299 if (!II0->hasOneUse())
3300 NewCost += CostII0;
3301 if (II1 != II0 && !II1->hasOneUse())
3302 NewCost += CostII1;
3303
3304 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3305 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3306 << "\n");
3307
3308 if (NewCost > OldCost)
3309 return false;
3310
3311 SmallVector<Value *> NewArgs;
3312 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3313 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3314 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3315 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3316 } else {
3317 std::pair<Value *, Value *> OperandPair =
3318 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3319 auto It = ShuffleCache.find(Val: OperandPair);
3320 if (It != ShuffleCache.end()) {
3321 // Reuse previously created shuffle for this operand pair.
3322 NewArgs.push_back(Elt: It->second);
3323 continue;
3324 }
3325 Value *Shuf = Builder.CreateShuffleVector(V1: II0->getArgOperand(i: I),
3326 V2: II1->getArgOperand(i: I), Mask: OldMask);
3327 ShuffleCache[OperandPair] = Shuf;
3328 NewArgs.push_back(Elt: Shuf);
3329 Worklist.pushValue(V: Shuf);
3330 }
3331 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3332
3333 // Intersect flags from the old intrinsics.
3334 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic)) {
3335 NewInst->copyIRFlags(V: II0);
3336 NewInst->andIRFlags(V: II1);
3337 }
3338
3339 replaceValue(Old&: I, New&: *NewIntrinsic);
3340 return true;
3341}
3342
3343/// Try to convert
3344/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3345bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3346 Value *V0;
3347 ArrayRef<int> Mask;
3348 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Undef(), mask: m_Mask(Mask))))
3349 return false;
3350
3351 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3352 if (!II0)
3353 return false;
3354
3355 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3356 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(Val: II0->getType());
3357 if (!ShuffleDstTy || !IntrinsicSrcTy)
3358 return false;
3359
3360 // Validate it's a pure permute, mask should only reference the first vector
3361 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3362 if (any_of(Range&: Mask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3363 return false;
3364
3365 Intrinsic::ID IID = II0->getIntrinsicID();
3366 if (!isTriviallyVectorizable(ID: IID))
3367 return false;
3368
3369 // Cost analysis
3370 InstructionCost IntrinsicCost =
3371 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3372 InstructionCost OldCost =
3373 IntrinsicCost +
3374 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: ShuffleDstTy,
3375 SrcTy: IntrinsicSrcTy, Mask, CostKind, Index: 0, SubTp: nullptr, Args: {V0}, CxtI: &I);
3376
3377 SmallVector<Type *> NewArgsTy;
3378 InstructionCost NewCost = 0;
3379 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3380 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3381 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3382 } else {
3383 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3384 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3385 NumElts: ShuffleDstTy->getNumElements());
3386 NewArgsTy.push_back(Elt: ArgTy);
3387 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
3388 DstTy: ArgTy, SrcTy: VecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
3389 Args: {II0->getArgOperand(i: I)});
3390 }
3391 }
3392 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3393 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3394
3395 // If the intrinsic has multiple uses, we need to account for the cost of
3396 // keeping the original intrinsic around.
3397 if (!II0->hasOneUse())
3398 NewCost += IntrinsicCost;
3399
3400 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3401 << OldCost << " vs NewCost: " << NewCost << "\n");
3402
3403 if (NewCost > OldCost)
3404 return false;
3405
3406 // Transform
3407 SmallVector<Value *> NewArgs;
3408 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3409 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3410 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3411 } else {
3412 Value *Shuf = Builder.CreateShuffleVector(V: II0->getArgOperand(i: I), Mask);
3413 NewArgs.push_back(Elt: Shuf);
3414 Worklist.pushValue(V: Shuf);
3415 }
3416 }
3417
3418 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3419
3420 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic))
3421 NewInst->copyIRFlags(V: II0);
3422
3423 replaceValue(Old&: I, New&: *NewIntrinsic);
3424 return true;
3425}
3426
3427using InstLane = std::pair<Use *, int>;
3428
3429static InstLane lookThroughShuffles(Use *U, int Lane) {
3430 while (auto *SV = dyn_cast<ShuffleVectorInst>(Val: U->get())) {
3431 unsigned NumElts =
3432 cast<FixedVectorType>(Val: SV->getOperand(i_nocapture: 0)->getType())->getNumElements();
3433 int M = SV->getMaskValue(Elt: Lane);
3434 if (M < 0)
3435 return {nullptr, PoisonMaskElem};
3436 if (static_cast<unsigned>(M) < NumElts) {
3437 U = &SV->getOperandUse(i: 0);
3438 Lane = M;
3439 } else {
3440 U = &SV->getOperandUse(i: 1);
3441 Lane = M - NumElts;
3442 }
3443 }
3444 return InstLane{U, Lane};
3445}
3446
3447static SmallVector<InstLane>
3448generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
3449 SmallVector<InstLane> NItem;
3450 for (InstLane IL : Item) {
3451 auto [U, Lane] = IL;
3452 InstLane OpLane =
3453 U ? lookThroughShuffles(U: &cast<Instruction>(Val: U->get())->getOperandUse(i: Op),
3454 Lane)
3455 : InstLane{nullptr, PoisonMaskElem};
3456 NItem.emplace_back(Args&: OpLane);
3457 }
3458 return NItem;
3459}
3460
3461/// Detect concat of multiple values into a vector
3462static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
3463 const TargetTransformInfo &TTI) {
3464 auto *Ty = cast<FixedVectorType>(Val: Item.front().first->get()->getType());
3465 unsigned NumElts = Ty->getNumElements();
3466 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3467 return false;
3468
3469 // Check that the concat is free, usually meaning that the type will be split
3470 // during legalization.
3471 SmallVector<int, 16> ConcatMask(NumElts * 2);
3472 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
3473 if (TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc,
3474 DstTy: FixedVectorType::get(ElementType: Ty->getScalarType(), NumElts: NumElts * 2),
3475 SrcTy: Ty, Mask: ConcatMask, CostKind) != 0)
3476 return false;
3477
3478 unsigned NumSlices = Item.size() / NumElts;
3479 // Currently we generate a tree of shuffles for the concats, which limits us
3480 // to a power2.
3481 if (!isPowerOf2_32(Value: NumSlices))
3482 return false;
3483 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3484 Use *SliceV = Item[Slice * NumElts].first;
3485 if (!SliceV || SliceV->get()->getType() != Ty)
3486 return false;
3487 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3488 auto [V, Lane] = Item[Slice * NumElts + Elt];
3489 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
3490 return false;
3491 }
3492 }
3493 return true;
3494}
3495
3496static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
3497 const SmallPtrSet<Use *, 4> &IdentityLeafs,
3498 const SmallPtrSet<Use *, 4> &SplatLeafs,
3499 const SmallPtrSet<Use *, 4> &ConcatLeafs,
3500 IRBuilderBase &Builder,
3501 const TargetTransformInfo *TTI) {
3502 auto [FrontU, FrontLane] = Item.front();
3503
3504 if (IdentityLeafs.contains(Ptr: FrontU)) {
3505 return FrontU->get();
3506 }
3507 if (SplatLeafs.contains(Ptr: FrontU)) {
3508 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3509 return Builder.CreateShuffleVector(V: FrontU->get(), Mask);
3510 }
3511 if (ConcatLeafs.contains(Ptr: FrontU)) {
3512 unsigned NumElts =
3513 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements();
3514 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3515 for (unsigned S = 0; S < Values.size(); ++S)
3516 Values[S] = Item[S * NumElts].first->get();
3517
3518 while (Values.size() > 1) {
3519 NumElts *= 2;
3520 SmallVector<int, 16> Mask(NumElts, 0);
3521 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
3522 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3523 for (unsigned S = 0; S < NewValues.size(); ++S)
3524 NewValues[S] =
3525 Builder.CreateShuffleVector(V1: Values[S * 2], V2: Values[S * 2 + 1], Mask);
3526 Values = NewValues;
3527 }
3528 return Values[0];
3529 }
3530
3531 auto *I = cast<Instruction>(Val: FrontU->get());
3532 auto *II = dyn_cast<IntrinsicInst>(Val: I);
3533 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3534 SmallVector<Value *> Ops(NumOps);
3535 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3536 if (II &&
3537 isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI)) {
3538 Ops[Idx] = II->getOperand(i_nocapture: Idx);
3539 continue;
3540 }
3541 Ops[Idx] = generateNewInstTree(Item: generateInstLaneVectorFromOperand(Item, Op: Idx),
3542 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
3543 Builder, TTI);
3544 }
3545
3546 SmallVector<Value *, 8> ValueList;
3547 for (const auto &Lane : Item)
3548 if (Lane.first)
3549 ValueList.push_back(Elt: Lane.first->get());
3550
3551 Type *DstTy =
3552 FixedVectorType::get(ElementType: I->getType()->getScalarType(), NumElts: Ty->getNumElements());
3553 if (auto *BI = dyn_cast<BinaryOperator>(Val: I)) {
3554 auto *Value = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)BI->getOpcode(),
3555 LHS: Ops[0], RHS: Ops[1]);
3556 propagateIRFlags(I: Value, VL: ValueList);
3557 return Value;
3558 }
3559 if (auto *CI = dyn_cast<CmpInst>(Val: I)) {
3560 auto *Value = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: Ops[0], RHS: Ops[1]);
3561 propagateIRFlags(I: Value, VL: ValueList);
3562 return Value;
3563 }
3564 if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
3565 auto *Value = Builder.CreateSelect(C: Ops[0], True: Ops[1], False: Ops[2], Name: "", MDFrom: SI);
3566 propagateIRFlags(I: Value, VL: ValueList);
3567 return Value;
3568 }
3569 if (auto *CI = dyn_cast<CastInst>(Val: I)) {
3570 auto *Value = Builder.CreateCast(Op: CI->getOpcode(), V: Ops[0], DestTy: DstTy);
3571 propagateIRFlags(I: Value, VL: ValueList);
3572 return Value;
3573 }
3574 if (II) {
3575 auto *Value = Builder.CreateIntrinsic(RetTy: DstTy, ID: II->getIntrinsicID(), Args: Ops);
3576 propagateIRFlags(I: Value, VL: ValueList);
3577 return Value;
3578 }
3579 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3580 auto *Value =
3581 Builder.CreateUnOp(Opc: (Instruction::UnaryOps)I->getOpcode(), V: Ops[0]);
3582 propagateIRFlags(I: Value, VL: ValueList);
3583 return Value;
3584}
3585
3586// Starting from a shuffle, look up through operands tracking the shuffled index
3587// of each lane. If we can simplify away the shuffles to identities then
3588// do so.
3589bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3590 auto *Ty = dyn_cast<FixedVectorType>(Val: I.getType());
3591 if (!Ty || I.use_empty())
3592 return false;
3593
3594 SmallVector<InstLane> Start(Ty->getNumElements());
3595 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3596 Start[M] = lookThroughShuffles(U: &*I.use_begin(), Lane: M);
3597
3598 SmallVector<SmallVector<InstLane>> Worklist;
3599 Worklist.push_back(Elt: Start);
3600 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
3601 unsigned NumVisited = 0;
3602
3603 while (!Worklist.empty()) {
3604 if (++NumVisited > MaxInstrsToScan)
3605 return false;
3606
3607 SmallVector<InstLane> Item = Worklist.pop_back_val();
3608 auto [FrontU, FrontLane] = Item.front();
3609
3610 // If we found an undef first lane then bail out to keep things simple.
3611 if (!FrontU)
3612 return false;
3613
3614 // Helper to peek through bitcasts to the same value.
3615 auto IsEquiv = [&](Value *X, Value *Y) {
3616 return X->getType() == Y->getType() &&
3617 peekThroughBitcasts(V: X) == peekThroughBitcasts(V: Y);
3618 };
3619
3620 // Look for an identity value.
3621 if (FrontLane == 0 &&
3622 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements() ==
3623 Ty->getNumElements() &&
3624 all_of(Range: drop_begin(RangeOrContainer: enumerate(First&: Item)), P: [IsEquiv, Item](const auto &E) {
3625 Value *FrontV = Item.front().first->get();
3626 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
3627 E.value().second == (int)E.index());
3628 })) {
3629 IdentityLeafs.insert(Ptr: FrontU);
3630 continue;
3631 }
3632 // Look for constants, for the moment only supporting constant splats.
3633 if (auto *C = dyn_cast<Constant>(Val: FrontU);
3634 C && C->getSplatValue() &&
3635 all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3636 Value *FrontV = Item.front().first->get();
3637 Use *U = IL.first;
3638 return !U || (isa<Constant>(Val: U->get()) &&
3639 cast<Constant>(Val: U->get())->getSplatValue() ==
3640 cast<Constant>(Val: FrontV)->getSplatValue());
3641 })) {
3642 SplatLeafs.insert(Ptr: FrontU);
3643 continue;
3644 }
3645 // Look for a splat value.
3646 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3647 auto [FrontU, FrontLane] = Item.front();
3648 auto [U, Lane] = IL;
3649 return !U || (U->get() == FrontU->get() && Lane == FrontLane);
3650 })) {
3651 SplatLeafs.insert(Ptr: FrontU);
3652 continue;
3653 }
3654
3655 // We need each element to be the same type of value, and check that each
3656 // element has a single use.
3657 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3658 Value *FrontV = Item.front().first->get();
3659 if (!IL.first)
3660 return true;
3661 Value *V = IL.first->get();
3662 if (auto *I = dyn_cast<Instruction>(Val: V); I && !I->hasOneUser())
3663 return false;
3664 if (V->getValueID() != FrontV->getValueID())
3665 return false;
3666 if (auto *CI = dyn_cast<CmpInst>(Val: V))
3667 if (CI->getPredicate() != cast<CmpInst>(Val: FrontV)->getPredicate())
3668 return false;
3669 if (auto *CI = dyn_cast<CastInst>(Val: V))
3670 if (CI->getSrcTy()->getScalarType() !=
3671 cast<CastInst>(Val: FrontV)->getSrcTy()->getScalarType())
3672 return false;
3673 if (auto *SI = dyn_cast<SelectInst>(Val: V))
3674 if (!isa<VectorType>(Val: SI->getOperand(i_nocapture: 0)->getType()) ||
3675 SI->getOperand(i_nocapture: 0)->getType() !=
3676 cast<SelectInst>(Val: FrontV)->getOperand(i_nocapture: 0)->getType())
3677 return false;
3678 if (isa<CallInst>(Val: V) && !isa<IntrinsicInst>(Val: V))
3679 return false;
3680 auto *II = dyn_cast<IntrinsicInst>(Val: V);
3681 return !II || (isa<IntrinsicInst>(Val: FrontV) &&
3682 II->getIntrinsicID() ==
3683 cast<IntrinsicInst>(Val: FrontV)->getIntrinsicID() &&
3684 !II->hasOperandBundles());
3685 };
3686 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: CheckLaneIsEquivalentToFirst)) {
3687 // Check the operator is one that we support.
3688 if (isa<BinaryOperator, CmpInst>(Val: FrontU)) {
3689 // We exclude div/rem in case they hit UB from poison lanes.
3690 if (auto *BO = dyn_cast<BinaryOperator>(Val: FrontU);
3691 BO && BO->isIntDivRem())
3692 return false;
3693 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3694 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
3695 continue;
3696 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3697 FPToUIInst, SIToFPInst, UIToFPInst>(Val: FrontU)) {
3698 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3699 continue;
3700 } else if (auto *BitCast = dyn_cast<BitCastInst>(Val: FrontU)) {
3701 // TODO: Handle vector widening/narrowing bitcasts.
3702 auto *DstTy = dyn_cast<FixedVectorType>(Val: BitCast->getDestTy());
3703 auto *SrcTy = dyn_cast<FixedVectorType>(Val: BitCast->getSrcTy());
3704 if (DstTy && SrcTy &&
3705 SrcTy->getNumElements() == DstTy->getNumElements()) {
3706 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3707 continue;
3708 }
3709 } else if (isa<SelectInst>(Val: FrontU)) {
3710 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3711 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
3712 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 2));
3713 continue;
3714 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: FrontU);
3715 II && isTriviallyVectorizable(ID: II->getIntrinsicID()) &&
3716 !II->hasOperandBundles()) {
3717 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3718 if (isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Op,
3719 TTI: &TTI)) {
3720 if (!all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item, Op](InstLane &IL) {
3721 Value *FrontV = Item.front().first->get();
3722 Use *U = IL.first;
3723 return !U || (cast<Instruction>(Val: U->get())->getOperand(i: Op) ==
3724 cast<Instruction>(Val: FrontV)->getOperand(i: Op));
3725 }))
3726 return false;
3727 continue;
3728 }
3729 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op));
3730 }
3731 continue;
3732 }
3733 }
3734
3735 if (isFreeConcat(Item, CostKind, TTI)) {
3736 ConcatLeafs.insert(Ptr: FrontU);
3737 continue;
3738 }
3739
3740 return false;
3741 }
3742
3743 if (NumVisited <= 1)
3744 return false;
3745
3746 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3747
3748 // If we got this far, we know the shuffles are superfluous and can be
3749 // removed. Scan through again and generate the new tree of instructions.
3750 Builder.SetInsertPoint(&I);
3751 Value *V = generateNewInstTree(Item: Start, Ty, IdentityLeafs, SplatLeafs,
3752 ConcatLeafs, Builder, TTI: &TTI);
3753 replaceValue(Old&: I, New&: *V);
3754 return true;
3755}
3756
3757/// Given a commutative reduction, the order of the input lanes does not alter
3758/// the results. We can use this to remove certain shuffles feeding the
3759/// reduction, removing the need to shuffle at all.
3760bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3761 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
3762 if (!II)
3763 return false;
3764 switch (II->getIntrinsicID()) {
3765 case Intrinsic::vector_reduce_add:
3766 case Intrinsic::vector_reduce_mul:
3767 case Intrinsic::vector_reduce_and:
3768 case Intrinsic::vector_reduce_or:
3769 case Intrinsic::vector_reduce_xor:
3770 case Intrinsic::vector_reduce_smin:
3771 case Intrinsic::vector_reduce_smax:
3772 case Intrinsic::vector_reduce_umin:
3773 case Intrinsic::vector_reduce_umax:
3774 break;
3775 default:
3776 return false;
3777 }
3778
3779 // Find all the inputs when looking through operations that do not alter the
3780 // lane order (binops, for example). Currently we look for a single shuffle,
3781 // and can ignore splat values.
3782 std::queue<Value *> Worklist;
3783 SmallPtrSet<Value *, 4> Visited;
3784 ShuffleVectorInst *Shuffle = nullptr;
3785 if (auto *Op = dyn_cast<Instruction>(Val: I.getOperand(i: 0)))
3786 Worklist.push(x: Op);
3787
3788 while (!Worklist.empty()) {
3789 Value *CV = Worklist.front();
3790 Worklist.pop();
3791 if (Visited.contains(Ptr: CV))
3792 continue;
3793
3794 // Splats don't change the order, so can be safely ignored.
3795 if (isSplatValue(V: CV))
3796 continue;
3797
3798 Visited.insert(Ptr: CV);
3799
3800 if (auto *CI = dyn_cast<Instruction>(Val: CV)) {
3801 if (CI->isBinaryOp()) {
3802 for (auto *Op : CI->operand_values())
3803 Worklist.push(x: Op);
3804 continue;
3805 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: CI)) {
3806 if (Shuffle && Shuffle != SV)
3807 return false;
3808 Shuffle = SV;
3809 continue;
3810 }
3811 }
3812
3813 // Anything else is currently an unknown node.
3814 return false;
3815 }
3816
3817 if (!Shuffle)
3818 return false;
3819
3820 // Check all uses of the binary ops and shuffles are also included in the
3821 // lane-invariant operations (Visited should be the list of lanewise
3822 // instructions, including the shuffle that we found).
3823 for (auto *V : Visited)
3824 for (auto *U : V->users())
3825 if (!Visited.contains(Ptr: U) && U != &I)
3826 return false;
3827
3828 FixedVectorType *VecType =
3829 dyn_cast<FixedVectorType>(Val: II->getOperand(i_nocapture: 0)->getType());
3830 if (!VecType)
3831 return false;
3832 FixedVectorType *ShuffleInputType =
3833 dyn_cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType());
3834 if (!ShuffleInputType)
3835 return false;
3836 unsigned NumInputElts = ShuffleInputType->getNumElements();
3837
3838 // Find the mask from sorting the lanes into order. This is most likely to
3839 // become a identity or concat mask. Undef elements are pushed to the end.
3840 SmallVector<int> ConcatMask;
3841 Shuffle->getShuffleMask(Result&: ConcatMask);
3842 sort(C&: ConcatMask, Comp: [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3843 bool UsesSecondVec =
3844 any_of(Range&: ConcatMask, P: [&](int M) { return M >= (int)NumInputElts; });
3845
3846 InstructionCost OldCost = TTI.getShuffleCost(
3847 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3848 SrcTy: ShuffleInputType, Mask: Shuffle->getShuffleMask(), CostKind);
3849 InstructionCost NewCost = TTI.getShuffleCost(
3850 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3851 SrcTy: ShuffleInputType, Mask: ConcatMask, CostKind);
3852
3853 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3854 << "\n");
3855 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3856 << "\n");
3857 bool MadeChanges = false;
3858 if (NewCost < OldCost) {
3859 Builder.SetInsertPoint(Shuffle);
3860 Value *NewShuffle = Builder.CreateShuffleVector(
3861 V1: Shuffle->getOperand(i_nocapture: 0), V2: Shuffle->getOperand(i_nocapture: 1), Mask: ConcatMask);
3862 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3863 replaceValue(Old&: *Shuffle, New&: *NewShuffle);
3864 return true;
3865 }
3866
3867 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3868 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3869 MadeChanges |= foldSelectShuffle(I&: *Shuffle, FromReduction: true);
3870 return MadeChanges;
3871}
3872
3873/// For a given chain of patterns of the following form:
3874///
3875/// ```
3876/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3877///
3878/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3879/// ty1> %1)
3880/// OR
3881/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3882///
3883/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3884/// ...
3885/// ...
3886/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3887/// 3), <n x ty1> %(i - 2)
3888/// OR
3889/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3890///
3891/// %(i) = extractelement <n x ty1> %(i - 1), 0
3892/// ```
3893///
3894/// Where:
3895/// `mask` follows a partition pattern:
3896///
3897/// Ex:
3898/// [n = 8, p = poison]
3899///
3900/// 4 5 6 7 | p p p p
3901/// 2 3 | p p p p p p
3902/// 1 | p p p p p p p
3903///
3904/// For powers of 2, there's a consistent pattern, but for other cases
3905/// the parity of the current half value at each step decides the
3906/// next partition half (see `ExpectedParityMask` for more logical details
3907/// in generalising this).
3908///
3909/// Ex:
3910/// [n = 6]
3911///
3912/// 3 4 5 | p p p
3913/// 1 2 | p p p p
3914/// 1 | p p p p p
3915bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3916 // Going bottom-up for the pattern.
3917 std::queue<Value *> InstWorklist;
3918 InstructionCost OrigCost = 0;
3919
3920 // Common instruction operation after each shuffle op.
3921 std::optional<unsigned int> CommonCallOp = std::nullopt;
3922 std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3923
3924 bool IsFirstCallOrBinInst = true;
3925 bool ShouldBeCallOrBinInst = true;
3926
3927 // This stores the last used instructions for shuffle/common op.
3928 //
3929 // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3930 // instructions from either shuffle/common op.
3931 SmallVector<Value *, 2> PrevVecV(2, nullptr);
3932
3933 Value *VecOpEE;
3934 if (!match(V: &I, P: m_ExtractElt(Val: m_Value(V&: VecOpEE), Idx: m_Zero())))
3935 return false;
3936
3937 auto *FVT = dyn_cast<FixedVectorType>(Val: VecOpEE->getType());
3938 if (!FVT)
3939 return false;
3940
3941 int64_t VecSize = FVT->getNumElements();
3942 if (VecSize < 2)
3943 return false;
3944
3945 // Number of levels would be ~log2(n), considering we always partition
3946 // by half for this fold pattern.
3947 unsigned int NumLevels = Log2_64_Ceil(Value: VecSize), VisitedCnt = 0;
3948 int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3949
3950 // This is how we generalise for all element sizes.
3951 // At each step, if vector size is odd, we need non-poison
3952 // values to cover the dominant half so we don't miss out on any element.
3953 //
3954 // This mask will help us retrieve this as we go from bottom to top:
3955 //
3956 // Mask Set -> N = N * 2 - 1
3957 // Mask Unset -> N = N * 2
3958 for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3959 Cur = (Cur + 1) / 2, --Mask) {
3960 if (Cur & 1)
3961 ExpectedParityMask |= (1ll << Mask);
3962 }
3963
3964 InstWorklist.push(x: VecOpEE);
3965
3966 while (!InstWorklist.empty()) {
3967 Value *CI = InstWorklist.front();
3968 InstWorklist.pop();
3969
3970 if (auto *II = dyn_cast<IntrinsicInst>(Val: CI)) {
3971 if (!ShouldBeCallOrBinInst)
3972 return false;
3973
3974 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
3975 return false;
3976
3977 // For the first found call/bin op, the vector has to come from the
3978 // extract element op.
3979 if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
3980 return false;
3981 IsFirstCallOrBinInst = false;
3982
3983 if (!CommonCallOp)
3984 CommonCallOp = II->getIntrinsicID();
3985 if (II->getIntrinsicID() != *CommonCallOp)
3986 return false;
3987
3988 switch (II->getIntrinsicID()) {
3989 case Intrinsic::umin:
3990 case Intrinsic::umax:
3991 case Intrinsic::smin:
3992 case Intrinsic::smax: {
3993 auto *Op0 = II->getOperand(i_nocapture: 0);
3994 auto *Op1 = II->getOperand(i_nocapture: 1);
3995 PrevVecV[0] = Op0;
3996 PrevVecV[1] = Op1;
3997 break;
3998 }
3999 default:
4000 return false;
4001 }
4002 ShouldBeCallOrBinInst ^= 1;
4003
4004 IntrinsicCostAttributes ICA(
4005 *CommonCallOp, II->getType(),
4006 {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
4007 OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4008
4009 // We may need a swap here since it can be (a, b) or (b, a)
4010 // and accordingly change as we go up.
4011 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4012 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4013 InstWorklist.push(x: PrevVecV[1]);
4014 InstWorklist.push(x: PrevVecV[0]);
4015 } else if (auto *BinOp = dyn_cast<BinaryOperator>(Val: CI)) {
4016 // Similar logic for bin ops.
4017
4018 if (!ShouldBeCallOrBinInst)
4019 return false;
4020
4021 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4022 return false;
4023
4024 if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4025 return false;
4026 IsFirstCallOrBinInst = false;
4027
4028 if (!CommonBinOp)
4029 CommonBinOp = BinOp->getOpcode();
4030
4031 if (BinOp->getOpcode() != *CommonBinOp)
4032 return false;
4033
4034 switch (*CommonBinOp) {
4035 case BinaryOperator::Add:
4036 case BinaryOperator::Mul:
4037 case BinaryOperator::Or:
4038 case BinaryOperator::And:
4039 case BinaryOperator::Xor: {
4040 auto *Op0 = BinOp->getOperand(i_nocapture: 0);
4041 auto *Op1 = BinOp->getOperand(i_nocapture: 1);
4042 PrevVecV[0] = Op0;
4043 PrevVecV[1] = Op1;
4044 break;
4045 }
4046 default:
4047 return false;
4048 }
4049 ShouldBeCallOrBinInst ^= 1;
4050
4051 OrigCost +=
4052 TTI.getArithmeticInstrCost(Opcode: *CommonBinOp, Ty: BinOp->getType(), CostKind);
4053
4054 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4055 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4056 InstWorklist.push(x: PrevVecV[1]);
4057 InstWorklist.push(x: PrevVecV[0]);
4058 } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(Val: CI)) {
4059 // We shouldn't have any null values in the previous vectors,
4060 // is so, there was a mismatch in pattern.
4061 if (ShouldBeCallOrBinInst || any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4062 return false;
4063
4064 if (SVInst != PrevVecV[1])
4065 return false;
4066
4067 ArrayRef<int> CurMask;
4068 if (!match(V: SVInst, P: m_Shuffle(v1: m_Specific(V: PrevVecV[0]), v2: m_Poison(),
4069 mask: m_Mask(CurMask))))
4070 return false;
4071
4072 // Subtract the parity mask when checking the condition.
4073 for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
4074 if (Mask < ShuffleMaskHalf &&
4075 CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
4076 return false;
4077 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
4078 return false;
4079 }
4080
4081 // Update mask values.
4082 ShuffleMaskHalf *= 2;
4083 ShuffleMaskHalf -= (ExpectedParityMask & 1);
4084 ExpectedParityMask >>= 1;
4085
4086 OrigCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
4087 DstTy: SVInst->getType(), SrcTy: SVInst->getType(),
4088 Mask: CurMask, CostKind);
4089
4090 VisitedCnt += 1;
4091 if (!ExpectedParityMask && VisitedCnt == NumLevels)
4092 break;
4093
4094 ShouldBeCallOrBinInst ^= 1;
4095 } else {
4096 return false;
4097 }
4098 }
4099
4100 // Pattern should end with a shuffle op.
4101 if (ShouldBeCallOrBinInst)
4102 return false;
4103
4104 assert(VecSize != -1 && "Expected Match for Vector Size");
4105
4106 Value *FinalVecV = PrevVecV[0];
4107 if (!FinalVecV)
4108 return false;
4109
4110 auto *FinalVecVTy = cast<FixedVectorType>(Val: FinalVecV->getType());
4111
4112 Intrinsic::ID ReducedOp =
4113 (CommonCallOp ? getMinMaxReductionIntrinsicID(IID: *CommonCallOp)
4114 : getReductionForBinop(Opc: *CommonBinOp));
4115 if (!ReducedOp)
4116 return false;
4117
4118 IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
4119 InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
4120
4121 if (NewCost >= OrigCost)
4122 return false;
4123
4124 auto *ReducedResult =
4125 Builder.CreateIntrinsic(ID: ReducedOp, Types: {FinalVecV->getType()}, Args: {FinalVecV});
4126 replaceValue(Old&: I, New&: *ReducedResult);
4127
4128 return true;
4129}
4130
4131/// Determine if its more efficient to fold:
4132/// reduce(trunc(x)) -> trunc(reduce(x)).
4133/// reduce(sext(x)) -> sext(reduce(x)).
4134/// reduce(zext(x)) -> zext(reduce(x)).
4135bool VectorCombine::foldCastFromReductions(Instruction &I) {
4136 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
4137 if (!II)
4138 return false;
4139
4140 bool TruncOnly = false;
4141 Intrinsic::ID IID = II->getIntrinsicID();
4142 switch (IID) {
4143 case Intrinsic::vector_reduce_add:
4144 case Intrinsic::vector_reduce_mul:
4145 TruncOnly = true;
4146 break;
4147 case Intrinsic::vector_reduce_and:
4148 case Intrinsic::vector_reduce_or:
4149 case Intrinsic::vector_reduce_xor:
4150 break;
4151 default:
4152 return false;
4153 }
4154
4155 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4156 Value *ReductionSrc = I.getOperand(i: 0);
4157
4158 Value *Src;
4159 if (!match(V: ReductionSrc, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: Src)))) &&
4160 (TruncOnly || !match(V: ReductionSrc, P: m_OneUse(SubPattern: m_ZExtOrSExt(Op: m_Value(V&: Src))))))
4161 return false;
4162
4163 auto CastOpc =
4164 (Instruction::CastOps)cast<Instruction>(Val: ReductionSrc)->getOpcode();
4165
4166 auto *SrcTy = cast<VectorType>(Val: Src->getType());
4167 auto *ReductionSrcTy = cast<VectorType>(Val: ReductionSrc->getType());
4168 Type *ResultTy = I.getType();
4169
4170 InstructionCost OldCost = TTI.getArithmeticReductionCost(
4171 Opcode: ReductionOpc, Ty: ReductionSrcTy, FMF: std::nullopt, CostKind);
4172 OldCost += TTI.getCastInstrCost(Opcode: CastOpc, Dst: ReductionSrcTy, Src: SrcTy,
4173 CCH: TTI::CastContextHint::None, CostKind,
4174 I: cast<CastInst>(Val: ReductionSrc));
4175 InstructionCost NewCost =
4176 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: SrcTy, FMF: std::nullopt,
4177 CostKind) +
4178 TTI.getCastInstrCost(Opcode: CastOpc, Dst: ResultTy, Src: ReductionSrcTy->getScalarType(),
4179 CCH: TTI::CastContextHint::None, CostKind);
4180
4181 if (OldCost <= NewCost || !NewCost.isValid())
4182 return false;
4183
4184 Value *NewReduction = Builder.CreateIntrinsic(RetTy: SrcTy->getScalarType(),
4185 ID: II->getIntrinsicID(), Args: {Src});
4186 Value *NewCast = Builder.CreateCast(Op: CastOpc, V: NewReduction, DestTy: ResultTy);
4187 replaceValue(Old&: I, New&: *NewCast);
4188 return true;
4189}
4190
4191/// Fold:
4192/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4193/// into:
4194/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4195///
4196/// Sign-bit reductions produce values with known semantics:
4197/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4198/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4199/// - reduce.add: count of negative elements (0 to NumElts)
4200///
4201/// We transform to a direct sign check on reduce.{or,umax} or
4202/// reduce.{and,umin} without explicit sign-bit extraction.
4203///
4204/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4205bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4206 CmpPredicate Pred;
4207 Value *ReduceOp;
4208 const APInt *CmpVal;
4209 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: ReduceOp), R: m_APInt(Res&: CmpVal))))
4210 return false;
4211
4212 auto *II = dyn_cast<IntrinsicInst>(Val: ReduceOp);
4213 if (!II || !II->hasOneUse())
4214 return false;
4215
4216 Intrinsic::ID OrigIID = II->getIntrinsicID();
4217 switch (OrigIID) {
4218 case Intrinsic::vector_reduce_or:
4219 case Intrinsic::vector_reduce_umax:
4220 case Intrinsic::vector_reduce_and:
4221 case Intrinsic::vector_reduce_umin:
4222 case Intrinsic::vector_reduce_add:
4223 break;
4224 default:
4225 return false;
4226 }
4227
4228 Value *ReductionSrc = II->getArgOperand(i: 0);
4229 if (!ReductionSrc->hasOneUse())
4230 return false;
4231
4232 auto *VecTy = dyn_cast<FixedVectorType>(Val: ReductionSrc->getType());
4233 if (!VecTy)
4234 return false;
4235
4236 unsigned BitWidth = VecTy->getScalarSizeInBits();
4237 unsigned NumElts = VecTy->getNumElements();
4238
4239 // For reduce.add, the result is the count of negative elements
4240 // (0 to NumElts). This must fit in the scalar type without overflow.
4241 // Otherwise, reduce.add can wrap and identity doesn't hold.
4242 if (OrigIID == Intrinsic::vector_reduce_add && !isUIntN(N: BitWidth, x: NumElts))
4243 return false;
4244
4245 // Match sign-bit extraction: shr X, (bitwidth-1)
4246 Value *X;
4247 if (!match(V: ReductionSrc, P: m_Shr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth - 1))))
4248 return false;
4249
4250 // MaxVal: 1 for or/and/umax/umin, NumElts for add
4251 APInt MaxVal(CmpVal->getBitWidth(),
4252 OrigIID == Intrinsic::vector_reduce_add ? NumElts : 1);
4253
4254 // In addition to direct comparisons EQ 0, NE 0, EQ 1, NE 1, etc. we support
4255 // inequalities that can be interpreted as either EQ or NE considering a
4256 // rather narrow range of possible value of sign-bit reductions.
4257 bool IsEq;
4258 bool TestsHigh;
4259 if (ICmpInst::isEquality(P: Pred)) {
4260 // EQ/NE: comparison must be against 0 or MaxVal
4261 if (!CmpVal->isZero() && *CmpVal != MaxVal)
4262 return false;
4263 IsEq = Pred == ICmpInst::ICMP_EQ;
4264 TestsHigh = *CmpVal == MaxVal;
4265 } else if (ICmpInst::isLT(P: Pred) && *CmpVal == MaxVal) {
4266 // s/ult MaxVal -> ne MaxVal
4267 IsEq = false;
4268 TestsHigh = true;
4269 } else if (ICmpInst::isGT(P: Pred) && *CmpVal == MaxVal - 1) {
4270 // s/ugt MaxVal-1 -> eq MaxVal
4271 IsEq = true;
4272 TestsHigh = true;
4273 } else {
4274 return false;
4275 }
4276
4277 // For this fold we support four types of checks:
4278 //
4279 // 1. All lanes are negative - AllNeg
4280 // 2. All lanes are non-negative - AllNonNeg
4281 // 3. At least one negative lane - AnyNeg
4282 // 4. At least one non-negative lane - AnyNonNeg
4283 //
4284 // For each case, we can generate the following code:
4285 //
4286 // 1. AllNeg - reduce.and/umin(X) < 0
4287 // 2. AllNonNeg - reduce.or/umax(X) > -1
4288 // 3. AnyNeg - reduce.or/umax(X) < 0
4289 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4290 //
4291 // The table below shows the aggregation of all supported cases
4292 // using these four cases.
4293 //
4294 // Reduction | == 0 | != 0 | == MAX | != MAX
4295 // ------------+-----------+-----------+-----------+-----------
4296 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4297 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4298 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4299 //
4300 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4301 //
4302 // For easier codegen and check inversion, we use the following encoding:
4303 //
4304 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4305 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4306 // 3. Bit-1 === universal (1) or existential (0) check
4307 //
4308 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4309 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4310 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4311 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4312 //
4313 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4314 //
4315 enum CheckKind : unsigned {
4316 AnyNonNeg = 0b000,
4317 AllNeg = 0b011,
4318 AllNonNeg = 0b101,
4319 AnyNeg = 0b110,
4320 };
4321 // Return true if we fold this check into or/umax and false for and/umin
4322 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4323 // Return true if we should check if result is negative and false otherwise
4324 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4325 // Logically invert the check
4326 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4327
4328 CheckKind Base;
4329 switch (OrigIID) {
4330 case Intrinsic::vector_reduce_or:
4331 case Intrinsic::vector_reduce_umax:
4332 Base = TestsHigh ? AnyNeg : AllNonNeg;
4333 break;
4334 case Intrinsic::vector_reduce_and:
4335 case Intrinsic::vector_reduce_umin:
4336 Base = TestsHigh ? AllNeg : AnyNonNeg;
4337 break;
4338 case Intrinsic::vector_reduce_add:
4339 Base = TestsHigh ? AllNeg : AllNonNeg;
4340 break;
4341 default:
4342 llvm_unreachable("Unexpected intrinsic");
4343 }
4344
4345 CheckKind Check = IsEq ? Base : Invert(Base);
4346
4347 // Calculate old cost: shift + reduction
4348 InstructionCost OldCost =
4349 TTI.getInstructionCost(U: cast<Instruction>(Val: ReductionSrc), CostKind);
4350 OldCost += TTI.getInstructionCost(U: II, CostKind);
4351
4352 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4353 InstructionCost ArithCost =
4354 TTI.getArithmeticReductionCost(Opcode: getArithmeticReductionInstruction(RdxID: Arith),
4355 Ty: VecTy, FMF: std::nullopt, CostKind);
4356 InstructionCost MinMaxCost =
4357 TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: MinMax), Ty: VecTy,
4358 FMF: FastMathFlags(), CostKind);
4359 return ArithCost <= MinMaxCost ? std::make_pair(x&: Arith, y&: ArithCost)
4360 : std::make_pair(x&: MinMax, y&: MinMaxCost);
4361 };
4362
4363 // Choose output reduction based on encoding's MSB
4364 auto [NewIID, NewCost] = RequiresOr(Check)
4365 ? PickCheaper(Intrinsic::vector_reduce_or,
4366 Intrinsic::vector_reduce_umax)
4367 : PickCheaper(Intrinsic::vector_reduce_and,
4368 Intrinsic::vector_reduce_umin);
4369
4370 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4371 << OldCost << " vs NewCost: " << NewCost << "\n");
4372
4373 if (NewCost > OldCost)
4374 return false;
4375
4376 // Generate comparison based on encoding's neg bit: slt 0 for neg, sgt -1 for
4377 // non-neg
4378 Builder.SetInsertPoint(&I);
4379 Type *ScalarTy = VecTy->getScalarType();
4380 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: NewIID, Args: {X});
4381 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(Arg: NewReduce)
4382 : Builder.CreateIsNotNeg(Arg: NewReduce);
4383 replaceValue(Old&: I, New&: *NewCmp);
4384 return true;
4385}
4386
4387/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4388///
4389/// We can prove it for cases when:
4390///
4391/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4392/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4393/// 2. f(x) == 0 <=> x == 0
4394///
4395/// From 1 and 2 (or 1' and 2), we can infer that
4396///
4397/// OP f(X_i) == 0 <=> OP X_i == 0.
4398///
4399/// (1)
4400/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4401/// (2)
4402/// <=> \forall i \in [1, N] X_i == 0
4403/// (1)
4404/// <=> OP(X_i) == 0
4405///
4406/// For some of the OP's and f's, we need to have domain constraints on X
4407/// to ensure properties 1 (or 1') and 2.
4408bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4409 CmpPredicate Pred;
4410 Value *Op;
4411 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op), R: m_Zero())) ||
4412 !ICmpInst::isEquality(P: Pred))
4413 return false;
4414
4415 auto *II = dyn_cast<IntrinsicInst>(Val: Op);
4416 if (!II)
4417 return false;
4418
4419 switch (II->getIntrinsicID()) {
4420 case Intrinsic::vector_reduce_add:
4421 case Intrinsic::vector_reduce_or:
4422 case Intrinsic::vector_reduce_umin:
4423 case Intrinsic::vector_reduce_umax:
4424 case Intrinsic::vector_reduce_smin:
4425 case Intrinsic::vector_reduce_smax:
4426 break;
4427 default:
4428 return false;
4429 }
4430
4431 Value *InnerOp = II->getArgOperand(i: 0);
4432
4433 // TODO: fixed vector type might be too restrictive
4434 if (!II->hasOneUse() || !isa<FixedVectorType>(Val: InnerOp->getType()))
4435 return false;
4436
4437 Value *X = nullptr;
4438
4439 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4440 //
4441 // 1. f(x) = shl nuw x, y for arbitrary y
4442 // 2. f(x) = mul nuw x, c for defined c != 0
4443 // 3. f(x) = zext x
4444 // 4. f(x) = sext x
4445 // 5. f(x) = neg x
4446 //
4447 if (!(match(V: InnerOp, P: m_NUWShl(L: m_Value(V&: X), R: m_Value())) || // Case 1
4448 match(V: InnerOp, P: m_NUWMul(L: m_Value(V&: X), R: m_NonZeroInt())) || // Case 2
4449 match(V: InnerOp, P: m_ZExt(Op: m_Value(V&: X))) || // Case 3
4450 match(V: InnerOp, P: m_SExt(Op: m_Value(V&: X))) || // Case 4
4451 match(V: InnerOp, P: m_Neg(V: m_Value(V&: X))) // Case 5
4452 ))
4453 return false;
4454
4455 SimplifyQuery S = SQ.getWithInstruction(I: &I);
4456 auto *XTy = cast<FixedVectorType>(Val: X->getType());
4457
4458 // Check for domain constraints for all supported reductions.
4459 //
4460 // a. OR X_i - has property 1 for every X
4461 // b. UMAX X_i - has property 1 for every X
4462 // c. UMIN X_i - has property 1' for every X
4463 // d. SMAX X_i - has property 1 for X >= 0
4464 // e. SMIN X_i - has property 1' for X >= 0
4465 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4466 //
4467 // In order for the proof to work, we need 1 (or 1') to be true for both
4468 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4469 //
4470 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4471 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4472 // of known bits, we can't reasonably hold knowledge of "either 0
4473 // or negative".
4474 switch (II->getIntrinsicID()) {
4475 case Intrinsic::vector_reduce_add: {
4476 // We need to check that both X_i and f(X_i) have enough leading
4477 // zeros to not overflow.
4478 KnownBits KnownX = computeKnownBits(V: X, Q: S);
4479 KnownBits KnownFX = computeKnownBits(V: InnerOp, Q: S);
4480 unsigned NumElems = XTy->getNumElements();
4481 // Adding N elements loses at most ceil(log2(N)) leading bits.
4482 unsigned LostBits = Log2_32_Ceil(Value: NumElems);
4483 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4484 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4485 // Need at least one leading zero left after summation to ensure no overflow
4486 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4487 return false;
4488
4489 // We are not checking whether X or f(X) are positive explicitly because
4490 // we implicitly checked for it when we checked if both cases have enough
4491 // leading zeros to not wrap addition.
4492 break;
4493 }
4494 case Intrinsic::vector_reduce_smin:
4495 case Intrinsic::vector_reduce_smax:
4496 // Check whether X >= 0 and f(X) >= 0
4497 if (!isKnownNonNegative(V: InnerOp, SQ: S) || !isKnownNonNegative(V: X, SQ: S))
4498 return false;
4499
4500 break;
4501 default:
4502 break;
4503 };
4504
4505 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4506 << *II << "\n");
4507
4508 // For zext/sext, check if the transform is profitable using cost model.
4509 // For other operations (shl, mul, neg), we're removing an instruction
4510 // while keeping the same reduction type, so it's always profitable.
4511 if (isa<ZExtInst>(Val: InnerOp) || isa<SExtInst>(Val: InnerOp)) {
4512 auto *FXTy = cast<FixedVectorType>(Val: InnerOp->getType());
4513 Intrinsic::ID IID = II->getIntrinsicID();
4514
4515 InstructionCost ExtCost = TTI.getCastInstrCost(
4516 Opcode: cast<CastInst>(Val: InnerOp)->getOpcode(), Dst: FXTy, Src: XTy,
4517 CCH: TTI::CastContextHint::None, CostKind, I: cast<CastInst>(Val: InnerOp));
4518
4519 InstructionCost OldReduceCost, NewReduceCost;
4520 switch (IID) {
4521 case Intrinsic::vector_reduce_add:
4522 case Intrinsic::vector_reduce_or:
4523 OldReduceCost = TTI.getArithmeticReductionCost(
4524 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: FXTy, FMF: std::nullopt, CostKind);
4525 NewReduceCost = TTI.getArithmeticReductionCost(
4526 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: XTy, FMF: std::nullopt, CostKind);
4527 break;
4528 case Intrinsic::vector_reduce_umin:
4529 case Intrinsic::vector_reduce_umax:
4530 case Intrinsic::vector_reduce_smin:
4531 case Intrinsic::vector_reduce_smax:
4532 OldReduceCost = TTI.getMinMaxReductionCost(
4533 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: FXTy, FMF: FastMathFlags(), CostKind);
4534 NewReduceCost = TTI.getMinMaxReductionCost(
4535 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: XTy, FMF: FastMathFlags(), CostKind);
4536 break;
4537 default:
4538 llvm_unreachable("Unexpected reduction");
4539 }
4540
4541 InstructionCost OldCost = OldReduceCost + ExtCost;
4542 InstructionCost NewCost =
4543 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4544
4545 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4546 << *InnerOp << "\n OldCost: " << OldCost
4547 << " vs NewCost: " << NewCost << "\n");
4548
4549 // We consider transformation to still be potentially beneficial even
4550 // when the costs are the same because we might remove a use from f(X)
4551 // and unlock other optimizations. Equal costs would just mean that we
4552 // didn't make it worse in the worst case.
4553 if (NewCost > OldCost)
4554 return false;
4555 }
4556
4557 // Since we support zext and sext as f, we might change the scalar type
4558 // of the intrinsic.
4559 Type *Ty = XTy->getScalarType();
4560 Value *NewReduce = Builder.CreateIntrinsic(RetTy: Ty, ID: II->getIntrinsicID(), Args: {X});
4561 Value *NewCmp =
4562 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::getNullValue(Ty));
4563 replaceValue(Old&: I, New&: *NewCmp);
4564 return true;
4565}
4566
4567/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4568/// based on cost, preserving the comparison semantics.
4569///
4570/// We use two fundamental properties for each pair:
4571///
4572/// 1. or(X) == 0 <=> umax(X) == 0
4573/// 2. or(X) == 1 <=> umax(X) == 1
4574/// 3. sign(or(X)) == sign(umax(X))
4575///
4576/// 1. and(X) == -1 <=> umin(X) == -1
4577/// 2. and(X) == -2 <=> umin(X) == -2
4578/// 3. sign(and(X)) == sign(umin(X))
4579///
4580/// From these we can infer the following transformations:
4581/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4582/// b. or(X) s< 0 <-> umax(X) s< 0
4583/// c. or(X) s> -1 <-> umax(X) s> -1
4584/// d. or(X) s< 1 <-> umax(X) s< 1
4585/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4586/// f. or(X) s< 2 <-> umax(X) s< 2
4587/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4588/// h. and(X) s< 0 <-> umin(X) s< 0
4589/// i. and(X) s> -1 <-> umin(X) s> -1
4590/// j. and(X) s> -2 <-> umin(X) s> -2
4591/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4592/// l. and(X) s> -3 <-> umin(X) s> -3
4593///
4594bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4595 CmpPredicate Pred;
4596 Value *ReduceOp;
4597 const APInt *CmpVal;
4598 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: ReduceOp), R: m_APInt(Res&: CmpVal))))
4599 return false;
4600
4601 auto *II = dyn_cast<IntrinsicInst>(Val: ReduceOp);
4602 if (!II || !II->hasOneUse())
4603 return false;
4604
4605 const auto IsValidOrUmaxCmp = [&]() {
4606 // Cases a and e
4607 bool IsEquality =
4608 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(P: Pred);
4609 // Case c
4610 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4611 // Cases b, d, and f
4612 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4613 Pred == ICmpInst::ICMP_SLT;
4614 return IsEquality || IsPositive || IsNegative;
4615 };
4616
4617 const auto IsValidAndUminCmp = [&]() {
4618 const auto LeadingOnes = CmpVal->countl_one();
4619
4620 // Cases g and k
4621 bool IsEquality =
4622 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4623 ICmpInst::isEquality(P: Pred);
4624 // Case h
4625 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4626 // Cases i, j, and l
4627 bool IsPositive =
4628 // if the number has at least N - 2 leading ones
4629 // and the two LSBs are:
4630 // - 1 x 1 -> -1
4631 // - 1 x 0 -> -2
4632 // - 0 x 1 -> -3
4633 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4634 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4635 return IsEquality || IsNegative || IsPositive;
4636 };
4637
4638 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4639 Intrinsic::ID AlternativeIID;
4640
4641 // Check if this is a valid comparison pattern and determine the alternate
4642 // reduction intrinsic.
4643 switch (OriginalIID) {
4644 case Intrinsic::vector_reduce_or:
4645 if (!IsValidOrUmaxCmp())
4646 return false;
4647 AlternativeIID = Intrinsic::vector_reduce_umax;
4648 break;
4649 case Intrinsic::vector_reduce_umax:
4650 if (!IsValidOrUmaxCmp())
4651 return false;
4652 AlternativeIID = Intrinsic::vector_reduce_or;
4653 break;
4654 case Intrinsic::vector_reduce_and:
4655 if (!IsValidAndUminCmp())
4656 return false;
4657 AlternativeIID = Intrinsic::vector_reduce_umin;
4658 break;
4659 case Intrinsic::vector_reduce_umin:
4660 if (!IsValidAndUminCmp())
4661 return false;
4662 AlternativeIID = Intrinsic::vector_reduce_and;
4663 break;
4664 default:
4665 return false;
4666 }
4667
4668 Value *X = II->getArgOperand(i: 0);
4669 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
4670 if (!VecTy)
4671 return false;
4672
4673 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4674 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4675 if (ReductionOpc != Instruction::ICmp)
4676 return TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecTy, FMF: std::nullopt,
4677 CostKind);
4678 return TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: VecTy,
4679 FMF: FastMathFlags(), CostKind);
4680 };
4681
4682 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4683 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4684
4685 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4686 << "\n OrigCost: " << OrigCost
4687 << " vs AltCost: " << AltCost << "\n");
4688
4689 if (AltCost >= OrigCost)
4690 return false;
4691
4692 Builder.SetInsertPoint(&I);
4693 Type *ScalarTy = VecTy->getScalarType();
4694 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: AlternativeIID, Args: {X});
4695 Value *NewCmp =
4696 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::get(Ty: ScalarTy, V: *CmpVal));
4697
4698 replaceValue(Old&: I, New&: *NewCmp);
4699 return true;
4700}
4701
4702/// Returns true if this ShuffleVectorInst eventually feeds into a
4703/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
4704/// chains of shuffles and binary operators (in any combination/order).
4705/// The search does not go deeper than the given Depth.
4706static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI) {
4707 constexpr unsigned MaxVisited = 32;
4708 SmallPtrSet<Instruction *, 8> Visited;
4709 SmallVector<Instruction *, 4> WorkList;
4710 bool FoundReduction = false;
4711
4712 WorkList.push_back(Elt: SVI);
4713 while (!WorkList.empty()) {
4714 Instruction *I = WorkList.pop_back_val();
4715 for (User *U : I->users()) {
4716 auto *UI = cast<Instruction>(Val: U);
4717 if (!UI || !Visited.insert(Ptr: UI).second)
4718 continue;
4719 if (Visited.size() > MaxVisited)
4720 return false;
4721 if (auto *II = dyn_cast<IntrinsicInst>(Val: UI)) {
4722 // More than one reduction reached
4723 if (FoundReduction)
4724 return false;
4725 switch (II->getIntrinsicID()) {
4726 case Intrinsic::vector_reduce_add:
4727 case Intrinsic::vector_reduce_mul:
4728 case Intrinsic::vector_reduce_and:
4729 case Intrinsic::vector_reduce_or:
4730 case Intrinsic::vector_reduce_xor:
4731 case Intrinsic::vector_reduce_smin:
4732 case Intrinsic::vector_reduce_smax:
4733 case Intrinsic::vector_reduce_umin:
4734 case Intrinsic::vector_reduce_umax:
4735 FoundReduction = true;
4736 continue;
4737 default:
4738 return false;
4739 }
4740 }
4741
4742 if (!isa<BinaryOperator>(Val: UI) && !isa<ShuffleVectorInst>(Val: UI))
4743 return false;
4744
4745 WorkList.emplace_back(Args&: UI);
4746 }
4747 }
4748 return FoundReduction;
4749}
4750
4751/// This method looks for groups of shuffles acting on binops, of the form:
4752/// %x = shuffle ...
4753/// %y = shuffle ...
4754/// %a = binop %x, %y
4755/// %b = binop %x, %y
4756/// shuffle %a, %b, selectmask
4757/// We may, especially if the shuffle is wider than legal, be able to convert
4758/// the shuffle to a form where only parts of a and b need to be computed. On
4759/// architectures with no obvious "select" shuffle, this can reduce the total
4760/// number of operations if the target reports them as cheaper.
4761bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
4762 auto *SVI = cast<ShuffleVectorInst>(Val: &I);
4763 auto *VT = cast<FixedVectorType>(Val: I.getType());
4764 auto *Op0 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 0));
4765 auto *Op1 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 1));
4766 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
4767 VT != Op0->getType())
4768 return false;
4769
4770 auto *SVI0A = dyn_cast<Instruction>(Val: Op0->getOperand(i: 0));
4771 auto *SVI0B = dyn_cast<Instruction>(Val: Op0->getOperand(i: 1));
4772 auto *SVI1A = dyn_cast<Instruction>(Val: Op1->getOperand(i: 0));
4773 auto *SVI1B = dyn_cast<Instruction>(Val: Op1->getOperand(i: 1));
4774 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
4775 auto checkSVNonOpUses = [&](Instruction *I) {
4776 if (!I || I->getOperand(i: 0)->getType() != VT)
4777 return true;
4778 return any_of(Range: I->users(), P: [&](User *U) {
4779 return U != Op0 && U != Op1 &&
4780 !(isa<ShuffleVectorInst>(Val: U) &&
4781 (InputShuffles.contains(Ptr: cast<Instruction>(Val: U)) ||
4782 isInstructionTriviallyDead(I: cast<Instruction>(Val: U))));
4783 });
4784 };
4785 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
4786 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
4787 return false;
4788
4789 // Collect all the uses that are shuffles that we can transform together. We
4790 // may not have a single shuffle, but a group that can all be transformed
4791 // together profitably.
4792 SmallVector<ShuffleVectorInst *> Shuffles;
4793 auto collectShuffles = [&](Instruction *I) {
4794 for (auto *U : I->users()) {
4795 auto *SV = dyn_cast<ShuffleVectorInst>(Val: U);
4796 if (!SV || SV->getType() != VT)
4797 return false;
4798 if ((SV->getOperand(i_nocapture: 0) != Op0 && SV->getOperand(i_nocapture: 0) != Op1) ||
4799 (SV->getOperand(i_nocapture: 1) != Op0 && SV->getOperand(i_nocapture: 1) != Op1))
4800 return false;
4801 if (!llvm::is_contained(Range&: Shuffles, Element: SV))
4802 Shuffles.push_back(Elt: SV);
4803 }
4804 return true;
4805 };
4806 if (!collectShuffles(Op0) || !collectShuffles(Op1))
4807 return false;
4808 // From a reduction, we need to be processing a single shuffle, otherwise the
4809 // other uses will not be lane-invariant.
4810 if (FromReduction && Shuffles.size() > 1)
4811 return false;
4812
4813 // Add any shuffle uses for the shuffles we have found, to include them in our
4814 // cost calculations.
4815 if (!FromReduction) {
4816 for (ShuffleVectorInst *SV : Shuffles) {
4817 for (auto *U : SV->users()) {
4818 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(Val: U);
4819 if (SSV && isa<UndefValue>(Val: SSV->getOperand(i_nocapture: 1)) && SSV->getType() == VT)
4820 Shuffles.push_back(Elt: SSV);
4821 }
4822 }
4823 }
4824
4825 // For each of the output shuffles, we try to sort all the first vector
4826 // elements to the beginning, followed by the second array elements at the
4827 // end. If the binops are legalized to smaller vectors, this may reduce total
4828 // number of binops. We compute the ReconstructMask mask needed to convert
4829 // back to the original lane order.
4830 SmallVector<std::pair<int, int>> V1, V2;
4831 SmallVector<SmallVector<int>> OrigReconstructMasks;
4832 int MaxV1Elt = 0, MaxV2Elt = 0;
4833 unsigned NumElts = VT->getNumElements();
4834 for (ShuffleVectorInst *SVN : Shuffles) {
4835 SmallVector<int> Mask;
4836 SVN->getShuffleMask(Result&: Mask);
4837
4838 // Check the operands are the same as the original, or reversed (in which
4839 // case we need to commute the mask).
4840 Value *SVOp0 = SVN->getOperand(i_nocapture: 0);
4841 Value *SVOp1 = SVN->getOperand(i_nocapture: 1);
4842 if (isa<UndefValue>(Val: SVOp1)) {
4843 auto *SSV = cast<ShuffleVectorInst>(Val: SVOp0);
4844 SVOp0 = SSV->getOperand(i_nocapture: 0);
4845 SVOp1 = SSV->getOperand(i_nocapture: 1);
4846 for (int &Elem : Mask) {
4847 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
4848 return false;
4849 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elt: Elem);
4850 }
4851 }
4852 if (SVOp0 == Op1 && SVOp1 == Op0) {
4853 std::swap(a&: SVOp0, b&: SVOp1);
4854 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumElts);
4855 }
4856 if (SVOp0 != Op0 || SVOp1 != Op1)
4857 return false;
4858
4859 // Calculate the reconstruction mask for this shuffle, as the mask needed to
4860 // take the packed values from Op0/Op1 and reconstructing to the original
4861 // order.
4862 SmallVector<int> ReconstructMask;
4863 for (unsigned I = 0; I < Mask.size(); I++) {
4864 if (Mask[I] < 0) {
4865 ReconstructMask.push_back(Elt: -1);
4866 } else if (Mask[I] < static_cast<int>(NumElts)) {
4867 MaxV1Elt = std::max(a: MaxV1Elt, b: Mask[I]);
4868 auto It = find_if(Range&: V1, P: [&](const std::pair<int, int> &A) {
4869 return Mask[I] == A.first;
4870 });
4871 if (It != V1.end())
4872 ReconstructMask.push_back(Elt: It - V1.begin());
4873 else {
4874 ReconstructMask.push_back(Elt: V1.size());
4875 V1.emplace_back(Args&: Mask[I], Args: V1.size());
4876 }
4877 } else {
4878 MaxV2Elt = std::max<int>(a: MaxV2Elt, b: Mask[I] - NumElts);
4879 auto It = find_if(Range&: V2, P: [&](const std::pair<int, int> &A) {
4880 return Mask[I] - static_cast<int>(NumElts) == A.first;
4881 });
4882 if (It != V2.end())
4883 ReconstructMask.push_back(Elt: NumElts + It - V2.begin());
4884 else {
4885 ReconstructMask.push_back(Elt: NumElts + V2.size());
4886 V2.emplace_back(Args: Mask[I] - NumElts, Args: NumElts + V2.size());
4887 }
4888 }
4889 }
4890
4891 // For reductions, we know that the lane ordering out doesn't alter the
4892 // result. In-order can help simplify the shuffle away.
4893 if (FromReduction)
4894 sort(C&: ReconstructMask);
4895 OrigReconstructMasks.push_back(Elt: std::move(ReconstructMask));
4896 }
4897
4898 // If the Maximum element used from V1 and V2 are not larger than the new
4899 // vectors, the vectors are already packes and performing the optimization
4900 // again will likely not help any further. This also prevents us from getting
4901 // stuck in a cycle in case the costs do not also rule it out.
4902 if (V1.empty() || V2.empty() ||
4903 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
4904 MaxV2Elt == static_cast<int>(V2.size()) - 1))
4905 return false;
4906
4907 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
4908 // shuffle of another shuffle, or not a shuffle (that is treated like a
4909 // identity shuffle).
4910 auto GetBaseMaskValue = [&](Instruction *I, int M) {
4911 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
4912 if (!SV)
4913 return M;
4914 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
4915 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
4916 if (InputShuffles.contains(Ptr: SSV))
4917 return SSV->getMaskValue(Elt: SV->getMaskValue(Elt: M));
4918 return SV->getMaskValue(Elt: M);
4919 };
4920
4921 // Attempt to sort the inputs my ascending mask values to make simpler input
4922 // shuffles and push complex shuffles down to the uses. We sort on the first
4923 // of the two input shuffle orders, to try and get at least one input into a
4924 // nice order.
4925 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
4926 std::pair<int, int> Y) {
4927 int MXA = GetBaseMaskValue(A, X.first);
4928 int MYA = GetBaseMaskValue(A, Y.first);
4929 return MXA < MYA;
4930 };
4931 stable_sort(Range&: V1, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
4932 return SortBase(SVI0A, A, B);
4933 });
4934 stable_sort(Range&: V2, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
4935 return SortBase(SVI1A, A, B);
4936 });
4937 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
4938 // modified order of the input shuffles.
4939 SmallVector<SmallVector<int>> ReconstructMasks;
4940 for (const auto &Mask : OrigReconstructMasks) {
4941 SmallVector<int> ReconstructMask;
4942 for (int M : Mask) {
4943 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
4944 auto It = find_if(Range: V, P: [M](auto A) { return A.second == M; });
4945 assert(It != V.end() && "Expected all entries in Mask");
4946 return std::distance(first: V.begin(), last: It);
4947 };
4948 if (M < 0)
4949 ReconstructMask.push_back(Elt: -1);
4950 else if (M < static_cast<int>(NumElts)) {
4951 ReconstructMask.push_back(Elt: FindIndex(V1, M));
4952 } else {
4953 ReconstructMask.push_back(Elt: NumElts + FindIndex(V2, M));
4954 }
4955 }
4956 ReconstructMasks.push_back(Elt: std::move(ReconstructMask));
4957 }
4958
4959 // Calculate the masks needed for the new input shuffles, which get padded
4960 // with undef
4961 SmallVector<int> V1A, V1B, V2A, V2B;
4962 for (unsigned I = 0; I < V1.size(); I++) {
4963 V1A.push_back(Elt: GetBaseMaskValue(SVI0A, V1[I].first));
4964 V1B.push_back(Elt: GetBaseMaskValue(SVI0B, V1[I].first));
4965 }
4966 for (unsigned I = 0; I < V2.size(); I++) {
4967 V2A.push_back(Elt: GetBaseMaskValue(SVI1A, V2[I].first));
4968 V2B.push_back(Elt: GetBaseMaskValue(SVI1B, V2[I].first));
4969 }
4970 while (V1A.size() < NumElts) {
4971 V1A.push_back(Elt: PoisonMaskElem);
4972 V1B.push_back(Elt: PoisonMaskElem);
4973 }
4974 while (V2A.size() < NumElts) {
4975 V2A.push_back(Elt: PoisonMaskElem);
4976 V2B.push_back(Elt: PoisonMaskElem);
4977 }
4978
4979 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
4980 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
4981 if (!SV)
4982 return C;
4983 return C + TTI.getShuffleCost(Kind: isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1))
4984 ? TTI::SK_PermuteSingleSrc
4985 : TTI::SK_PermuteTwoSrc,
4986 DstTy: VT, SrcTy: VT, Mask: SV->getShuffleMask(), CostKind);
4987 };
4988 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
4989 return C +
4990 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
4991 };
4992
4993 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
4994 unsigned MaxVectorSize =
4995 TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector);
4996 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
4997 if (MaxElementsInVector == 0)
4998 return false;
4999 // When there are multiple shufflevector operations on the same input,
5000 // especially when the vector length is larger than the register size,
5001 // identical shuffle patterns may occur across different groups of elements.
5002 // To avoid overestimating the cost by counting these repeated shuffles more
5003 // than once, we only account for unique shuffle patterns. This adjustment
5004 // prevents inflated costs in the cost model for wide vectors split into
5005 // several register-sized groups.
5006 std::set<SmallVector<int, 4>> UniqueShuffles;
5007 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5008 // Compute the cost for performing the shuffle over the full vector.
5009 auto ShuffleCost =
5010 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5011 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5012 if (NumFullVectors < 2)
5013 return C + ShuffleCost;
5014 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5015 unsigned NumUniqueGroups = 0;
5016 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5017 // For each group of MaxElementsInVector contiguous elements,
5018 // collect their shuffle pattern and insert into the set of unique patterns.
5019 for (unsigned I = 0; I < NumFullVectors; ++I) {
5020 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5021 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5022 if (UniqueShuffles.insert(x: SubShuffle).second)
5023 NumUniqueGroups += 1;
5024 }
5025 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5026 };
5027 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5028 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5029 if (!SV)
5030 return C;
5031 SmallVector<int, 16> Mask;
5032 SV->getShuffleMask(Result&: Mask);
5033 return AddShuffleMaskAdjustedCost(C, Mask);
5034 };
5035 // Check that input consists of ShuffleVectors applied to the same input
5036 auto AllShufflesHaveSameOperands =
5037 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5038 if (InputShuffles.size() < 2)
5039 return false;
5040 ShuffleVectorInst *FirstSV =
5041 dyn_cast<ShuffleVectorInst>(Val: *InputShuffles.begin());
5042 if (!FirstSV)
5043 return false;
5044
5045 Value *In0 = FirstSV->getOperand(i_nocapture: 0), *In1 = FirstSV->getOperand(i_nocapture: 1);
5046 return std::all_of(
5047 first: std::next(x: InputShuffles.begin()), last: InputShuffles.end(),
5048 pred: [&](Instruction *I) {
5049 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5050 return SV && SV->getOperand(i_nocapture: 0) == In0 && SV->getOperand(i_nocapture: 1) == In1;
5051 });
5052 };
5053
5054 // Get the costs of the shuffles + binops before and after with the new
5055 // shuffle masks.
5056 InstructionCost CostBefore =
5057 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: VT, CostKind) +
5058 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: VT, CostKind);
5059 CostBefore += std::accumulate(first: Shuffles.begin(), last: Shuffles.end(),
5060 init: InstructionCost(0), binary_op: AddShuffleCost);
5061 if (AllShufflesHaveSameOperands(InputShuffles)) {
5062 UniqueShuffles.clear();
5063 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5064 init: InstructionCost(0), binary_op: AddShuffleAdjustedCost);
5065 } else {
5066 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5067 init: InstructionCost(0), binary_op: AddShuffleCost);
5068 }
5069
5070 // The new binops will be unused for lanes past the used shuffle lengths.
5071 // These types attempt to get the correct cost for that from the target.
5072 FixedVectorType *Op0SmallVT =
5073 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V1.size());
5074 FixedVectorType *Op1SmallVT =
5075 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V2.size());
5076 InstructionCost CostAfter =
5077 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: Op0SmallVT, CostKind) +
5078 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: Op1SmallVT, CostKind);
5079 UniqueShuffles.clear();
5080 CostAfter += std::accumulate(first: ReconstructMasks.begin(), last: ReconstructMasks.end(),
5081 init: InstructionCost(0), binary_op: AddShuffleMaskAdjustedCost);
5082 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5083 CostAfter +=
5084 std::accumulate(first: OutputShuffleMasks.begin(), last: OutputShuffleMasks.end(),
5085 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
5086
5087 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5088 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5089 << " vs CostAfter: " << CostAfter << "\n");
5090 if (CostBefore < CostAfter ||
5091 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5092 return false;
5093
5094 // The cost model has passed, create the new instructions.
5095 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5096 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5097 if (!SV)
5098 return I;
5099 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5100 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5101 if (InputShuffles.contains(Ptr: SSV))
5102 return SSV->getOperand(i_nocapture: Op);
5103 return SV->getOperand(i_nocapture: Op);
5104 };
5105 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5106 Value *NSV0A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0A, 0),
5107 V2: GetShuffleOperand(SVI0A, 1), Mask: V1A);
5108 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5109 Value *NSV0B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0B, 0),
5110 V2: GetShuffleOperand(SVI0B, 1), Mask: V1B);
5111 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5112 Value *NSV1A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1A, 0),
5113 V2: GetShuffleOperand(SVI1A, 1), Mask: V2A);
5114 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5115 Value *NSV1B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1B, 0),
5116 V2: GetShuffleOperand(SVI1B, 1), Mask: V2B);
5117 Builder.SetInsertPoint(Op0);
5118 Value *NOp0 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op0->getOpcode(),
5119 LHS: NSV0A, RHS: NSV0B);
5120 if (auto *I = dyn_cast<Instruction>(Val: NOp0))
5121 I->copyIRFlags(V: Op0, IncludeWrapFlags: true);
5122 Builder.SetInsertPoint(Op1);
5123 Value *NOp1 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op1->getOpcode(),
5124 LHS: NSV1A, RHS: NSV1B);
5125 if (auto *I = dyn_cast<Instruction>(Val: NOp1))
5126 I->copyIRFlags(V: Op1, IncludeWrapFlags: true);
5127
5128 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5129 Builder.SetInsertPoint(Shuffles[S]);
5130 Value *NSV = Builder.CreateShuffleVector(V1: NOp0, V2: NOp1, Mask: ReconstructMasks[S]);
5131 replaceValue(Old&: *Shuffles[S], New&: *NSV, Erase: false);
5132 }
5133
5134 Worklist.pushValue(V: NSV0A);
5135 Worklist.pushValue(V: NSV0B);
5136 Worklist.pushValue(V: NSV1A);
5137 Worklist.pushValue(V: NSV1B);
5138 return true;
5139}
5140
5141/// Check if instruction depends on ZExt and this ZExt can be moved after the
5142/// instruction. Move ZExt if it is profitable. For example:
5143/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5144/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5145/// Cost model calculations takes into account if zext(x) has other users and
5146/// whether it can be propagated through them too.
5147bool VectorCombine::shrinkType(Instruction &I) {
5148 Value *ZExted, *OtherOperand;
5149 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_ZExt(Op: m_Value(V&: ZExted)),
5150 R: m_Value(V&: OtherOperand))) &&
5151 !match(V: &I, P: m_LShr(L: m_ZExt(Op: m_Value(V&: ZExted)), R: m_Value(V&: OtherOperand))))
5152 return false;
5153
5154 Value *ZExtOperand = I.getOperand(i: I.getOperand(i: 0) == OtherOperand ? 1 : 0);
5155
5156 auto *BigTy = cast<FixedVectorType>(Val: I.getType());
5157 auto *SmallTy = cast<FixedVectorType>(Val: ZExted->getType());
5158 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5159
5160 if (I.getOpcode() == Instruction::LShr) {
5161 // Check that the shift amount is less than the number of bits in the
5162 // smaller type. Otherwise, the smaller lshr will return a poison value.
5163 KnownBits ShAmtKB = computeKnownBits(V: I.getOperand(i: 1), DL: *DL);
5164 if (ShAmtKB.getMaxValue().uge(RHS: BW))
5165 return false;
5166 } else {
5167 // Check that the expression overall uses at most the same number of bits as
5168 // ZExted
5169 KnownBits KB = computeKnownBits(V: &I, DL: *DL);
5170 if (KB.countMaxActiveBits() > BW)
5171 return false;
5172 }
5173
5174 // Calculate costs of leaving current IR as it is and moving ZExt operation
5175 // later, along with adding truncates if needed
5176 InstructionCost ZExtCost = TTI.getCastInstrCost(
5177 Opcode: Instruction::ZExt, Dst: BigTy, Src: SmallTy,
5178 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5179 InstructionCost CurrentCost = ZExtCost;
5180 InstructionCost ShrinkCost = 0;
5181
5182 // Calculate total cost and check that we can propagate through all ZExt users
5183 for (User *U : ZExtOperand->users()) {
5184 auto *UI = cast<Instruction>(Val: U);
5185 if (UI == &I) {
5186 CurrentCost +=
5187 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5188 ShrinkCost +=
5189 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5190 ShrinkCost += ZExtCost;
5191 continue;
5192 }
5193
5194 if (!Instruction::isBinaryOp(Opcode: UI->getOpcode()))
5195 return false;
5196
5197 // Check if we can propagate ZExt through its other users
5198 KnownBits KB = computeKnownBits(V: UI, DL: *DL);
5199 if (KB.countMaxActiveBits() > BW)
5200 return false;
5201
5202 CurrentCost += TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5203 ShrinkCost +=
5204 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5205 ShrinkCost += ZExtCost;
5206 }
5207
5208 // If the other instruction operand is not a constant, we'll need to
5209 // generate a truncate instruction. So we have to adjust cost
5210 if (!isa<Constant>(Val: OtherOperand))
5211 ShrinkCost += TTI.getCastInstrCost(
5212 Opcode: Instruction::Trunc, Dst: SmallTy, Src: BigTy,
5213 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5214
5215 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5216 // towards modifying the IR because shrinking opens opportunities for other
5217 // shrinking optimisations.
5218 if (ShrinkCost > CurrentCost)
5219 return false;
5220
5221 Builder.SetInsertPoint(&I);
5222 Value *Op0 = ZExted;
5223 Value *Op1 = Builder.CreateTrunc(V: OtherOperand, DestTy: SmallTy);
5224 // Keep the order of operands the same
5225 if (I.getOperand(i: 0) == OtherOperand)
5226 std::swap(a&: Op0, b&: Op1);
5227 Value *NewBinOp =
5228 Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(), LHS: Op0, RHS: Op1);
5229 cast<Instruction>(Val: NewBinOp)->copyIRFlags(V: &I);
5230 cast<Instruction>(Val: NewBinOp)->copyMetadata(SrcInst: I);
5231 Value *NewZExtr = Builder.CreateZExt(V: NewBinOp, DestTy: BigTy);
5232 replaceValue(Old&: I, New&: *NewZExtr);
5233 return true;
5234}
5235
5236/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5237/// shuffle (DstVec, SrcVec, Mask)
5238bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5239 Value *DstVec, *SrcVec;
5240 uint64_t ExtIdx, InsIdx;
5241 if (!match(V: &I,
5242 P: m_InsertElt(Val: m_Value(V&: DstVec),
5243 Elt: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx)),
5244 Idx: m_ConstantInt(V&: InsIdx))))
5245 return false;
5246
5247 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
5248 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
5249 // We can try combining vectors with different element sizes.
5250 if (!DstVecTy || !SrcVecTy ||
5251 SrcVecTy->getElementType() != DstVecTy->getElementType())
5252 return false;
5253
5254 unsigned NumDstElts = DstVecTy->getNumElements();
5255 unsigned NumSrcElts = SrcVecTy->getNumElements();
5256 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5257 return false;
5258
5259 // Insertion into poison is a cheaper single operand shuffle.
5260 TargetTransformInfo::ShuffleKind SK;
5261 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5262
5263 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5264 bool NeedDstSrcSwap = isa<PoisonValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec);
5265 if (NeedDstSrcSwap) {
5266 SK = TargetTransformInfo::SK_PermuteSingleSrc;
5267 Mask[InsIdx] = ExtIdx % NumDstElts;
5268 std::swap(a&: DstVec, b&: SrcVec);
5269 } else {
5270 SK = TargetTransformInfo::SK_PermuteTwoSrc;
5271 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
5272 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5273 }
5274
5275 // Cost
5276 auto *Ins = cast<InsertElementInst>(Val: &I);
5277 auto *Ext = cast<ExtractElementInst>(Val: I.getOperand(i: 1));
5278 InstructionCost InsCost =
5279 TTI.getVectorInstrCost(I: *Ins, Val: DstVecTy, CostKind, Index: InsIdx);
5280 InstructionCost ExtCost =
5281 TTI.getVectorInstrCost(I: *Ext, Val: DstVecTy, CostKind, Index: ExtIdx);
5282 InstructionCost OldCost = ExtCost + InsCost;
5283
5284 InstructionCost NewCost = 0;
5285 SmallVector<int> ExtToVecMask;
5286 if (!NeedExpOrNarrow) {
5287 // Ignore 'free' identity insertion shuffle.
5288 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5289 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5290 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind, Index: 0,
5291 SubTp: nullptr, Args: {DstVec, SrcVec});
5292 } else {
5293 // When creating a length-changing-vector, always try to keep the relevant
5294 // element in an equivalent position, so that bulk shuffles are more likely
5295 // to be useful.
5296 ExtToVecMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
5297 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5298 // Add cost for expanding or narrowing
5299 NewCost = TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
5300 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: ExtToVecMask, CostKind);
5301 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind);
5302 }
5303
5304 if (!Ext->hasOneUse())
5305 NewCost += ExtCost;
5306
5307 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5308 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5309 << "\n");
5310
5311 if (OldCost < NewCost)
5312 return false;
5313
5314 if (NeedExpOrNarrow) {
5315 if (!NeedDstSrcSwap)
5316 SrcVec = Builder.CreateShuffleVector(V: SrcVec, Mask: ExtToVecMask);
5317 else
5318 DstVec = Builder.CreateShuffleVector(V: DstVec, Mask: ExtToVecMask);
5319 }
5320
5321 // Canonicalize undef param to RHS to help further folds.
5322 if (isa<UndefValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec)) {
5323 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumDstElts);
5324 std::swap(a&: DstVec, b&: SrcVec);
5325 }
5326
5327 Value *Shuf = Builder.CreateShuffleVector(V1: DstVec, V2: SrcVec, Mask);
5328 replaceValue(Old&: I, New&: *Shuf);
5329
5330 return true;
5331}
5332
5333/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5334/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5335/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5336/// before casting it back into `<vscale x 16 x i32>`.
5337bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5338 const APInt *SplatVal0, *SplatVal1;
5339 if (!match(V: &I, P: m_Intrinsic<Intrinsic::vector_interleave2>(
5340 Op0: m_APInt(Res&: SplatVal0), Op1: m_APInt(Res&: SplatVal1))))
5341 return false;
5342
5343 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5344 << "\n");
5345
5346 auto *VTy =
5347 cast<VectorType>(Val: cast<IntrinsicInst>(Val&: I).getArgOperand(i: 0)->getType());
5348 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5349 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5350
5351 // Just in case the cost of interleave2 intrinsic and bitcast are both
5352 // invalid, in which case we want to bail out, we use <= rather
5353 // than < here. Even they both have valid and equal costs, it's probably
5354 // not a good idea to emit a high-cost constant splat.
5355 if (TTI.getInstructionCost(U: &I, CostKind) <=
5356 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: I.getType(), Src: ExtVTy,
5357 CCH: TTI::CastContextHint::None, CostKind)) {
5358 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5359 << *I.getType() << " is too high.\n");
5360 return false;
5361 }
5362
5363 APInt NewSplatVal = SplatVal1->zext(width: Width * 2);
5364 NewSplatVal <<= Width;
5365 NewSplatVal |= SplatVal0->zext(width: Width * 2);
5366 auto *NewSplat = ConstantVector::getSplat(
5367 EC: ExtVTy->getElementCount(), Elt: ConstantInt::get(Context&: F.getContext(), V: NewSplatVal));
5368
5369 IRBuilder<> Builder(&I);
5370 replaceValue(Old&: I, New&: *Builder.CreateBitCast(V: NewSplat, DestTy: I.getType()));
5371 return true;
5372}
5373
5374// Attempt to shrink loads that are only used by shufflevector instructions.
5375bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5376 auto *OldLoad = dyn_cast<LoadInst>(Val: &I);
5377 if (!OldLoad || !OldLoad->isSimple())
5378 return false;
5379
5380 auto *OldLoadTy = dyn_cast<FixedVectorType>(Val: OldLoad->getType());
5381 if (!OldLoadTy)
5382 return false;
5383
5384 unsigned const OldNumElements = OldLoadTy->getNumElements();
5385
5386 // Search all uses of load. If all uses are shufflevector instructions, and
5387 // the second operands are all poison values, find the minimum and maximum
5388 // indices of the vector elements referenced by all shuffle masks.
5389 // Otherwise return `std::nullopt`.
5390 using IndexRange = std::pair<int, int>;
5391 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5392 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5393 for (llvm::Use &Use : I.uses()) {
5394 // Ensure all uses match the required pattern.
5395 User *Shuffle = Use.getUser();
5396 ArrayRef<int> Mask;
5397
5398 if (!match(V: Shuffle,
5399 P: m_Shuffle(v1: m_Specific(V: OldLoad), v2: m_Undef(), mask: m_Mask(Mask))))
5400 return std::nullopt;
5401
5402 // Ignore shufflevector instructions that have no uses.
5403 if (Shuffle->use_empty())
5404 continue;
5405
5406 // Find the min and max indices used by the shufflevector instruction.
5407 for (int Index : Mask) {
5408 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5409 OutputRange.first = std::min(a: Index, b: OutputRange.first);
5410 OutputRange.second = std::max(a: Index, b: OutputRange.second);
5411 }
5412 }
5413 }
5414
5415 if (OutputRange.second < OutputRange.first)
5416 return std::nullopt;
5417
5418 return OutputRange;
5419 };
5420
5421 // Get the range of vector elements used by shufflevector instructions.
5422 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5423 unsigned const NewNumElements = Indices->second + 1u;
5424
5425 // If the range of vector elements is smaller than the full load, attempt
5426 // to create a smaller load.
5427 if (NewNumElements < OldNumElements) {
5428 IRBuilder Builder(&I);
5429 Builder.SetCurrentDebugLocation(I.getDebugLoc());
5430
5431 // Calculate costs of old and new ops.
5432 Type *ElemTy = OldLoadTy->getElementType();
5433 FixedVectorType *NewLoadTy = FixedVectorType::get(ElementType: ElemTy, NumElts: NewNumElements);
5434 Value *PtrOp = OldLoad->getPointerOperand();
5435
5436 InstructionCost OldCost = TTI.getMemoryOpCost(
5437 Opcode: Instruction::Load, Src: OldLoad->getType(), Alignment: OldLoad->getAlign(),
5438 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5439 InstructionCost NewCost =
5440 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: NewLoadTy, Alignment: OldLoad->getAlign(),
5441 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5442
5443 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
5444 SmallVector<UseEntry, 4u> NewUses;
5445 unsigned const MaxIndex = NewNumElements * 2u;
5446
5447 for (llvm::Use &Use : I.uses()) {
5448 auto *Shuffle = cast<ShuffleVectorInst>(Val: Use.getUser());
5449 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
5450
5451 // Create entry for new use.
5452 NewUses.push_back(Elt: {Shuffle, OldMask});
5453
5454 // Validate mask indices.
5455 for (int Index : OldMask) {
5456 if (Index >= static_cast<int>(MaxIndex))
5457 return false;
5458 }
5459
5460 // Update costs.
5461 OldCost +=
5462 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5463 SrcTy: OldLoadTy, Mask: OldMask, CostKind);
5464 NewCost +=
5465 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5466 SrcTy: NewLoadTy, Mask: OldMask, CostKind);
5467 }
5468
5469 LLVM_DEBUG(
5470 dbgs() << "Found a load used only by shufflevector instructions: "
5471 << I << "\n OldCost: " << OldCost
5472 << " vs NewCost: " << NewCost << "\n");
5473
5474 if (OldCost < NewCost || !NewCost.isValid())
5475 return false;
5476
5477 // Create new load of smaller vector.
5478 auto *NewLoad = cast<LoadInst>(
5479 Val: Builder.CreateAlignedLoad(Ty: NewLoadTy, Ptr: PtrOp, Align: OldLoad->getAlign()));
5480 NewLoad->copyMetadata(SrcInst: I);
5481
5482 // Replace all uses.
5483 for (UseEntry &Use : NewUses) {
5484 ShuffleVectorInst *Shuffle = Use.first;
5485 std::vector<int> &NewMask = Use.second;
5486
5487 Builder.SetInsertPoint(Shuffle);
5488 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
5489 Value *NewShuffle = Builder.CreateShuffleVector(
5490 V1: NewLoad, V2: PoisonValue::get(T: NewLoadTy), Mask: NewMask);
5491
5492 replaceValue(Old&: *Shuffle, New&: *NewShuffle, Erase: false);
5493 }
5494
5495 return true;
5496 }
5497 }
5498 return false;
5499}
5500
5501// Attempt to narrow a phi of shufflevector instructions where the two incoming
5502// values have the same operands but different masks. If the two shuffle masks
5503// are offsets of one another we can use one branch to rotate the incoming
5504// vector and perform one larger shuffle after the phi.
5505bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
5506 auto *Phi = dyn_cast<PHINode>(Val: &I);
5507 if (!Phi || Phi->getNumIncomingValues() != 2u)
5508 return false;
5509
5510 Value *Op = nullptr;
5511 ArrayRef<int> Mask0;
5512 ArrayRef<int> Mask1;
5513
5514 if (!match(V: Phi->getOperand(i_nocapture: 0u),
5515 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op), v2: m_Poison(), mask: m_Mask(Mask0)))) ||
5516 !match(V: Phi->getOperand(i_nocapture: 1u),
5517 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Specific(V: Op), v2: m_Poison(), mask: m_Mask(Mask1)))))
5518 return false;
5519
5520 auto *Shuf = cast<ShuffleVectorInst>(Val: Phi->getOperand(i_nocapture: 0u));
5521
5522 // Ensure result vectors are wider than the argument vector.
5523 auto *InputVT = cast<FixedVectorType>(Val: Op->getType());
5524 auto *ResultVT = cast<FixedVectorType>(Val: Shuf->getType());
5525 auto const InputNumElements = InputVT->getNumElements();
5526
5527 if (InputNumElements >= ResultVT->getNumElements())
5528 return false;
5529
5530 // Take the difference of the two shuffle masks at each index. Ignore poison
5531 // values at the same index in both masks.
5532 SmallVector<int, 16> NewMask;
5533 NewMask.reserve(N: Mask0.size());
5534
5535 for (auto [M0, M1] : zip(t&: Mask0, u&: Mask1)) {
5536 if (M0 >= 0 && M1 >= 0)
5537 NewMask.push_back(Elt: M0 - M1);
5538 else if (M0 == -1 && M1 == -1)
5539 continue;
5540 else
5541 return false;
5542 }
5543
5544 // Ensure all elements of the new mask are equal. If the difference between
5545 // the incoming mask elements is the same, the two must be constant offsets
5546 // of one another.
5547 if (NewMask.empty() || !all_equal(Range&: NewMask))
5548 return false;
5549
5550 // Create new mask using difference of the two incoming masks.
5551 int MaskOffset = NewMask[0u];
5552 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
5553 NewMask.clear();
5554
5555 for (unsigned I = 0u; I < InputNumElements; ++I) {
5556 NewMask.push_back(Elt: Index);
5557 Index = (Index + 1u) % InputNumElements;
5558 }
5559
5560 // Calculate costs for worst cases and compare.
5561 auto const Kind = TTI::SK_PermuteSingleSrc;
5562 auto OldCost =
5563 std::max(a: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask0, CostKind),
5564 b: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind));
5565 auto NewCost = TTI.getShuffleCost(Kind, DstTy: InputVT, SrcTy: InputVT, Mask: NewMask, CostKind) +
5566 TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind);
5567
5568 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
5569 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5570 << "\n");
5571
5572 if (NewCost > OldCost)
5573 return false;
5574
5575 // Create new shuffles and narrowed phi.
5576 auto Builder = IRBuilder(Shuf);
5577 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
5578 auto *PoisonVal = PoisonValue::get(T: InputVT);
5579 auto *NewShuf0 = Builder.CreateShuffleVector(V1: Op, V2: PoisonVal, Mask: NewMask);
5580 Worklist.push(I: cast<Instruction>(Val: NewShuf0));
5581
5582 Builder.SetInsertPoint(Phi);
5583 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
5584 auto *NewPhi = Builder.CreatePHI(Ty: NewShuf0->getType(), NumReservedValues: 2u);
5585 NewPhi->addIncoming(V: NewShuf0, BB: Phi->getIncomingBlock(i: 0u));
5586 NewPhi->addIncoming(V: Op, BB: Phi->getIncomingBlock(i: 1u));
5587
5588 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
5589 PoisonVal = PoisonValue::get(T: NewPhi->getType());
5590 auto *NewShuf1 = Builder.CreateShuffleVector(V1: NewPhi, V2: PoisonVal, Mask: Mask1);
5591
5592 replaceValue(Old&: *Phi, New&: *NewShuf1);
5593 return true;
5594}
5595
5596/// This is the entry point for all transforms. Pass manager differences are
5597/// handled in the callers of this function.
5598bool VectorCombine::run() {
5599 if (DisableVectorCombine)
5600 return false;
5601
5602 // Don't attempt vectorization if the target does not support vectors.
5603 if (!TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(/*Vector*/ true)))
5604 return false;
5605
5606 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
5607
5608 auto FoldInst = [this](Instruction &I) {
5609 Builder.SetInsertPoint(&I);
5610 bool IsVectorType = isa<VectorType>(Val: I.getType());
5611 bool IsFixedVectorType = isa<FixedVectorType>(Val: I.getType());
5612 auto Opcode = I.getOpcode();
5613
5614 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
5615
5616 // These folds should be beneficial regardless of when this pass is run
5617 // in the optimization pipeline.
5618 // The type checking is for run-time efficiency. We can avoid wasting time
5619 // dispatching to folding functions if there's no chance of matching.
5620 if (IsFixedVectorType) {
5621 switch (Opcode) {
5622 case Instruction::InsertElement:
5623 if (vectorizeLoadInsert(I))
5624 return true;
5625 break;
5626 case Instruction::ShuffleVector:
5627 if (widenSubvectorLoad(I))
5628 return true;
5629 break;
5630 default:
5631 break;
5632 }
5633 }
5634
5635 // This transform works with scalable and fixed vectors
5636 // TODO: Identify and allow other scalable transforms
5637 if (IsVectorType) {
5638 if (scalarizeOpOrCmp(I))
5639 return true;
5640 if (scalarizeLoad(I))
5641 return true;
5642 if (scalarizeExtExtract(I))
5643 return true;
5644 if (scalarizeVPIntrinsic(I))
5645 return true;
5646 if (foldInterleaveIntrinsics(I))
5647 return true;
5648 }
5649
5650 if (Opcode == Instruction::Store)
5651 if (foldSingleElementStore(I))
5652 return true;
5653
5654 // If this is an early pipeline invocation of this pass, we are done.
5655 if (TryEarlyFoldsOnly)
5656 return false;
5657
5658 // Otherwise, try folds that improve codegen but may interfere with
5659 // early IR canonicalizations.
5660 // The type checking is for run-time efficiency. We can avoid wasting time
5661 // dispatching to folding functions if there's no chance of matching.
5662 if (IsFixedVectorType) {
5663 switch (Opcode) {
5664 case Instruction::InsertElement:
5665 if (foldInsExtFNeg(I))
5666 return true;
5667 if (foldInsExtBinop(I))
5668 return true;
5669 if (foldInsExtVectorToShuffle(I))
5670 return true;
5671 break;
5672 case Instruction::ShuffleVector:
5673 if (foldPermuteOfBinops(I))
5674 return true;
5675 if (foldShuffleOfBinops(I))
5676 return true;
5677 if (foldShuffleOfSelects(I))
5678 return true;
5679 if (foldShuffleOfCastops(I))
5680 return true;
5681 if (foldShuffleOfShuffles(I))
5682 return true;
5683 if (foldPermuteOfIntrinsic(I))
5684 return true;
5685 if (foldShufflesOfLengthChangingShuffles(I))
5686 return true;
5687 if (foldShuffleOfIntrinsics(I))
5688 return true;
5689 if (foldSelectShuffle(I))
5690 return true;
5691 if (foldShuffleToIdentity(I))
5692 return true;
5693 break;
5694 case Instruction::Load:
5695 if (shrinkLoadForShuffles(I))
5696 return true;
5697 break;
5698 case Instruction::BitCast:
5699 if (foldBitcastShuffle(I))
5700 return true;
5701 if (foldSelectsFromBitcast(I))
5702 return true;
5703 break;
5704 case Instruction::And:
5705 case Instruction::Or:
5706 case Instruction::Xor:
5707 if (foldBitOpOfCastops(I))
5708 return true;
5709 if (foldBitOpOfCastConstant(I))
5710 return true;
5711 break;
5712 case Instruction::PHI:
5713 if (shrinkPhiOfShuffles(I))
5714 return true;
5715 break;
5716 default:
5717 if (shrinkType(I))
5718 return true;
5719 break;
5720 }
5721 } else {
5722 switch (Opcode) {
5723 case Instruction::Call:
5724 if (foldShuffleFromReductions(I))
5725 return true;
5726 if (foldCastFromReductions(I))
5727 return true;
5728 break;
5729 case Instruction::ExtractElement:
5730 if (foldShuffleChainsToReduce(I))
5731 return true;
5732 break;
5733 case Instruction::ICmp:
5734 if (foldSignBitReductionCmp(I))
5735 return true;
5736 if (foldICmpEqZeroVectorReduce(I))
5737 return true;
5738 if (foldEquivalentReductionCmp(I))
5739 return true;
5740 [[fallthrough]];
5741 case Instruction::FCmp:
5742 if (foldExtractExtract(I))
5743 return true;
5744 break;
5745 case Instruction::Or:
5746 if (foldConcatOfBoolMasks(I))
5747 return true;
5748 [[fallthrough]];
5749 default:
5750 if (Instruction::isBinaryOp(Opcode)) {
5751 if (foldExtractExtract(I))
5752 return true;
5753 if (foldExtractedCmps(I))
5754 return true;
5755 if (foldBinopOfReductions(I))
5756 return true;
5757 }
5758 break;
5759 }
5760 }
5761 return false;
5762 };
5763
5764 bool MadeChange = false;
5765 for (BasicBlock &BB : F) {
5766 // Ignore unreachable basic blocks.
5767 if (!DT.isReachableFromEntry(A: &BB))
5768 continue;
5769 // Use early increment range so that we can erase instructions in loop.
5770 // make_early_inc_range is not applicable here, as the next iterator may
5771 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
5772 // We manually maintain the next instruction and update it when it is about
5773 // to be deleted.
5774 Instruction *I = &BB.front();
5775 while (I) {
5776 NextInst = I->getNextNode();
5777 if (!I->isDebugOrPseudoInst())
5778 MadeChange |= FoldInst(*I);
5779 I = NextInst;
5780 }
5781 }
5782
5783 NextInst = nullptr;
5784
5785 while (!Worklist.isEmpty()) {
5786 Instruction *I = Worklist.removeOne();
5787 if (!I)
5788 continue;
5789
5790 if (isInstructionTriviallyDead(I)) {
5791 eraseInstruction(I&: *I);
5792 continue;
5793 }
5794
5795 MadeChange |= FoldInst(*I);
5796 }
5797
5798 return MadeChange;
5799}
5800
5801PreservedAnalyses VectorCombinePass::run(Function &F,
5802 FunctionAnalysisManager &FAM) {
5803 auto &AC = FAM.getResult<AssumptionAnalysis>(IR&: F);
5804 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
5805 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
5806 AAResults &AA = FAM.getResult<AAManager>(IR&: F);
5807 const DataLayout *DL = &F.getDataLayout();
5808 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
5809 TryEarlyFoldsOnly);
5810 if (!Combiner.run())
5811 return PreservedAnalyses::all();
5812 PreservedAnalyses PA;
5813 PA.preserveSet<CFGAnalyses>();
5814 return PA;
5815}
5816