1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Vectorize/VectorCombine.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/Statistic.h"
21#include "llvm/Analysis/AssumptionCache.h"
22#include "llvm/Analysis/BasicAliasAnalysis.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/Loads.h"
26#include "llvm/Analysis/TargetFolder.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/Analysis/ValueTracking.h"
29#include "llvm/Analysis/VectorUtils.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/PatternMatch.h"
35#include "llvm/Support/CommandLine.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Transforms/Utils/Local.h"
38#include "llvm/Transforms/Utils/LoopUtils.h"
39#include <numeric>
40#include <optional>
41#include <queue>
42#include <set>
43
44#define DEBUG_TYPE "vector-combine"
45#include "llvm/Transforms/Utils/InstructionWorklist.h"
46
47using namespace llvm;
48using namespace llvm::PatternMatch;
49
50STATISTIC(NumVecLoad, "Number of vector loads formed");
51STATISTIC(NumVecCmp, "Number of vector compares formed");
52STATISTIC(NumVecBO, "Number of vector binops formed");
53STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
54STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
55STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
56STATISTIC(NumScalarCmp, "Number of scalar compares formed");
57STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
58
59static cl::opt<bool> DisableVectorCombine(
60 "disable-vector-combine", cl::init(Val: false), cl::Hidden,
61 cl::desc("Disable all vector combine transforms"));
62
63static cl::opt<bool> DisableBinopExtractShuffle(
64 "disable-binop-extract-shuffle", cl::init(Val: false), cl::Hidden,
65 cl::desc("Disable binop extract to shuffle transforms"));
66
67static cl::opt<unsigned> MaxInstrsToScan(
68 "vector-combine-max-scan-instrs", cl::init(Val: 30), cl::Hidden,
69 cl::desc("Max number of instructions to scan for vector combining."));
70
71static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
72
73namespace {
74class VectorCombine {
75public:
76 VectorCombine(Function &F, const TargetTransformInfo &TTI,
77 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
78 const DataLayout *DL, TTI::TargetCostKind CostKind,
79 bool TryEarlyFoldsOnly)
80 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
81 DT(DT), AA(AA), AC(AC), DL(DL), CostKind(CostKind), SQ(*DL),
82 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
83
84 bool run();
85
86private:
87 Function &F;
88 IRBuilder<InstSimplifyFolder> Builder;
89 const TargetTransformInfo &TTI;
90 const DominatorTree &DT;
91 AAResults &AA;
92 AssumptionCache &AC;
93 const DataLayout *DL;
94 TTI::TargetCostKind CostKind;
95 const SimplifyQuery SQ;
96
97 /// If true, only perform beneficial early IR transforms. Do not introduce new
98 /// vector operations.
99 bool TryEarlyFoldsOnly;
100
101 InstructionWorklist Worklist;
102
103 /// Next instruction to iterate. It will be updated when it is erased by
104 /// RecursivelyDeleteTriviallyDeadInstructions.
105 Instruction *NextInst;
106
107 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
108 // parameter. That should be updated to specific sub-classes because the
109 // run loop was changed to dispatch on opcode.
110 bool vectorizeLoadInsert(Instruction &I);
111 bool widenSubvectorLoad(Instruction &I);
112 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
113 ExtractElementInst *Ext1,
114 unsigned PreferredExtractIndex) const;
115 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
116 const Instruction &I,
117 ExtractElementInst *&ConvertToShuffle,
118 unsigned PreferredExtractIndex);
119 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
120 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 bool foldExtractExtract(Instruction &I);
122 bool foldInsExtFNeg(Instruction &I);
123 bool foldInsExtBinop(Instruction &I);
124 bool foldInsExtVectorToShuffle(Instruction &I);
125 bool foldBitOpOfCastops(Instruction &I);
126 bool foldBitOpOfCastConstant(Instruction &I);
127 bool foldBitcastShuffle(Instruction &I);
128 bool scalarizeOpOrCmp(Instruction &I);
129 bool scalarizeVPIntrinsic(Instruction &I);
130 bool foldExtractedCmps(Instruction &I);
131 bool foldSelectsFromBitcast(Instruction &I);
132 bool foldBinopOfReductions(Instruction &I);
133 bool foldSingleElementStore(Instruction &I);
134 bool scalarizeLoad(Instruction &I);
135 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
136 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeExtExtract(Instruction &I);
138 bool foldConcatOfBoolMasks(Instruction &I);
139 bool foldPermuteOfBinops(Instruction &I);
140 bool foldShuffleOfBinops(Instruction &I);
141 bool foldShuffleOfSelects(Instruction &I);
142 bool foldShuffleOfCastops(Instruction &I);
143 bool foldShuffleOfShuffles(Instruction &I);
144 bool foldPermuteOfIntrinsic(Instruction &I);
145 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
146 bool foldShuffleOfIntrinsics(Instruction &I);
147 bool foldShuffleToIdentity(Instruction &I);
148 bool foldShuffleFromReductions(Instruction &I);
149 bool foldShuffleChainsToReduce(Instruction &I);
150 bool foldCastFromReductions(Instruction &I);
151 bool foldSignBitReductionCmp(Instruction &I);
152 bool foldICmpEqZeroVectorReduce(Instruction &I);
153 bool foldEquivalentReductionCmp(Instruction &I);
154 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
155 bool foldInterleaveIntrinsics(Instruction &I);
156 bool shrinkType(Instruction &I);
157 bool shrinkLoadForShuffles(Instruction &I);
158 bool shrinkPhiOfShuffles(Instruction &I);
159
160 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
161 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
162 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
163 Old.replaceAllUsesWith(V: &New);
164 if (auto *NewI = dyn_cast<Instruction>(Val: &New)) {
165 New.takeName(V: &Old);
166 Worklist.pushUsersToWorkList(I&: *NewI);
167 Worklist.pushValue(V: NewI);
168 }
169 if (Erase && isInstructionTriviallyDead(I: &Old)) {
170 eraseInstruction(I&: Old);
171 } else {
172 Worklist.push(I: &Old);
173 }
174 }
175
176 void eraseInstruction(Instruction &I) {
177 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
178 SmallVector<Value *> Ops(I.operands());
179 Worklist.remove(I: &I);
180 I.eraseFromParent();
181
182 // Push remaining users of the operands and then the operand itself - allows
183 // further folds that were hindered by OneUse limits.
184 SmallPtrSet<Value *, 4> Visited;
185 for (Value *Op : Ops) {
186 if (!Visited.contains(Ptr: Op)) {
187 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
188 if (RecursivelyDeleteTriviallyDeadInstructions(
189 V: OpI, TLI: nullptr, MSSAU: nullptr, AboutToDeleteCallback: [&](Value *V) {
190 if (auto *I = dyn_cast<Instruction>(Val: V)) {
191 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
192 Worklist.remove(I);
193 if (I == NextInst)
194 NextInst = NextInst->getNextNode();
195 Visited.insert(Ptr: I);
196 }
197 }))
198 continue;
199 Worklist.pushUsersToWorkList(I&: *OpI);
200 Worklist.pushValue(V: OpI);
201 }
202 }
203 }
204 }
205};
206} // namespace
207
208/// Return the source operand of a potentially bitcasted value. If there is no
209/// bitcast, return the input value itself.
210static Value *peekThroughBitcasts(Value *V) {
211 while (auto *BitCast = dyn_cast<BitCastInst>(Val: V))
212 V = BitCast->getOperand(i_nocapture: 0);
213 return V;
214}
215
216static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
217 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
218 // The widened load may load data from dirty regions or create data races
219 // non-existent in the source.
220 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
221 Load->getFunction()->hasFnAttribute(Kind: Attribute::SanitizeMemTag) ||
222 mustSuppressSpeculation(LI: *Load))
223 return false;
224
225 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
226 // sure we have all of our type-based constraints in place for this target.
227 Type *ScalarTy = Load->getType()->getScalarType();
228 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
229 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
230 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
231 ScalarSize % 8 != 0)
232 return false;
233
234 return true;
235}
236
237bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
238 // Match insert into fixed vector of scalar value.
239 // TODO: Handle non-zero insert index.
240 Value *Scalar;
241 if (!match(V: &I,
242 P: m_InsertElt(Val: m_Poison(), Elt: m_OneUse(SubPattern: m_Value(V&: Scalar)), Idx: m_ZeroInt())))
243 return false;
244
245 // Optionally match an extract from another vector.
246 Value *X;
247 bool HasExtract = match(V: Scalar, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ZeroInt()));
248 if (!HasExtract)
249 X = Scalar;
250
251 auto *Load = dyn_cast<LoadInst>(Val: X);
252 if (!canWidenLoad(Load, TTI))
253 return false;
254
255 Type *ScalarTy = Scalar->getType();
256 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
257 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
258
259 // Check safety of replacing the scalar load with a larger vector load.
260 // We use minimal alignment (maximum flexibility) because we only care about
261 // the dereferenceable region. When calculating cost and creating a new op,
262 // we may use a larger value based on alignment attributes.
263 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
264 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
265
266 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
267 auto *MinVecTy = VectorType::get(ElementType: ScalarTy, NumElements: MinVecNumElts, Scalable: false);
268 unsigned OffsetEltIndex = 0;
269 Align Alignment = Load->getAlign();
270 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
271 DT: &DT)) {
272 // It is not safe to load directly from the pointer, but we can still peek
273 // through gep offsets and check if it safe to load from a base address with
274 // updated alignment. If it is, we can shuffle the element(s) into place
275 // after loading.
276 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(Ty: SrcPtr->getType());
277 APInt Offset(OffsetBitWidth, 0);
278 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL: *DL, Offset);
279
280 // We want to shuffle the result down from a high element of a vector, so
281 // the offset must be positive.
282 if (Offset.isNegative())
283 return false;
284
285 // The offset must be a multiple of the scalar element to shuffle cleanly
286 // in the element's size.
287 uint64_t ScalarSizeInBytes = ScalarSize / 8;
288 if (Offset.urem(RHS: ScalarSizeInBytes) != 0)
289 return false;
290
291 // If we load MinVecNumElts, will our target element still be loaded?
292 OffsetEltIndex = Offset.udiv(RHS: ScalarSizeInBytes).getZExtValue();
293 if (OffsetEltIndex >= MinVecNumElts)
294 return false;
295
296 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
297 DT: &DT))
298 return false;
299
300 // Update alignment with offset value. Note that the offset could be negated
301 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
302 // negation does not change the result of the alignment calculation.
303 Alignment = commonAlignment(A: Alignment, Offset: Offset.getZExtValue());
304 }
305
306 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
307 // Use the greater of the alignment on the load or its source pointer.
308 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
309 Type *LoadTy = Load->getType();
310 unsigned AS = Load->getPointerAddressSpace();
311 InstructionCost OldCost =
312 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
313 APInt DemandedElts = APInt::getOneBitSet(numBits: MinVecNumElts, BitNo: 0);
314 OldCost +=
315 TTI.getScalarizationOverhead(Ty: MinVecTy, DemandedElts,
316 /* Insert */ true, Extract: HasExtract, CostKind);
317
318 // New pattern: load VecPtr
319 InstructionCost NewCost =
320 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: MinVecTy, Alignment, AddressSpace: AS, CostKind);
321 // Optionally, we are shuffling the loaded vector element(s) into place.
322 // For the mask set everything but element 0 to undef to prevent poison from
323 // propagating from the extra loaded memory. This will also optionally
324 // shrink/grow the vector from the loaded size to the output size.
325 // We assume this operation has no cost in codegen if there was no offset.
326 // Note that we could use freeze to avoid poison problems, but then we might
327 // still need a shuffle to change the vector size.
328 auto *Ty = cast<FixedVectorType>(Val: I.getType());
329 unsigned OutputNumElts = Ty->getNumElements();
330 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
331 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
332 Mask[0] = OffsetEltIndex;
333 if (OffsetEltIndex)
334 NewCost += TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Ty, SrcTy: MinVecTy, Mask,
335 CostKind);
336
337 // We can aggressively convert to the vector form because the backend can
338 // invert this transform if it does not result in a performance win.
339 if (OldCost < NewCost || !NewCost.isValid())
340 return false;
341
342 // It is safe and potentially profitable to load a vector directly:
343 // inselt undef, load Scalar, 0 --> load VecPtr
344 IRBuilder<> Builder(Load);
345 Value *CastedPtr =
346 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
347 Value *VecLd = Builder.CreateAlignedLoad(Ty: MinVecTy, Ptr: CastedPtr, Align: Alignment);
348 VecLd = Builder.CreateShuffleVector(V: VecLd, Mask);
349
350 replaceValue(Old&: I, New&: *VecLd);
351 ++NumVecLoad;
352 return true;
353}
354
355/// If we are loading a vector and then inserting it into a larger vector with
356/// undefined elements, try to load the larger vector and eliminate the insert.
357/// This removes a shuffle in IR and may allow combining of other loaded values.
358bool VectorCombine::widenSubvectorLoad(Instruction &I) {
359 // Match subvector insert of fixed vector.
360 auto *Shuf = cast<ShuffleVectorInst>(Val: &I);
361 if (!Shuf->isIdentityWithPadding())
362 return false;
363
364 // Allow a non-canonical shuffle mask that is choosing elements from op1.
365 unsigned NumOpElts =
366 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
367 unsigned OpIndex = any_of(Range: Shuf->getShuffleMask(), P: [&NumOpElts](int M) {
368 return M >= (int)(NumOpElts);
369 });
370
371 auto *Load = dyn_cast<LoadInst>(Val: Shuf->getOperand(i_nocapture: OpIndex));
372 if (!canWidenLoad(Load, TTI))
373 return false;
374
375 // We use minimal alignment (maximum flexibility) because we only care about
376 // the dereferenceable region. When calculating cost and creating a new op,
377 // we may use a larger value based on alignment attributes.
378 auto *Ty = cast<FixedVectorType>(Val: I.getType());
379 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
380 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
381 Align Alignment = Load->getAlign();
382 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC, DT: &DT))
383 return false;
384
385 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
386 Type *LoadTy = Load->getType();
387 unsigned AS = Load->getPointerAddressSpace();
388
389 // Original pattern: insert_subvector (load PtrOp)
390 // This conservatively assumes that the cost of a subvector insert into an
391 // undef value is 0. We could add that cost if the cost model accurately
392 // reflects the real cost of that operation.
393 InstructionCost OldCost =
394 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
395
396 // New pattern: load PtrOp
397 InstructionCost NewCost =
398 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: Ty, Alignment, AddressSpace: AS, CostKind);
399
400 // We can aggressively convert to the vector form because the backend can
401 // invert this transform if it does not result in a performance win.
402 if (OldCost < NewCost || !NewCost.isValid())
403 return false;
404
405 IRBuilder<> Builder(Load);
406 Value *CastedPtr =
407 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
408 Value *VecLd = Builder.CreateAlignedLoad(Ty, Ptr: CastedPtr, Align: Alignment);
409 replaceValue(Old&: I, New&: *VecLd);
410 ++NumVecLoad;
411 return true;
412}
413
414/// Determine which, if any, of the inputs should be replaced by a shuffle
415/// followed by extract from a different index.
416ExtractElementInst *VectorCombine::getShuffleExtract(
417 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
418 unsigned PreferredExtractIndex = InvalidIndex) const {
419 auto *Index0C = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
420 auto *Index1C = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
421 assert(Index0C && Index1C && "Expected constant extract indexes");
422
423 unsigned Index0 = Index0C->getZExtValue();
424 unsigned Index1 = Index1C->getZExtValue();
425
426 // If the extract indexes are identical, no shuffle is needed.
427 if (Index0 == Index1)
428 return nullptr;
429
430 Type *VecTy = Ext0->getVectorOperand()->getType();
431 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
432 InstructionCost Cost0 =
433 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
434 InstructionCost Cost1 =
435 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
436
437 // If both costs are invalid no shuffle is needed
438 if (!Cost0.isValid() && !Cost1.isValid())
439 return nullptr;
440
441 // We are extracting from 2 different indexes, so one operand must be shuffled
442 // before performing a vector operation and/or extract. The more expensive
443 // extract will be replaced by a shuffle.
444 if (Cost0 > Cost1)
445 return Ext0;
446 if (Cost1 > Cost0)
447 return Ext1;
448
449 // If the costs are equal and there is a preferred extract index, shuffle the
450 // opposite operand.
451 if (PreferredExtractIndex == Index0)
452 return Ext1;
453 if (PreferredExtractIndex == Index1)
454 return Ext0;
455
456 // Otherwise, replace the extract with the higher index.
457 return Index0 > Index1 ? Ext0 : Ext1;
458}
459
460/// Compare the relative costs of 2 extracts followed by scalar operation vs.
461/// vector operation(s) followed by extract. Return true if the existing
462/// instructions are cheaper than a vector alternative. Otherwise, return false
463/// and if one of the extracts should be transformed to a shufflevector, set
464/// \p ConvertToShuffle to that extract instruction.
465bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
466 ExtractElementInst *Ext1,
467 const Instruction &I,
468 ExtractElementInst *&ConvertToShuffle,
469 unsigned PreferredExtractIndex) {
470 auto *Ext0IndexC = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
471 auto *Ext1IndexC = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
472 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
473
474 unsigned Opcode = I.getOpcode();
475 Value *Ext0Src = Ext0->getVectorOperand();
476 Value *Ext1Src = Ext1->getVectorOperand();
477 Type *ScalarTy = Ext0->getType();
478 auto *VecTy = cast<VectorType>(Val: Ext0Src->getType());
479 InstructionCost ScalarOpCost, VectorOpCost;
480
481 // Get cost estimates for scalar and vector versions of the operation.
482 bool IsBinOp = Instruction::isBinaryOp(Opcode);
483 if (IsBinOp) {
484 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
485 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
486 } else {
487 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
488 "Expected a compare");
489 CmpInst::Predicate Pred = cast<CmpInst>(Val: I).getPredicate();
490 ScalarOpCost = TTI.getCmpSelInstrCost(
491 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
492 VectorOpCost = TTI.getCmpSelInstrCost(
493 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
494 }
495
496 // Get cost estimates for the extract elements. These costs will factor into
497 // both sequences.
498 unsigned Ext0Index = Ext0IndexC->getZExtValue();
499 unsigned Ext1Index = Ext1IndexC->getZExtValue();
500
501 InstructionCost Extract0Cost =
502 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Ext0Index);
503 InstructionCost Extract1Cost =
504 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Ext1Index);
505
506 // A more expensive extract will always be replaced by a splat shuffle.
507 // For example, if Ext0 is more expensive:
508 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
509 // extelt (opcode (splat V0, Ext0), V1), Ext1
510 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
511 // check the cost of creating a broadcast shuffle and shuffling both
512 // operands to element 0.
513 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
514 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
515 InstructionCost CheapExtractCost = std::min(a: Extract0Cost, b: Extract1Cost);
516
517 // Extra uses of the extracts mean that we include those costs in the
518 // vector total because those instructions will not be eliminated.
519 InstructionCost OldCost, NewCost;
520 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
521 // Handle a special case. If the 2 extracts are identical, adjust the
522 // formulas to account for that. The extra use charge allows for either the
523 // CSE'd pattern or an unoptimized form with identical values:
524 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
525 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(N: 2)
526 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
527 OldCost = CheapExtractCost + ScalarOpCost;
528 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
529 } else {
530 // Handle the general case. Each extract is actually a different value:
531 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
532 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost +
534 !Ext0->hasOneUse() * Extract0Cost +
535 !Ext1->hasOneUse() * Extract1Cost;
536 }
537
538 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
539 if (ConvertToShuffle) {
540 if (IsBinOp && DisableBinopExtractShuffle)
541 return true;
542
543 // If we are extracting from 2 different indexes, then one operand must be
544 // shuffled before performing the vector operation. The shuffle mask is
545 // poison except for 1 lane that is being translated to the remaining
546 // extraction lane. Therefore, it is a splat shuffle. Ex:
547 // ShufMask = { poison, poison, 0, poison }
548 // TODO: The cost model has an option for a "broadcast" shuffle
549 // (splat-from-element-0), but no option for a more general splat.
550 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: VecTy)) {
551 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
552 PoisonMaskElem);
553 ShuffleMask[BestInsIndex] = BestExtIndex;
554 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
555 DstTy: VecTy, SrcTy: VecTy, Mask: ShuffleMask, CostKind, Index: 0,
556 SubTp: nullptr, Args: {ConvertToShuffle});
557 } else {
558 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
559 DstTy: VecTy, SrcTy: VecTy, Mask: {}, CostKind, Index: 0, SubTp: nullptr,
560 Args: {ConvertToShuffle});
561 }
562 }
563
564 // Aggressively form a vector op if the cost is equal because the transform
565 // may enable further optimization.
566 // Codegen can reverse this transform (scalarize) if it was not profitable.
567 return OldCost < NewCost;
568}
569
570/// Create a shuffle that translates (shifts) 1 element from the input vector
571/// to a new element location.
572static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
573 unsigned NewIndex, IRBuilderBase &Builder) {
574 // The shuffle mask is poison except for 1 lane that is being translated
575 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
576 // ShufMask = { 2, poison, poison, poison }
577 auto *VecTy = cast<FixedVectorType>(Val: Vec->getType());
578 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
579 ShufMask[NewIndex] = OldIndex;
580 return Builder.CreateShuffleVector(V: Vec, Mask: ShufMask, Name: "shift");
581}
582
583/// Given an extract element instruction with constant index operand, shuffle
584/// the source vector (shift the scalar element) to a NewIndex for extraction.
585/// Return null if the input can be constant folded, so that we are not creating
586/// unnecessary instructions.
587static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
588 IRBuilderBase &Builder) {
589 // Shufflevectors can only be created for fixed-width vectors.
590 Value *X = ExtElt->getVectorOperand();
591 if (!isa<FixedVectorType>(Val: X->getType()))
592 return nullptr;
593
594 // If the extract can be constant-folded, this code is unsimplified. Defer
595 // to other passes to handle that.
596 Value *C = ExtElt->getIndexOperand();
597 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
598 if (isa<Constant>(Val: X))
599 return nullptr;
600
601 Value *Shuf = createShiftShuffle(Vec: X, OldIndex: cast<ConstantInt>(Val: C)->getZExtValue(),
602 NewIndex, Builder);
603 return Shuf;
604}
605
606/// Try to reduce extract element costs by converting scalar compares to vector
607/// compares followed by extract.
608/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
609Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
610 Instruction &I) {
611 assert(isa<CmpInst>(&I) && "Expected a compare");
612
613 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
614 // --> extelt (cmp Pred V0, V1), ExtIndex
615 ++NumVecCmp;
616 CmpInst::Predicate Pred = cast<CmpInst>(Val: &I)->getPredicate();
617 Value *VecCmp = Builder.CreateCmp(Pred, LHS: V0, RHS: V1);
618 return Builder.CreateExtractElement(Vec: VecCmp, Idx: ExtIndex, Name: "foldExtExtCmp");
619}
620
621/// Try to reduce extract element costs by converting scalar binops to vector
622/// binops followed by extract.
623/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
624Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
625 Instruction &I) {
626 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
627
628 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
629 // --> extelt (bo V0, V1), ExtIndex
630 ++NumVecBO;
631 Value *VecBO = Builder.CreateBinOp(Opc: cast<BinaryOperator>(Val: &I)->getOpcode(), LHS: V0,
632 RHS: V1, Name: "foldExtExtBinop");
633
634 // All IR flags are safe to back-propagate because any potential poison
635 // created in unused vector elements is discarded by the extract.
636 if (auto *VecBOInst = dyn_cast<Instruction>(Val: VecBO))
637 VecBOInst->copyIRFlags(V: &I);
638
639 return Builder.CreateExtractElement(Vec: VecBO, Idx: ExtIndex, Name: "foldExtExtBinop");
640}
641
642/// Match an instruction with extracted vector operands.
643bool VectorCombine::foldExtractExtract(Instruction &I) {
644 // It is not safe to transform things like div, urem, etc. because we may
645 // create undefined behavior when executing those on unknown vector elements.
646 if (!isSafeToSpeculativelyExecute(I: &I))
647 return false;
648
649 Instruction *I0, *I1;
650 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
651 if (!match(V: &I, P: m_Cmp(Pred, L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))) &&
652 !match(V: &I, P: m_BinOp(L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))))
653 return false;
654
655 Value *V0, *V1;
656 uint64_t C0, C1;
657 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: V0), Idx: m_ConstantInt(V&: C0))) ||
658 !match(V: I1, P: m_ExtractElt(Val: m_Value(V&: V1), Idx: m_ConstantInt(V&: C1))) ||
659 V0->getType() != V1->getType())
660 return false;
661
662 // If the scalar value 'I' is going to be re-inserted into a vector, then try
663 // to create an extract to that same element. The extract/insert can be
664 // reduced to a "select shuffle".
665 // TODO: If we add a larger pattern match that starts from an insert, this
666 // probably becomes unnecessary.
667 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
668 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
669 uint64_t InsertIndex = InvalidIndex;
670 if (I.hasOneUse())
671 match(V: I.user_back(),
672 P: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ConstantInt(V&: InsertIndex)));
673
674 ExtractElementInst *ExtractToChange;
675 if (isExtractExtractCheap(Ext0, Ext1, I, ConvertToShuffle&: ExtractToChange, PreferredExtractIndex: InsertIndex))
676 return false;
677
678 Value *ExtOp0 = Ext0->getVectorOperand();
679 Value *ExtOp1 = Ext1->getVectorOperand();
680
681 if (ExtractToChange) {
682 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
683 Value *NewExtOp =
684 translateExtract(ExtElt: ExtractToChange, NewIndex: CheapExtractIdx, Builder);
685 if (!NewExtOp)
686 return false;
687 if (ExtractToChange == Ext0)
688 ExtOp0 = NewExtOp;
689 else
690 ExtOp1 = NewExtOp;
691 }
692
693 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
694 : Ext0->getIndexOperand();
695 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
696 ? foldExtExtCmp(V0: ExtOp0, V1: ExtOp1, ExtIndex, I)
697 : foldExtExtBinop(V0: ExtOp0, V1: ExtOp1, ExtIndex, I);
698 Worklist.push(I: Ext0);
699 Worklist.push(I: Ext1);
700 replaceValue(Old&: I, New&: *NewExt);
701 return true;
702}
703
704/// Try to replace an extract + scalar fneg + insert with a vector fneg +
705/// shuffle.
706bool VectorCombine::foldInsExtFNeg(Instruction &I) {
707 // Match an insert (op (extract)) pattern.
708 Value *DstVec;
709 uint64_t ExtIdx, InsIdx;
710 Instruction *FNeg;
711 if (!match(V: &I, P: m_InsertElt(Val: m_Value(V&: DstVec), Elt: m_OneUse(SubPattern: m_Instruction(I&: FNeg)),
712 Idx: m_ConstantInt(V&: InsIdx))))
713 return false;
714
715 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
716 Value *SrcVec;
717 Instruction *Extract;
718 if (!match(V: FNeg, P: m_FNeg(X: m_CombineAnd(
719 L: m_Instruction(I&: Extract),
720 R: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx))))))
721 return false;
722
723 auto *DstVecTy = cast<FixedVectorType>(Val: DstVec->getType());
724 auto *DstVecScalarTy = DstVecTy->getScalarType();
725 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
726 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
727 return false;
728
729 // Ignore if insert/extract index is out of bounds or destination vector has
730 // one element
731 unsigned NumDstElts = DstVecTy->getNumElements();
732 unsigned NumSrcElts = SrcVecTy->getNumElements();
733 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
734 return false;
735
736 // We are inserting the negated element into the same lane that we extracted
737 // from. This is equivalent to a select-shuffle that chooses all but the
738 // negated element from the destination vector.
739 SmallVector<int> Mask(NumDstElts);
740 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
741 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
742 InstructionCost OldCost =
743 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: DstVecScalarTy, CostKind) +
744 TTI.getVectorInstrCost(I, Val: DstVecTy, CostKind, Index: InsIdx);
745
746 // If the extract has one use, it will be eliminated, so count it in the
747 // original cost. If it has more than one use, ignore the cost because it will
748 // be the same before/after.
749 if (Extract->hasOneUse())
750 OldCost += TTI.getVectorInstrCost(I: *Extract, Val: SrcVecTy, CostKind, Index: ExtIdx);
751
752 InstructionCost NewCost =
753 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: SrcVecTy, CostKind) +
754 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: DstVecTy,
755 SrcTy: DstVecTy, Mask, CostKind);
756
757 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
758 // If the lengths of the two vectors are not equal,
759 // we need to add a length-change vector. Add this cost.
760 SmallVector<int> SrcMask;
761 if (NeedLenChg) {
762 SrcMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
763 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
764 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
765 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: SrcMask, CostKind);
766 }
767
768 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
769 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
770 << "\n");
771 if (NewCost > OldCost)
772 return false;
773
774 Value *NewShuf, *LenChgShuf = nullptr;
775 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
776 Value *VecFNeg = Builder.CreateFNegFMF(V: SrcVec, FMFSource: FNeg);
777 if (NeedLenChg) {
778 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
779 LenChgShuf = Builder.CreateShuffleVector(V: VecFNeg, Mask: SrcMask);
780 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: LenChgShuf, Mask);
781 Worklist.pushValue(V: LenChgShuf);
782 } else {
783 // shuffle DstVec, (fneg SrcVec), Mask
784 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: VecFNeg, Mask);
785 }
786
787 Worklist.pushValue(V: VecFNeg);
788 replaceValue(Old&: I, New&: *NewShuf);
789 return true;
790}
791
792/// Try to fold insert(binop(x,y),binop(a,b),idx)
793/// --> binop(insert(x,a,idx),insert(y,b,idx))
794bool VectorCombine::foldInsExtBinop(Instruction &I) {
795 BinaryOperator *VecBinOp, *SclBinOp;
796 uint64_t Index;
797 if (!match(V: &I,
798 P: m_InsertElt(Val: m_OneUse(SubPattern: m_BinOp(I&: VecBinOp)),
799 Elt: m_OneUse(SubPattern: m_BinOp(I&: SclBinOp)), Idx: m_ConstantInt(V&: Index))))
800 return false;
801
802 // TODO: Add support for addlike etc.
803 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
804 if (BinOpcode != SclBinOp->getOpcode())
805 return false;
806
807 auto *ResultTy = dyn_cast<FixedVectorType>(Val: I.getType());
808 if (!ResultTy)
809 return false;
810
811 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
812 // shuffle?
813
814 InstructionCost OldCost = TTI.getInstructionCost(U: &I, CostKind) +
815 TTI.getInstructionCost(U: VecBinOp, CostKind) +
816 TTI.getInstructionCost(U: SclBinOp, CostKind);
817 InstructionCost NewCost =
818 TTI.getArithmeticInstrCost(Opcode: BinOpcode, Ty: ResultTy, CostKind) +
819 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
820 Index, Op0: VecBinOp->getOperand(i_nocapture: 0),
821 Op1: SclBinOp->getOperand(i_nocapture: 0)) +
822 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
823 Index, Op0: VecBinOp->getOperand(i_nocapture: 1),
824 Op1: SclBinOp->getOperand(i_nocapture: 1));
825
826 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
827 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
828 << "\n");
829 if (NewCost > OldCost)
830 return false;
831
832 Value *NewIns0 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 0),
833 NewElt: SclBinOp->getOperand(i_nocapture: 0), Idx: Index);
834 Value *NewIns1 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 1),
835 NewElt: SclBinOp->getOperand(i_nocapture: 1), Idx: Index);
836 Value *NewBO = Builder.CreateBinOp(Opc: BinOpcode, LHS: NewIns0, RHS: NewIns1);
837
838 // Intersect flags from the old binops.
839 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
840 NewInst->copyIRFlags(V: VecBinOp);
841 NewInst->andIRFlags(V: SclBinOp);
842 }
843
844 Worklist.pushValue(V: NewIns0);
845 Worklist.pushValue(V: NewIns1);
846 replaceValue(Old&: I, New&: *NewBO);
847 return true;
848}
849
850/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
851/// Supports: bitcast, trunc, sext, zext
852bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
853 // Check if this is a bitwise logic operation
854 auto *BinOp = dyn_cast<BinaryOperator>(Val: &I);
855 if (!BinOp || !BinOp->isBitwiseLogicOp())
856 return false;
857
858 // Get the cast instructions
859 auto *LHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 0));
860 auto *RHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 1));
861 if (!LHSCast || !RHSCast) {
862 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
863 return false;
864 }
865
866 // Both casts must be the same type
867 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
868 if (CastOpcode != RHSCast->getOpcode())
869 return false;
870
871 // Only handle supported cast operations
872 switch (CastOpcode) {
873 case Instruction::BitCast:
874 case Instruction::Trunc:
875 case Instruction::SExt:
876 case Instruction::ZExt:
877 break;
878 default:
879 return false;
880 }
881
882 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
883 Value *RHSSrc = RHSCast->getOperand(i_nocapture: 0);
884
885 // Source types must match
886 if (LHSSrc->getType() != RHSSrc->getType())
887 return false;
888
889 auto *SrcTy = LHSSrc->getType();
890 auto *DstTy = I.getType();
891 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
892 // Other casts only handle vector types with integer elements.
893 if (CastOpcode != Instruction::BitCast &&
894 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
895 return false;
896
897 // Only integer scalar/vector values are legal for bitwise logic operations.
898 if (!SrcTy->getScalarType()->isIntegerTy() ||
899 !DstTy->getScalarType()->isIntegerTy())
900 return false;
901
902 // Cost Check :
903 // OldCost = bitlogic + 2*casts
904 // NewCost = bitlogic + cast
905
906 // Calculate specific costs for each cast with instruction context
907 InstructionCost LHSCastCost = TTI.getCastInstrCost(
908 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
909 InstructionCost RHSCastCost = TTI.getCastInstrCost(
910 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: RHSCast);
911
912 InstructionCost OldCost =
913 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: DstTy, CostKind) +
914 LHSCastCost + RHSCastCost;
915
916 // For new cost, we can't provide an instruction (it doesn't exist yet)
917 InstructionCost GenericCastCost = TTI.getCastInstrCost(
918 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
919
920 InstructionCost NewCost =
921 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: SrcTy, CostKind) +
922 GenericCastCost;
923
924 // Account for multi-use casts using specific costs
925 if (!LHSCast->hasOneUse())
926 NewCost += LHSCastCost;
927 if (!RHSCast->hasOneUse())
928 NewCost += RHSCastCost;
929
930 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
931 << " NewCost=" << NewCost << "\n");
932
933 if (NewCost > OldCost)
934 return false;
935
936 // Create the operation on the source type
937 Value *NewOp = Builder.CreateBinOp(Opc: BinOp->getOpcode(), LHS: LHSSrc, RHS: RHSSrc,
938 Name: BinOp->getName() + ".inner");
939 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
940 NewBinOp->copyIRFlags(V: BinOp);
941
942 Worklist.pushValue(V: NewOp);
943
944 // Create the cast operation directly to ensure we get a new instruction
945 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
946
947 // Preserve cast instruction flags
948 NewCast->copyIRFlags(V: LHSCast);
949 NewCast->andIRFlags(V: RHSCast);
950
951 // Insert the new instruction
952 Value *Result = Builder.Insert(I: NewCast);
953
954 replaceValue(Old&: I, New&: *Result);
955 return true;
956}
957
958/// Match:
959// bitop(castop(x), C) ->
960// bitop(castop(x), castop(InvC)) ->
961// castop(bitop(x, InvC))
962// Supports: bitcast
963bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
964 Instruction *LHS;
965 Constant *C;
966
967 // Check if this is a bitwise logic operation
968 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_Instruction(I&: LHS), R: m_Constant(C))))
969 return false;
970
971 // Get the cast instructions
972 auto *LHSCast = dyn_cast<CastInst>(Val: LHS);
973 if (!LHSCast)
974 return false;
975
976 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
977
978 // Only handle supported cast operations
979 switch (CastOpcode) {
980 case Instruction::BitCast:
981 case Instruction::ZExt:
982 case Instruction::SExt:
983 case Instruction::Trunc:
984 break;
985 default:
986 return false;
987 }
988
989 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
990
991 auto *SrcTy = LHSSrc->getType();
992 auto *DstTy = I.getType();
993 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
994 // Other casts only handle vector types with integer elements.
995 if (CastOpcode != Instruction::BitCast &&
996 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
997 return false;
998
999 // Only integer scalar/vector values are legal for bitwise logic operations.
1000 if (!SrcTy->getScalarType()->isIntegerTy() ||
1001 !DstTy->getScalarType()->isIntegerTy())
1002 return false;
1003
1004 // Find the constant InvC, such that castop(InvC) equals to C.
1005 PreservedCastFlags RHSFlags;
1006 Constant *InvC = getLosslessInvCast(C, InvCastTo: SrcTy, CastOp: CastOpcode, DL: *DL, Flags: &RHSFlags);
1007 if (!InvC)
1008 return false;
1009
1010 // Cost Check :
1011 // OldCost = bitlogic + cast
1012 // NewCost = bitlogic + cast
1013
1014 // Calculate specific costs for each cast with instruction context
1015 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1016 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
1017
1018 InstructionCost OldCost =
1019 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: DstTy, CostKind) + LHSCastCost;
1020
1021 // For new cost, we can't provide an instruction (it doesn't exist yet)
1022 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1023 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
1024
1025 InstructionCost NewCost =
1026 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: SrcTy, CostKind) +
1027 GenericCastCost;
1028
1029 // Account for multi-use casts using specific costs
1030 if (!LHSCast->hasOneUse())
1031 NewCost += LHSCastCost;
1032
1033 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1034 << " NewCost=" << NewCost << "\n");
1035
1036 if (NewCost > OldCost)
1037 return false;
1038
1039 // Create the operation on the source type
1040 Value *NewOp = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(),
1041 LHS: LHSSrc, RHS: InvC, Name: I.getName() + ".inner");
1042 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
1043 NewBinOp->copyIRFlags(V: &I);
1044
1045 Worklist.pushValue(V: NewOp);
1046
1047 // Create the cast operation directly to ensure we get a new instruction
1048 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
1049
1050 // Preserve cast instruction flags
1051 if (RHSFlags.NNeg)
1052 NewCast->setNonNeg();
1053 if (RHSFlags.NUW)
1054 NewCast->setHasNoUnsignedWrap();
1055 if (RHSFlags.NSW)
1056 NewCast->setHasNoSignedWrap();
1057
1058 NewCast->andIRFlags(V: LHSCast);
1059
1060 // Insert the new instruction
1061 Value *Result = Builder.Insert(I: NewCast);
1062
1063 replaceValue(Old&: I, New&: *Result);
1064 return true;
1065}
1066
1067/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1068/// destination type followed by shuffle. This can enable further transforms by
1069/// moving bitcasts or shuffles together.
1070bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1071 Value *V0, *V1;
1072 ArrayRef<int> Mask;
1073 if (!match(V: &I, P: m_BitCast(Op: m_OneUse(
1074 SubPattern: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(Mask))))))
1075 return false;
1076
1077 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1078 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1079 // mask for scalable type is a splat or not.
1080 // 2) Disallow non-vector casts.
1081 // TODO: We could allow any shuffle.
1082 auto *DestTy = dyn_cast<FixedVectorType>(Val: I.getType());
1083 auto *SrcTy = dyn_cast<FixedVectorType>(Val: V0->getType());
1084 if (!DestTy || !SrcTy)
1085 return false;
1086
1087 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1088 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1089 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1090 return false;
1091
1092 bool IsUnary = isa<UndefValue>(Val: V1);
1093
1094 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1095 // if it won't increase the number of bitcasts.
1096 if (!IsUnary) {
1097 auto *BCTy0 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V0)->getType());
1098 auto *BCTy1 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V1)->getType());
1099 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1100 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1101 return false;
1102 }
1103
1104 SmallVector<int, 16> NewMask;
1105 if (DestEltSize <= SrcEltSize) {
1106 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1107 // always be expanded to the equivalent form choosing narrower elements.
1108 if (SrcEltSize % DestEltSize != 0)
1109 return false;
1110 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1111 narrowShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask);
1112 } else {
1113 // The bitcast is from narrow elements to wide elements. The shuffle mask
1114 // must choose consecutive elements to allow casting first.
1115 if (DestEltSize % SrcEltSize != 0)
1116 return false;
1117 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1118 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask))
1119 return false;
1120 }
1121
1122 // Bitcast the shuffle src - keep its original width but using the destination
1123 // scalar type.
1124 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1125 auto *NewShuffleTy =
1126 FixedVectorType::get(ElementType: DestTy->getScalarType(), NumElts: NumSrcElts);
1127 auto *OldShuffleTy =
1128 FixedVectorType::get(ElementType: SrcTy->getScalarType(), NumElts: Mask.size());
1129 unsigned NumOps = IsUnary ? 1 : 2;
1130
1131 // The new shuffle must not cost more than the old shuffle.
1132 TargetTransformInfo::ShuffleKind SK =
1133 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
1134 : TargetTransformInfo::SK_PermuteTwoSrc;
1135
1136 InstructionCost NewCost =
1137 TTI.getShuffleCost(Kind: SK, DstTy: DestTy, SrcTy: NewShuffleTy, Mask: NewMask, CostKind) +
1138 (NumOps * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewShuffleTy, Src: SrcTy,
1139 CCH: TargetTransformInfo::CastContextHint::None,
1140 CostKind));
1141 InstructionCost OldCost =
1142 TTI.getShuffleCost(Kind: SK, DstTy: OldShuffleTy, SrcTy, Mask, CostKind) +
1143 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DestTy, Src: OldShuffleTy,
1144 CCH: TargetTransformInfo::CastContextHint::None,
1145 CostKind);
1146
1147 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1148 << OldCost << " vs NewCost: " << NewCost << "\n");
1149
1150 if (NewCost > OldCost || !NewCost.isValid())
1151 return false;
1152
1153 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1154 ++NumShufOfBitcast;
1155 Value *CastV0 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V0), DestTy: NewShuffleTy);
1156 Value *CastV1 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V1), DestTy: NewShuffleTy);
1157 Value *Shuf = Builder.CreateShuffleVector(V1: CastV0, V2: CastV1, Mask: NewMask);
1158 replaceValue(Old&: I, New&: *Shuf);
1159 return true;
1160}
1161
1162/// VP Intrinsics whose vector operands are both splat values may be simplified
1163/// into the scalar version of the operation and the result splatted. This
1164/// can lead to scalarization down the line.
1165bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1166 if (!isa<VPIntrinsic>(Val: I))
1167 return false;
1168 VPIntrinsic &VPI = cast<VPIntrinsic>(Val&: I);
1169 Value *Op0 = VPI.getArgOperand(i: 0);
1170 Value *Op1 = VPI.getArgOperand(i: 1);
1171
1172 if (!isSplatValue(V: Op0) || !isSplatValue(V: Op1))
1173 return false;
1174
1175 // Check getSplatValue early in this function, to avoid doing unnecessary
1176 // work.
1177 Value *ScalarOp0 = getSplatValue(V: Op0);
1178 Value *ScalarOp1 = getSplatValue(V: Op1);
1179 if (!ScalarOp0 || !ScalarOp1)
1180 return false;
1181
1182 // For the binary VP intrinsics supported here, the result on disabled lanes
1183 // is a poison value. For now, only do this simplification if all lanes
1184 // are active.
1185 // TODO: Relax the condition that all lanes are active by using insertelement
1186 // on inactive lanes.
1187 auto IsAllTrueMask = [](Value *MaskVal) {
1188 if (Value *SplattedVal = getSplatValue(V: MaskVal))
1189 if (auto *ConstValue = dyn_cast<Constant>(Val: SplattedVal))
1190 return ConstValue->isAllOnesValue();
1191 return false;
1192 };
1193 if (!IsAllTrueMask(VPI.getArgOperand(i: 2)))
1194 return false;
1195
1196 // Check to make sure we support scalarization of the intrinsic
1197 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1198 if (!VPBinOpIntrinsic::isVPBinOp(ID: IntrID))
1199 return false;
1200
1201 // Calculate cost of splatting both operands into vectors and the vector
1202 // intrinsic
1203 VectorType *VecTy = cast<VectorType>(Val: VPI.getType());
1204 SmallVector<int> Mask;
1205 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: VecTy))
1206 Mask.resize(N: FVTy->getNumElements(), NV: 0);
1207 InstructionCost SplatCost =
1208 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: 0) +
1209 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_Broadcast, DstTy: VecTy, SrcTy: VecTy, Mask,
1210 CostKind);
1211
1212 // Calculate the cost of the VP Intrinsic
1213 SmallVector<Type *, 4> Args;
1214 for (Value *V : VPI.args())
1215 Args.push_back(Elt: V->getType());
1216 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1217 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1218 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1219
1220 // Determine scalar opcode
1221 std::optional<unsigned> FunctionalOpcode =
1222 VPI.getFunctionalOpcode();
1223 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1224 if (!FunctionalOpcode) {
1225 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1226 if (!ScalarIntrID)
1227 return false;
1228 }
1229
1230 // Calculate cost of scalarizing
1231 InstructionCost ScalarOpCost = 0;
1232 if (ScalarIntrID) {
1233 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1234 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1235 } else {
1236 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode: *FunctionalOpcode,
1237 Ty: VecTy->getScalarType(), CostKind);
1238 }
1239
1240 // The existing splats may be kept around if other instructions use them.
1241 InstructionCost CostToKeepSplats =
1242 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1243 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1244
1245 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1246 << "\n");
1247 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1248 << ", Cost of scalarizing:" << NewCost << "\n");
1249
1250 // We want to scalarize unless the vector variant actually has lower cost.
1251 if (OldCost < NewCost || !NewCost.isValid())
1252 return false;
1253
1254 // Scalarize the intrinsic
1255 ElementCount EC = cast<VectorType>(Val: Op0->getType())->getElementCount();
1256 Value *EVL = VPI.getArgOperand(i: 3);
1257
1258 // If the VP op might introduce UB or poison, we can scalarize it provided
1259 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1260 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1261 // scalarizing it.
1262 bool SafeToSpeculate;
1263 if (ScalarIntrID)
1264 SafeToSpeculate = Intrinsic::getFnAttributes(C&: I.getContext(), id: *ScalarIntrID)
1265 .hasAttribute(Kind: Attribute::AttrKind::Speculatable);
1266 else
1267 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1268 Opcode: *FunctionalOpcode, Inst: &VPI, CtxI: nullptr, AC: &AC, DT: &DT);
1269 if (!SafeToSpeculate &&
1270 !isKnownNonZero(V: EVL, Q: SimplifyQuery(*DL, &DT, &AC, &VPI)))
1271 return false;
1272
1273 Value *ScalarVal =
1274 ScalarIntrID
1275 ? Builder.CreateIntrinsic(RetTy: VecTy->getScalarType(), ID: *ScalarIntrID,
1276 Args: {ScalarOp0, ScalarOp1})
1277 : Builder.CreateBinOp(Opc: (Instruction::BinaryOps)(*FunctionalOpcode),
1278 LHS: ScalarOp0, RHS: ScalarOp1);
1279
1280 replaceValue(Old&: VPI, New&: *Builder.CreateVectorSplat(EC, V: ScalarVal));
1281 return true;
1282}
1283
1284/// Match a vector op/compare/intrinsic with at least one
1285/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1286/// by insertelement.
1287bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1288 auto *UO = dyn_cast<UnaryOperator>(Val: &I);
1289 auto *BO = dyn_cast<BinaryOperator>(Val: &I);
1290 auto *CI = dyn_cast<CmpInst>(Val: &I);
1291 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
1292 if (!UO && !BO && !CI && !II)
1293 return false;
1294
1295 // TODO: Allow intrinsics with different argument types
1296 if (II) {
1297 if (!isTriviallyVectorizable(ID: II->getIntrinsicID()))
1298 return false;
1299 for (auto [Idx, Arg] : enumerate(First: II->args()))
1300 if (Arg->getType() != II->getType() &&
1301 !isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI))
1302 return false;
1303 }
1304
1305 // Do not convert the vector condition of a vector select into a scalar
1306 // condition. That may cause problems for codegen because of differences in
1307 // boolean formats and register-file transfers.
1308 // TODO: Can we account for that in the cost model?
1309 if (CI)
1310 for (User *U : I.users())
1311 if (match(V: U, P: m_Select(C: m_Specific(V: &I), L: m_Value(), R: m_Value())))
1312 return false;
1313
1314 // Match constant vectors or scalars being inserted into constant vectors:
1315 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1316 SmallVector<Value *> VecCs, ScalarOps;
1317 std::optional<uint64_t> Index;
1318
1319 auto Ops = II ? II->args() : I.operands();
1320 for (auto [OpNum, Op] : enumerate(First&: Ops)) {
1321 Constant *VecC;
1322 Value *V;
1323 uint64_t InsIdx = 0;
1324 if (match(V: Op.get(), P: m_InsertElt(Val: m_Constant(C&: VecC), Elt: m_Value(V),
1325 Idx: m_ConstantInt(V&: InsIdx)))) {
1326 // Bail if any inserts are out of bounds.
1327 VectorType *OpTy = cast<VectorType>(Val: Op->getType());
1328 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1329 return false;
1330 // All inserts must have the same index.
1331 // TODO: Deal with mismatched index constants and variable indexes?
1332 if (!Index)
1333 Index = InsIdx;
1334 else if (InsIdx != *Index)
1335 return false;
1336 VecCs.push_back(Elt: VecC);
1337 ScalarOps.push_back(Elt: V);
1338 } else if (II && isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(),
1339 ScalarOpdIdx: OpNum, TTI: &TTI)) {
1340 VecCs.push_back(Elt: Op.get());
1341 ScalarOps.push_back(Elt: Op.get());
1342 } else if (match(V: Op.get(), P: m_Constant(C&: VecC))) {
1343 VecCs.push_back(Elt: VecC);
1344 ScalarOps.push_back(Elt: nullptr);
1345 } else {
1346 return false;
1347 }
1348 }
1349
1350 // Bail if all operands are constant.
1351 if (!Index.has_value())
1352 return false;
1353
1354 VectorType *VecTy = cast<VectorType>(Val: I.getType());
1355 Type *ScalarTy = VecTy->getScalarType();
1356 assert(VecTy->isVectorTy() &&
1357 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1358 ScalarTy->isPointerTy()) &&
1359 "Unexpected types for insert element into binop or cmp");
1360
1361 unsigned Opcode = I.getOpcode();
1362 InstructionCost ScalarOpCost, VectorOpCost;
1363 if (CI) {
1364 CmpInst::Predicate Pred = CI->getPredicate();
1365 ScalarOpCost = TTI.getCmpSelInstrCost(
1366 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
1367 VectorOpCost = TTI.getCmpSelInstrCost(
1368 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1369 } else if (UO || BO) {
1370 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
1371 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
1372 } else {
1373 IntrinsicCostAttributes ScalarICA(
1374 II->getIntrinsicID(), ScalarTy,
1375 SmallVector<Type *>(II->arg_size(), ScalarTy));
1376 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: ScalarICA, CostKind);
1377 IntrinsicCostAttributes VectorICA(
1378 II->getIntrinsicID(), VecTy,
1379 SmallVector<Type *>(II->arg_size(), VecTy));
1380 VectorOpCost = TTI.getIntrinsicInstrCost(ICA: VectorICA, CostKind);
1381 }
1382
1383 // Fold the vector constants in the original vectors into a new base vector to
1384 // get more accurate cost modelling.
1385 Value *NewVecC = nullptr;
1386 if (CI)
1387 NewVecC = simplifyCmpInst(Predicate: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1388 else if (UO)
1389 NewVecC =
1390 simplifyUnOp(Opcode: UO->getOpcode(), Op: VecCs[0], FMF: UO->getFastMathFlags(), Q: SQ);
1391 else if (BO)
1392 NewVecC = simplifyBinOp(Opcode: BO->getOpcode(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1393 else if (II)
1394 NewVecC = simplifyCall(Call: II, Callee: II->getCalledOperand(), Args: VecCs, Q: SQ);
1395
1396 if (!NewVecC)
1397 return false;
1398
1399 // Get cost estimate for the insert element. This cost will factor into
1400 // both sequences.
1401 InstructionCost OldCost = VectorOpCost;
1402 InstructionCost NewCost =
1403 ScalarOpCost + TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy,
1404 CostKind, Index: *Index, Op0: NewVecC);
1405
1406 for (auto [Idx, Op, VecC, Scalar] : enumerate(First&: Ops, Rest&: VecCs, Rest&: ScalarOps)) {
1407 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1408 ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI)))
1409 continue;
1410 InstructionCost InsertCost = TTI.getVectorInstrCost(
1411 Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: *Index, Op0: VecC, Op1: Scalar);
1412 OldCost += InsertCost;
1413 NewCost += !Op->hasOneUse() * InsertCost;
1414 }
1415
1416 // We want to scalarize unless the vector variant actually has lower cost.
1417 if (OldCost < NewCost || !NewCost.isValid())
1418 return false;
1419
1420 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1421 // inselt NewVecC, (scalar_op V0, V1), Index
1422 if (CI)
1423 ++NumScalarCmp;
1424 else if (UO || BO)
1425 ++NumScalarOps;
1426 else
1427 ++NumScalarIntrinsic;
1428
1429 // For constant cases, extract the scalar element, this should constant fold.
1430 for (auto [OpIdx, Scalar, VecC] : enumerate(First&: ScalarOps, Rest&: VecCs))
1431 if (!Scalar)
1432 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1433 Vec: cast<Constant>(Val: VecC), Idx: Builder.getInt64(C: *Index));
1434
1435 Value *Scalar;
1436 if (CI)
1437 Scalar = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: ScalarOps[0], RHS: ScalarOps[1]);
1438 else if (UO || BO)
1439 Scalar = Builder.CreateNAryOp(Opc: Opcode, Ops: ScalarOps);
1440 else
1441 Scalar = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: II->getIntrinsicID(), Args: ScalarOps);
1442
1443 Scalar->setName(I.getName() + ".scalar");
1444
1445 // All IR flags are safe to back-propagate. There is no potential for extra
1446 // poison to be created by the scalar instruction.
1447 if (auto *ScalarInst = dyn_cast<Instruction>(Val: Scalar))
1448 ScalarInst->copyIRFlags(V: &I);
1449
1450 Value *Insert = Builder.CreateInsertElement(Vec: NewVecC, NewElt: Scalar, Idx: *Index);
1451 replaceValue(Old&: I, New&: *Insert);
1452 return true;
1453}
1454
1455/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1456/// a vector into vector operations followed by extract. Note: The SLP pass
1457/// may miss this pattern because of implementation problems.
1458bool VectorCombine::foldExtractedCmps(Instruction &I) {
1459 auto *BI = dyn_cast<BinaryOperator>(Val: &I);
1460
1461 // We are looking for a scalar binop of booleans.
1462 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1463 if (!BI || !I.getType()->isIntegerTy(Bitwidth: 1))
1464 return false;
1465
1466 // The compare predicates should match, and each compare should have a
1467 // constant operand.
1468 Value *B0 = I.getOperand(i: 0), *B1 = I.getOperand(i: 1);
1469 Instruction *I0, *I1;
1470 Constant *C0, *C1;
1471 CmpPredicate P0, P1;
1472 if (!match(V: B0, P: m_Cmp(Pred&: P0, L: m_Instruction(I&: I0), R: m_Constant(C&: C0))) ||
1473 !match(V: B1, P: m_Cmp(Pred&: P1, L: m_Instruction(I&: I1), R: m_Constant(C&: C1))))
1474 return false;
1475
1476 auto MatchingPred = CmpPredicate::getMatching(A: P0, B: P1);
1477 if (!MatchingPred)
1478 return false;
1479
1480 // The compare operands must be extracts of the same vector with constant
1481 // extract indexes.
1482 Value *X;
1483 uint64_t Index0, Index1;
1484 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ConstantInt(V&: Index0))) ||
1485 !match(V: I1, P: m_ExtractElt(Val: m_Specific(V: X), Idx: m_ConstantInt(V&: Index1))))
1486 return false;
1487
1488 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
1489 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
1490 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex: CostKind);
1491 if (!ConvertToShuf)
1492 return false;
1493 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1494 "Unknown ExtractElementInst");
1495
1496 // The original scalar pattern is:
1497 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1498 CmpInst::Predicate Pred = *MatchingPred;
1499 unsigned CmpOpcode =
1500 CmpInst::isFPPredicate(P: Pred) ? Instruction::FCmp : Instruction::ICmp;
1501 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
1502 if (!VecTy)
1503 return false;
1504
1505 InstructionCost Ext0Cost =
1506 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
1507 InstructionCost Ext1Cost =
1508 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
1509 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1510 Opcode: CmpOpcode, ValTy: I0->getType(), CondTy: CmpInst::makeCmpResultType(opnd_type: I0->getType()), VecPred: Pred,
1511 CostKind);
1512
1513 InstructionCost OldCost =
1514 Ext0Cost + Ext1Cost + CmpCost * 2 +
1515 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: I.getType(), CostKind);
1516
1517 // The proposed vector pattern is:
1518 // vcmp = cmp Pred X, VecC
1519 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1520 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1521 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1522 auto *CmpTy = cast<FixedVectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
1523 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1524 Opcode: CmpOpcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1525 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1526 ShufMask[CheapIndex] = ExpensiveIndex;
1527 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: CmpTy,
1528 SrcTy: CmpTy, Mask: ShufMask, CostKind);
1529 NewCost += TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: CmpTy, CostKind);
1530 NewCost += TTI.getVectorInstrCost(I: *Ext0, Val: CmpTy, CostKind, Index: CheapIndex);
1531 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1532 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1533
1534 // Aggressively form vector ops if the cost is equal because the transform
1535 // may enable further optimization.
1536 // Codegen can reverse this transform (scalarize) if it was not profitable.
1537 if (OldCost < NewCost || !NewCost.isValid())
1538 return false;
1539
1540 // Create a vector constant from the 2 scalar constants.
1541 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1542 PoisonValue::get(T: VecTy->getElementType()));
1543 CmpC[Index0] = C0;
1544 CmpC[Index1] = C1;
1545 Value *VCmp = Builder.CreateCmp(Pred, LHS: X, RHS: ConstantVector::get(V: CmpC));
1546 Value *Shuf = createShiftShuffle(Vec: VCmp, OldIndex: ExpensiveIndex, NewIndex: CheapIndex, Builder);
1547 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1548 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1549 Value *VecLogic = Builder.CreateBinOp(Opc: BI->getOpcode(), LHS, RHS);
1550 Value *NewExt = Builder.CreateExtractElement(Vec: VecLogic, Idx: CheapIndex);
1551 replaceValue(Old&: I, New&: *NewExt);
1552 ++NumVecCmpBO;
1553 return true;
1554}
1555
1556/// Try to fold scalar selects that select between extracted elements and zero
1557/// into extracting from a vector select. This is rooted at the bitcast.
1558///
1559/// This pattern arises when a vector is bitcast to a smaller element type,
1560/// elements are extracted, and then conditionally selected with zero:
1561///
1562/// %bc = bitcast <4 x i32> %src to <16 x i8>
1563/// %e0 = extractelement <16 x i8> %bc, i32 0
1564/// %s0 = select i1 %cond, i8 %e0, i8 0
1565/// %e1 = extractelement <16 x i8> %bc, i32 1
1566/// %s1 = select i1 %cond, i8 %e1, i8 0
1567/// ...
1568///
1569/// Transforms to:
1570/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1571/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1572/// %e0 = extractelement <16 x i8> %bc, i32 0
1573/// %e1 = extractelement <16 x i8> %bc, i32 1
1574/// ...
1575///
1576/// This is profitable because vector select on wider types produces fewer
1577/// select/cndmask instructions than scalar selects on each element.
1578bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1579 auto *BC = dyn_cast<BitCastInst>(Val: &I);
1580 if (!BC)
1581 return false;
1582
1583 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(Val: BC->getSrcTy());
1584 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(Val: BC->getDestTy());
1585 if (!SrcVecTy || !DstVecTy)
1586 return false;
1587
1588 // Source must be 32-bit or 64-bit elements, destination must be smaller
1589 // integer elements. Zero in all these types is all-bits-zero.
1590 Type *SrcEltTy = SrcVecTy->getElementType();
1591 Type *DstEltTy = DstVecTy->getElementType();
1592 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1593 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1594
1595 if (SrcEltBits != 32 && SrcEltBits != 64)
1596 return false;
1597
1598 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1599 return false;
1600
1601 // Check profitability using TTI before collecting users.
1602 Type *CondTy = CmpInst::makeCmpResultType(opnd_type: DstEltTy);
1603 Type *VecCondTy = CmpInst::makeCmpResultType(opnd_type: SrcVecTy);
1604
1605 InstructionCost ScalarSelCost =
1606 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: DstEltTy, CondTy,
1607 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1608 InstructionCost VecSelCost =
1609 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: SrcVecTy, CondTy: VecCondTy,
1610 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1611
1612 // We need at least this many selects for vectorization to be profitable.
1613 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1614 // ScalarSelCost
1615 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1616 return false;
1617
1618 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1619
1620 // Quick check: if bitcast doesn't have enough users, bail early.
1621 if (!BC->hasNUsesOrMore(N: MinSelects))
1622 return false;
1623
1624 // Collect all select users that match the pattern, grouped by condition.
1625 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1626 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1627
1628 for (User *U : BC->users()) {
1629 auto *Ext = dyn_cast<ExtractElementInst>(Val: U);
1630 if (!Ext)
1631 continue;
1632
1633 for (User *ExtUser : Ext->users()) {
1634 Value *Cond;
1635 // Match: select i1 %cond, %ext, 0
1636 if (match(V: ExtUser, P: m_Select(C: m_Value(V&: Cond), L: m_Specific(V: Ext), R: m_Zero())) &&
1637 Cond->getType()->isIntegerTy(Bitwidth: 1))
1638 CondToSelects[Cond].push_back(Elt: cast<SelectInst>(Val: ExtUser));
1639 }
1640 }
1641
1642 if (CondToSelects.empty())
1643 return false;
1644
1645 bool MadeChange = false;
1646 Value *SrcVec = BC->getOperand(i_nocapture: 0);
1647
1648 // Process each group of selects with the same condition.
1649 for (auto [Cond, Selects] : CondToSelects) {
1650 // Only profitable if vector select cost < total scalar select cost.
1651 if (Selects.size() < MinSelects) {
1652 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1653 << "profitable (VecCost=" << VecSelCost
1654 << ", ScalarCost=" << ScalarSelCost
1655 << ", NumSelects=" << Selects.size() << ")\n");
1656 continue;
1657 }
1658
1659 // Create the vector select and bitcast once for this condition.
1660 auto InsertPt = std::next(x: BC->getIterator());
1661
1662 if (auto *CondInst = dyn_cast<Instruction>(Val: Cond))
1663 if (DT.dominates(Def: BC, User: CondInst))
1664 InsertPt = std::next(x: CondInst->getIterator());
1665
1666 Builder.SetInsertPoint(InsertPt);
1667 Value *VecSel =
1668 Builder.CreateSelect(C: Cond, True: SrcVec, False: Constant::getNullValue(Ty: SrcVecTy));
1669 Value *NewBC = Builder.CreateBitCast(V: VecSel, DestTy: DstVecTy);
1670
1671 // Replace each scalar select with an extract from the new bitcast.
1672 for (SelectInst *Sel : Selects) {
1673 auto *Ext = cast<ExtractElementInst>(Val: Sel->getTrueValue());
1674 Value *Idx = Ext->getIndexOperand();
1675
1676 Builder.SetInsertPoint(Sel);
1677 Value *NewExt = Builder.CreateExtractElement(Vec: NewBC, Idx);
1678 replaceValue(Old&: *Sel, New&: *NewExt);
1679 MadeChange = true;
1680 }
1681
1682 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1683 << " selects into vector select\n");
1684 }
1685
1686 return MadeChange;
1687}
1688
1689static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1690 TTI::TargetCostKind CostKind,
1691 const TargetTransformInfo &TTI,
1692 InstructionCost &CostBeforeReduction,
1693 InstructionCost &CostAfterReduction) {
1694 Instruction *Op0, *Op1;
1695 auto *RedOp = dyn_cast<Instruction>(Val: II.getOperand(i_nocapture: 0));
1696 auto *VecRedTy = cast<VectorType>(Val: II.getOperand(i_nocapture: 0)->getType());
1697 unsigned ReductionOpc =
1698 getArithmeticReductionInstruction(RdxID: II.getIntrinsicID());
1699 if (RedOp && match(V: RedOp, P: m_ZExtOrSExt(Op: m_Value()))) {
1700 bool IsUnsigned = isa<ZExtInst>(Val: RedOp);
1701 auto *ExtType = cast<VectorType>(Val: RedOp->getOperand(i: 0)->getType());
1702
1703 CostBeforeReduction =
1704 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: ExtType,
1705 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1706 CostAfterReduction =
1707 TTI.getExtendedReductionCost(Opcode: ReductionOpc, IsUnsigned, ResTy: II.getType(),
1708 Ty: ExtType, FMF: FastMathFlags(), CostKind);
1709 return;
1710 }
1711 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1712 match(V: RedOp,
1713 P: m_ZExtOrSExt(Op: m_Mul(L: m_Instruction(I&: Op0), R: m_Instruction(I&: Op1)))) &&
1714 match(V: Op0, P: m_ZExtOrSExt(Op: m_Value())) &&
1715 Op0->getOpcode() == Op1->getOpcode() &&
1716 Op0->getOperand(i: 0)->getType() == Op1->getOperand(i: 0)->getType() &&
1717 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1718 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1719 bool IsUnsigned = isa<ZExtInst>(Val: Op0);
1720 auto *ExtType = cast<VectorType>(Val: Op0->getOperand(i: 0)->getType());
1721 VectorType *MulType = VectorType::get(ElementType: Op0->getType(), Other: VecRedTy);
1722
1723 InstructionCost ExtCost =
1724 TTI.getCastInstrCost(Opcode: Op0->getOpcode(), Dst: MulType, Src: ExtType,
1725 CCH: TTI::CastContextHint::None, CostKind, I: Op0);
1726 InstructionCost MulCost =
1727 TTI.getArithmeticInstrCost(Opcode: Instruction::Mul, Ty: MulType, CostKind);
1728 InstructionCost Ext2Cost =
1729 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: MulType,
1730 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1731
1732 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1733 CostAfterReduction = TTI.getMulAccReductionCost(
1734 IsUnsigned, RedOpcode: ReductionOpc, ResTy: II.getType(), Ty: ExtType, CostKind);
1735 return;
1736 }
1737 CostAfterReduction = TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecRedTy,
1738 FMF: std::nullopt, CostKind);
1739}
1740
1741bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1742 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(Val: &I)->getOpcode();
1743 Intrinsic::ID ReductionIID = getReductionForBinop(Opc: BinOpOpc);
1744 if (BinOpOpc == Instruction::Sub)
1745 ReductionIID = Intrinsic::vector_reduce_add;
1746 if (ReductionIID == Intrinsic::not_intrinsic)
1747 return false;
1748
1749 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1750 Intrinsic::ID IID) -> Value * {
1751 auto *II = dyn_cast<IntrinsicInst>(Val: V);
1752 if (!II)
1753 return nullptr;
1754 if (II->getIntrinsicID() == IID && II->hasOneUse())
1755 return II->getArgOperand(i: 0);
1756 return nullptr;
1757 };
1758
1759 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 0), ReductionIID);
1760 if (!V0)
1761 return false;
1762 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 1), ReductionIID);
1763 if (!V1)
1764 return false;
1765
1766 auto *VTy = cast<VectorType>(Val: V0->getType());
1767 if (V1->getType() != VTy)
1768 return false;
1769 const auto &II0 = *cast<IntrinsicInst>(Val: I.getOperand(i: 0));
1770 const auto &II1 = *cast<IntrinsicInst>(Val: I.getOperand(i: 1));
1771 unsigned ReductionOpc =
1772 getArithmeticReductionInstruction(RdxID: II0.getIntrinsicID());
1773
1774 InstructionCost OldCost = 0;
1775 InstructionCost NewCost = 0;
1776 InstructionCost CostOfRedOperand0 = 0;
1777 InstructionCost CostOfRed0 = 0;
1778 InstructionCost CostOfRedOperand1 = 0;
1779 InstructionCost CostOfRed1 = 0;
1780 analyzeCostOfVecReduction(II: II0, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand0, CostAfterReduction&: CostOfRed0);
1781 analyzeCostOfVecReduction(II: II1, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand1, CostAfterReduction&: CostOfRed1);
1782 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(U: &I, CostKind);
1783 NewCost =
1784 CostOfRedOperand0 + CostOfRedOperand1 +
1785 TTI.getArithmeticInstrCost(Opcode: BinOpOpc, Ty: VTy, CostKind) +
1786 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VTy, FMF: std::nullopt, CostKind);
1787 if (NewCost >= OldCost || !NewCost.isValid())
1788 return false;
1789
1790 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1791 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1792 << "\n");
1793 Value *VectorBO;
1794 if (BinOpOpc == Instruction::Or)
1795 VectorBO = Builder.CreateOr(LHS: V0, RHS: V1, Name: "",
1796 IsDisjoint: cast<PossiblyDisjointInst>(Val&: I).isDisjoint());
1797 else
1798 VectorBO = Builder.CreateBinOp(Opc: BinOpOpc, LHS: V0, RHS: V1);
1799
1800 Instruction *Rdx = Builder.CreateIntrinsic(ID: ReductionIID, Types: {VTy}, Args: {VectorBO});
1801 replaceValue(Old&: I, New&: *Rdx);
1802 return true;
1803}
1804
1805// Check if memory loc modified between two instrs in the same BB
1806static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1807 BasicBlock::iterator End,
1808 const MemoryLocation &Loc, AAResults &AA) {
1809 unsigned NumScanned = 0;
1810 return std::any_of(first: Begin, last: End, pred: [&](const Instruction &Instr) {
1811 return isModSet(MRI: AA.getModRefInfo(I: &Instr, OptLoc: Loc)) ||
1812 ++NumScanned > MaxInstrsToScan;
1813 });
1814}
1815
1816namespace {
1817/// Helper class to indicate whether a vector index can be safely scalarized and
1818/// if a freeze needs to be inserted.
1819class ScalarizationResult {
1820 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1821
1822 StatusTy Status;
1823 Value *ToFreeze;
1824
1825 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1826 : Status(Status), ToFreeze(ToFreeze) {}
1827
1828public:
1829 ScalarizationResult(const ScalarizationResult &Other) = default;
1830 ~ScalarizationResult() {
1831 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1832 }
1833
1834 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1835 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1836 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1837 return {StatusTy::SafeWithFreeze, ToFreeze};
1838 }
1839
1840 /// Returns true if the index can be scalarize without requiring a freeze.
1841 bool isSafe() const { return Status == StatusTy::Safe; }
1842 /// Returns true if the index cannot be scalarized.
1843 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1844 /// Returns true if the index can be scalarize, but requires inserting a
1845 /// freeze.
1846 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1847
1848 /// Reset the state of Unsafe and clear ToFreze if set.
1849 void discard() {
1850 ToFreeze = nullptr;
1851 Status = StatusTy::Unsafe;
1852 }
1853
1854 /// Freeze the ToFreeze and update the use in \p User to use it.
1855 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1856 assert(isSafeWithFreeze() &&
1857 "should only be used when freezing is required");
1858 assert(is_contained(ToFreeze->users(), &UserI) &&
1859 "UserI must be a user of ToFreeze");
1860 IRBuilder<>::InsertPointGuard Guard(Builder);
1861 Builder.SetInsertPoint(cast<Instruction>(Val: &UserI));
1862 Value *Frozen =
1863 Builder.CreateFreeze(V: ToFreeze, Name: ToFreeze->getName() + ".frozen");
1864 for (Use &U : make_early_inc_range(Range: (UserI.operands())))
1865 if (U.get() == ToFreeze)
1866 U.set(Frozen);
1867
1868 ToFreeze = nullptr;
1869 }
1870};
1871} // namespace
1872
1873/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1874/// Idx. \p Idx must access a valid vector element.
1875static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1876 Instruction *CtxI,
1877 AssumptionCache &AC,
1878 const DominatorTree &DT) {
1879 // We do checks for both fixed vector types and scalable vector types.
1880 // This is the number of elements of fixed vector types,
1881 // or the minimum number of elements of scalable vector types.
1882 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1883 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1884
1885 if (auto *C = dyn_cast<ConstantInt>(Val: Idx)) {
1886 if (C->getValue().ult(RHS: NumElements))
1887 return ScalarizationResult::safe();
1888 return ScalarizationResult::unsafe();
1889 }
1890
1891 // Always unsafe if the index type can't handle all inbound values.
1892 if (!llvm::isUIntN(N: IntWidth, x: NumElements))
1893 return ScalarizationResult::unsafe();
1894
1895 APInt Zero(IntWidth, 0);
1896 APInt MaxElts(IntWidth, NumElements);
1897 ConstantRange ValidIndices(Zero, MaxElts);
1898 ConstantRange IdxRange(IntWidth, true);
1899
1900 if (isGuaranteedNotToBePoison(V: Idx, AC: &AC)) {
1901 if (ValidIndices.contains(CR: computeConstantRange(V: Idx, /* ForSigned */ false,
1902 UseInstrInfo: true, AC: &AC, CtxI, DT: &DT)))
1903 return ScalarizationResult::safe();
1904 return ScalarizationResult::unsafe();
1905 }
1906
1907 // If the index may be poison, check if we can insert a freeze before the
1908 // range of the index is restricted.
1909 Value *IdxBase;
1910 ConstantInt *CI;
1911 if (match(V: Idx, P: m_And(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1912 IdxRange = IdxRange.binaryAnd(Other: CI->getValue());
1913 } else if (match(V: Idx, P: m_URem(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1914 IdxRange = IdxRange.urem(Other: CI->getValue());
1915 }
1916
1917 if (ValidIndices.contains(CR: IdxRange))
1918 return ScalarizationResult::safeWithFreeze(ToFreeze: IdxBase);
1919 return ScalarizationResult::unsafe();
1920}
1921
1922/// The memory operation on a vector of \p ScalarType had alignment of
1923/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1924/// alignment that will be valid for the memory operation on a single scalar
1925/// element of the same type with index \p Idx.
1926static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1927 Type *ScalarType, Value *Idx,
1928 const DataLayout &DL) {
1929 if (auto *C = dyn_cast<ConstantInt>(Val: Idx))
1930 return commonAlignment(A: VectorAlignment,
1931 Offset: C->getZExtValue() * DL.getTypeStoreSize(Ty: ScalarType));
1932 return commonAlignment(A: VectorAlignment, Offset: DL.getTypeStoreSize(Ty: ScalarType));
1933}
1934
1935// Combine patterns like:
1936// %0 = load <4 x i32>, <4 x i32>* %a
1937// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1938// store <4 x i32> %1, <4 x i32>* %a
1939// to:
1940// %0 = bitcast <4 x i32>* %a to i32*
1941// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1942// store i32 %b, i32* %1
1943bool VectorCombine::foldSingleElementStore(Instruction &I) {
1944 if (!TTI.allowVectorElementIndexingUsingGEP())
1945 return false;
1946 auto *SI = cast<StoreInst>(Val: &I);
1947 if (!SI->isSimple() || !isa<VectorType>(Val: SI->getValueOperand()->getType()))
1948 return false;
1949
1950 // TODO: Combine more complicated patterns (multiple insert) by referencing
1951 // TargetTransformInfo.
1952 Instruction *Source;
1953 Value *NewElement;
1954 Value *Idx;
1955 if (!match(V: SI->getValueOperand(),
1956 P: m_InsertElt(Val: m_Instruction(I&: Source), Elt: m_Value(V&: NewElement),
1957 Idx: m_Value(V&: Idx))))
1958 return false;
1959
1960 if (auto *Load = dyn_cast<LoadInst>(Val: Source)) {
1961 auto VecTy = cast<VectorType>(Val: SI->getValueOperand()->getType());
1962 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1963 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1964 // modified between, vector type matches store size, and index is inbounds.
1965 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1966 !DL->typeSizeEqualsStoreSize(Ty: Load->getType()->getScalarType()) ||
1967 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1968 return false;
1969
1970 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, CtxI: Load, AC, DT);
1971 if (ScalarizableIdx.isUnsafe() ||
1972 isMemModifiedBetween(Begin: Load->getIterator(), End: SI->getIterator(),
1973 Loc: MemoryLocation::get(SI), AA))
1974 return false;
1975
1976 // Ensure we add the load back to the worklist BEFORE its users so they can
1977 // erased in the correct order.
1978 Worklist.push(I: Load);
1979
1980 if (ScalarizableIdx.isSafeWithFreeze())
1981 ScalarizableIdx.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
1982 Value *GEP = Builder.CreateInBoundsGEP(
1983 Ty: SI->getValueOperand()->getType(), Ptr: SI->getPointerOperand(),
1984 IdxList: {ConstantInt::get(Ty: Idx->getType(), V: 0), Idx});
1985 StoreInst *NSI = Builder.CreateStore(Val: NewElement, Ptr: GEP);
1986 NSI->copyMetadata(SrcInst: *SI);
1987 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1988 VectorAlignment: std::max(a: SI->getAlign(), b: Load->getAlign()), ScalarType: NewElement->getType(), Idx,
1989 DL: *DL);
1990 NSI->setAlignment(ScalarOpAlignment);
1991 replaceValue(Old&: I, New&: *NSI);
1992 eraseInstruction(I);
1993 return true;
1994 }
1995
1996 return false;
1997}
1998
1999/// Try to scalarize vector loads feeding extractelement or bitcast
2000/// instructions.
2001bool VectorCombine::scalarizeLoad(Instruction &I) {
2002 Value *Ptr;
2003 if (!match(V: &I, P: m_Load(Op: m_Value(V&: Ptr))))
2004 return false;
2005
2006 auto *LI = cast<LoadInst>(Val: &I);
2007 auto *VecTy = cast<VectorType>(Val: LI->getType());
2008 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(Ty: VecTy->getScalarType()))
2009 return false;
2010
2011 bool AllExtracts = true;
2012 bool AllBitcasts = true;
2013 Instruction *LastCheckedInst = LI;
2014 unsigned NumInstChecked = 0;
2015
2016 // Check what type of users we have (must either all be extracts or
2017 // bitcasts) and ensure no memory modifications between the load and
2018 // its users.
2019 for (User *U : LI->users()) {
2020 auto *UI = dyn_cast<Instruction>(Val: U);
2021 if (!UI || UI->getParent() != LI->getParent())
2022 return false;
2023
2024 // If any user is waiting to be erased, then bail out as this will
2025 // distort the cost calculation and possibly lead to infinite loops.
2026 if (UI->use_empty())
2027 return false;
2028
2029 if (!isa<ExtractElementInst>(Val: UI))
2030 AllExtracts = false;
2031 if (!isa<BitCastInst>(Val: UI))
2032 AllBitcasts = false;
2033
2034 // Check if any instruction between the load and the user may modify memory.
2035 if (LastCheckedInst->comesBefore(Other: UI)) {
2036 for (Instruction &I :
2037 make_range(x: std::next(x: LI->getIterator()), y: UI->getIterator())) {
2038 // Bail out if we reached the check limit or the instruction may write
2039 // to memory.
2040 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2041 return false;
2042 NumInstChecked++;
2043 }
2044 LastCheckedInst = UI;
2045 }
2046 }
2047
2048 if (AllExtracts)
2049 return scalarizeLoadExtract(LI, VecTy, Ptr);
2050 if (AllBitcasts)
2051 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2052 return false;
2053}
2054
2055/// Try to scalarize vector loads feeding extractelement instructions.
2056bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2057 Value *Ptr) {
2058 if (!TTI.allowVectorElementIndexingUsingGEP())
2059 return false;
2060
2061 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2062 llvm::scope_exit FailureGuard([&]() {
2063 // If the transform is aborted, discard the ScalarizationResults.
2064 for (auto &Pair : NeedFreeze)
2065 Pair.second.discard();
2066 });
2067
2068 InstructionCost OriginalCost =
2069 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2070 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2071 InstructionCost ScalarizedCost = 0;
2072
2073 for (User *U : LI->users()) {
2074 auto *UI = cast<ExtractElementInst>(Val: U);
2075
2076 auto ScalarIdx =
2077 canScalarizeAccess(VecTy, Idx: UI->getIndexOperand(), CtxI: LI, AC, DT);
2078 if (ScalarIdx.isUnsafe())
2079 return false;
2080 if (ScalarIdx.isSafeWithFreeze()) {
2081 NeedFreeze.try_emplace(Key: UI, Args&: ScalarIdx);
2082 ScalarIdx.discard();
2083 }
2084
2085 auto *Index = dyn_cast<ConstantInt>(Val: UI->getIndexOperand());
2086 OriginalCost +=
2087 TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: VecTy, CostKind,
2088 Index: Index ? Index->getZExtValue() : -1);
2089 ScalarizedCost +=
2090 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy->getElementType(),
2091 Alignment: Align(1), AddressSpace: LI->getPointerAddressSpace(), CostKind);
2092 ScalarizedCost += TTI.getAddressComputationCost(PtrTy: LI->getPointerOperandType(),
2093 SE: nullptr, Ptr: nullptr, CostKind);
2094 }
2095
2096 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2097 << "\n LoadExtractCost: " << OriginalCost
2098 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2099
2100 if (ScalarizedCost >= OriginalCost)
2101 return false;
2102
2103 // Ensure we add the load back to the worklist BEFORE its users so they can
2104 // erased in the correct order.
2105 Worklist.push(I: LI);
2106
2107 Type *ElemType = VecTy->getElementType();
2108
2109 // Replace extracts with narrow scalar loads.
2110 for (User *U : LI->users()) {
2111 auto *EI = cast<ExtractElementInst>(Val: U);
2112 Value *Idx = EI->getIndexOperand();
2113
2114 // Insert 'freeze' for poison indexes.
2115 auto It = NeedFreeze.find(Val: EI);
2116 if (It != NeedFreeze.end())
2117 It->second.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
2118
2119 Builder.SetInsertPoint(EI);
2120 Value *GEP =
2121 Builder.CreateInBoundsGEP(Ty: VecTy, Ptr, IdxList: {Builder.getInt32(C: 0), Idx});
2122 auto *NewLoad = cast<LoadInst>(
2123 Val: Builder.CreateLoad(Ty: ElemType, Ptr: GEP, Name: EI->getName() + ".scalar"));
2124
2125 Align ScalarOpAlignment =
2126 computeAlignmentAfterScalarization(VectorAlignment: LI->getAlign(), ScalarType: ElemType, Idx, DL: *DL);
2127 NewLoad->setAlignment(ScalarOpAlignment);
2128
2129 if (auto *ConstIdx = dyn_cast<ConstantInt>(Val: Idx)) {
2130 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(Ty: ElemType);
2131 AAMDNodes OldAAMD = LI->getAAMetadata();
2132 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, AccessTy: ElemType, DL: *DL));
2133 }
2134
2135 replaceValue(Old&: *EI, New&: *NewLoad, Erase: false);
2136 }
2137
2138 FailureGuard.release();
2139 return true;
2140}
2141
2142/// Try to scalarize vector loads feeding bitcast instructions.
2143bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2144 Value *Ptr) {
2145 InstructionCost OriginalCost =
2146 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2147 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2148
2149 Type *TargetScalarType = nullptr;
2150 unsigned VecBitWidth = DL->getTypeSizeInBits(Ty: VecTy);
2151
2152 for (User *U : LI->users()) {
2153 auto *BC = cast<BitCastInst>(Val: U);
2154
2155 Type *DestTy = BC->getDestTy();
2156 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2157 return false;
2158
2159 unsigned DestBitWidth = DL->getTypeSizeInBits(Ty: DestTy);
2160 if (DestBitWidth != VecBitWidth)
2161 return false;
2162
2163 // All bitcasts must target the same scalar type.
2164 if (!TargetScalarType)
2165 TargetScalarType = DestTy;
2166 else if (TargetScalarType != DestTy)
2167 return false;
2168
2169 OriginalCost +=
2170 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: TargetScalarType, Src: VecTy,
2171 CCH: TTI.getCastContextHint(I: BC), CostKind, I: BC);
2172 }
2173
2174 if (!TargetScalarType)
2175 return false;
2176
2177 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2178 InstructionCost ScalarizedCost =
2179 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: TargetScalarType, Alignment: LI->getAlign(),
2180 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2181
2182 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2183 << "\n OriginalCost: " << OriginalCost
2184 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2185
2186 if (ScalarizedCost >= OriginalCost)
2187 return false;
2188
2189 // Ensure we add the load back to the worklist BEFORE its users so they can
2190 // erased in the correct order.
2191 Worklist.push(I: LI);
2192
2193 Builder.SetInsertPoint(LI);
2194 auto *ScalarLoad =
2195 Builder.CreateLoad(Ty: TargetScalarType, Ptr, Name: LI->getName() + ".scalar");
2196 ScalarLoad->setAlignment(LI->getAlign());
2197 ScalarLoad->copyMetadata(SrcInst: *LI);
2198
2199 // Replace all bitcast users with the scalar load.
2200 for (User *U : LI->users()) {
2201 auto *BC = cast<BitCastInst>(Val: U);
2202 replaceValue(Old&: *BC, New&: *ScalarLoad, Erase: false);
2203 }
2204
2205 return true;
2206}
2207
2208bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2209 if (!TTI.allowVectorElementIndexingUsingGEP())
2210 return false;
2211 auto *Ext = dyn_cast<ZExtInst>(Val: &I);
2212 if (!Ext)
2213 return false;
2214
2215 // Try to convert a vector zext feeding only extracts to a set of scalar
2216 // (Src << ExtIdx *Size) & (Size -1)
2217 // if profitable .
2218 auto *SrcTy = dyn_cast<FixedVectorType>(Val: Ext->getOperand(i_nocapture: 0)->getType());
2219 if (!SrcTy)
2220 return false;
2221 auto *DstTy = cast<FixedVectorType>(Val: Ext->getType());
2222
2223 Type *ScalarDstTy = DstTy->getElementType();
2224 if (DL->getTypeSizeInBits(Ty: SrcTy) != DL->getTypeSizeInBits(Ty: ScalarDstTy))
2225 return false;
2226
2227 InstructionCost VectorCost =
2228 TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: DstTy, Src: SrcTy,
2229 CCH: TTI::CastContextHint::None, CostKind, I: Ext);
2230 unsigned ExtCnt = 0;
2231 bool ExtLane0 = false;
2232 for (User *U : Ext->users()) {
2233 uint64_t Idx;
2234 if (!match(V: U, P: m_ExtractElt(Val: m_Value(), Idx: m_ConstantInt(V&: Idx))))
2235 return false;
2236 if (cast<Instruction>(Val: U)->use_empty())
2237 continue;
2238 ExtCnt += 1;
2239 ExtLane0 |= !Idx;
2240 VectorCost += TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: DstTy,
2241 CostKind, Index: Idx, Op0: U);
2242 }
2243
2244 InstructionCost ScalarCost =
2245 ExtCnt * TTI.getArithmeticInstrCost(
2246 Opcode: Instruction::And, Ty: ScalarDstTy, CostKind,
2247 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2248 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None}) +
2249 (ExtCnt - ExtLane0) *
2250 TTI.getArithmeticInstrCost(
2251 Opcode: Instruction::LShr, Ty: ScalarDstTy, CostKind,
2252 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2253 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None});
2254 if (ScalarCost > VectorCost)
2255 return false;
2256
2257 Value *ScalarV = Ext->getOperand(i_nocapture: 0);
2258 if (!isGuaranteedNotToBePoison(V: ScalarV, AC: &AC, CtxI: dyn_cast<Instruction>(Val: ScalarV),
2259 DT: &DT)) {
2260 // Check wether all lanes are extracted, all extracts trigger UB
2261 // on poison, and the last extract (and hence all previous ones)
2262 // are guaranteed to execute if Ext executes. If so, we do not
2263 // need to insert a freeze.
2264 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2265 bool AllExtractsTriggerUB = true;
2266 ExtractElementInst *LastExtract = nullptr;
2267 BasicBlock *ExtBB = Ext->getParent();
2268 for (User *U : Ext->users()) {
2269 auto *Extract = cast<ExtractElementInst>(Val: U);
2270 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Inst: Extract)) {
2271 AllExtractsTriggerUB = false;
2272 break;
2273 }
2274 ExtractedLanes.insert(V: cast<ConstantInt>(Val: Extract->getIndexOperand()));
2275 if (!LastExtract || LastExtract->comesBefore(Other: Extract))
2276 LastExtract = Extract;
2277 }
2278 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2279 !AllExtractsTriggerUB ||
2280 !isGuaranteedToTransferExecutionToSuccessor(Begin: Ext->getIterator(),
2281 End: LastExtract->getIterator()))
2282 ScalarV = Builder.CreateFreeze(V: ScalarV);
2283 }
2284 ScalarV = Builder.CreateBitCast(
2285 V: ScalarV,
2286 DestTy: IntegerType::get(C&: SrcTy->getContext(), NumBits: DL->getTypeSizeInBits(Ty: SrcTy)));
2287 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(Ty: SrcTy->getElementType());
2288 uint64_t TotalBits = DL->getTypeSizeInBits(Ty: SrcTy);
2289 APInt EltBitMask = APInt::getLowBitsSet(numBits: TotalBits, loBitsSet: SrcEltSizeInBits);
2290 Type *PackedTy = IntegerType::get(C&: SrcTy->getContext(), NumBits: TotalBits);
2291 Value *Mask = ConstantInt::get(Ty: PackedTy, V: EltBitMask);
2292 for (User *U : Ext->users()) {
2293 auto *Extract = cast<ExtractElementInst>(Val: U);
2294 uint64_t Idx =
2295 cast<ConstantInt>(Val: Extract->getIndexOperand())->getZExtValue();
2296 uint64_t ShiftAmt =
2297 DL->isBigEndian()
2298 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2299 : (Idx * SrcEltSizeInBits);
2300 Value *LShr = Builder.CreateLShr(LHS: ScalarV, RHS: ShiftAmt);
2301 Value *And = Builder.CreateAnd(LHS: LShr, RHS: Mask);
2302 U->replaceAllUsesWith(V: And);
2303 }
2304 return true;
2305}
2306
2307/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2308/// to "(bitcast (concat X, Y))"
2309/// where X/Y are bitcasted from i1 mask vectors.
2310bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2311 Type *Ty = I.getType();
2312 if (!Ty->isIntegerTy())
2313 return false;
2314
2315 // TODO: Add big endian test coverage
2316 if (DL->isBigEndian())
2317 return false;
2318
2319 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2320 Instruction *X, *Y;
2321 if (!match(V: &I, P: m_DisjointOr(L: m_Instruction(I&: X), R: m_Instruction(I&: Y))))
2322 return false;
2323
2324 // Allow both sources to contain shl, to handle more generic pattern:
2325 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2326 Value *SrcX;
2327 uint64_t ShAmtX = 0;
2328 if (!match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX)))))) &&
2329 !match(V: X, P: m_OneUse(
2330 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX))))),
2331 R: m_ConstantInt(V&: ShAmtX)))))
2332 return false;
2333
2334 Value *SrcY;
2335 uint64_t ShAmtY = 0;
2336 if (!match(V: Y, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY)))))) &&
2337 !match(V: Y, P: m_OneUse(
2338 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY))))),
2339 R: m_ConstantInt(V&: ShAmtY)))))
2340 return false;
2341
2342 // Canonicalize larger shift to the RHS.
2343 if (ShAmtX > ShAmtY) {
2344 std::swap(a&: X, b&: Y);
2345 std::swap(a&: SrcX, b&: SrcY);
2346 std::swap(a&: ShAmtX, b&: ShAmtY);
2347 }
2348
2349 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2350 // difference is the mask width so they can be easily concatenated together.
2351 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2352 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2353 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2354 auto *MaskTy = dyn_cast<FixedVectorType>(Val: SrcX->getType());
2355 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2356 !MaskTy->getElementType()->isIntegerTy(Bitwidth: 1) ||
2357 MaskTy->getNumElements() != ShAmtDiff ||
2358 MaskTy->getNumElements() > (BitWidth / 2))
2359 return false;
2360
2361 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(VTy: MaskTy);
2362 auto *ConcatIntTy =
2363 Type::getIntNTy(C&: Ty->getContext(), N: ConcatTy->getNumElements());
2364 auto *MaskIntTy = Type::getIntNTy(C&: Ty->getContext(), N: ShAmtDiff);
2365
2366 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2367 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
2368
2369 // TODO: Is it worth supporting multi use cases?
2370 InstructionCost OldCost = 0;
2371 OldCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Or, Ty, CostKind);
2372 OldCost +=
2373 NumSHL * TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2374 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: MaskIntTy,
2375 CCH: TTI::CastContextHint::None, CostKind);
2376 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: MaskIntTy, Src: MaskTy,
2377 CCH: TTI::CastContextHint::None, CostKind);
2378
2379 InstructionCost NewCost = 0;
2380 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ConcatTy,
2381 SrcTy: MaskTy, Mask: ConcatMask, CostKind);
2382 NewCost += TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: ConcatIntTy, Src: ConcatTy,
2383 CCH: TTI::CastContextHint::None, CostKind);
2384 if (Ty != ConcatIntTy)
2385 NewCost += TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: ConcatIntTy,
2386 CCH: TTI::CastContextHint::None, CostKind);
2387 if (ShAmtX > 0)
2388 NewCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2389
2390 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2391 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2392 << "\n");
2393
2394 if (NewCost > OldCost)
2395 return false;
2396
2397 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2398 // any residual zero-extension or shifting.
2399 Value *Concat = Builder.CreateShuffleVector(V1: SrcX, V2: SrcY, Mask: ConcatMask);
2400 Worklist.pushValue(V: Concat);
2401
2402 Value *Result = Builder.CreateBitCast(V: Concat, DestTy: ConcatIntTy);
2403
2404 if (Ty != ConcatIntTy) {
2405 Worklist.pushValue(V: Result);
2406 Result = Builder.CreateZExt(V: Result, DestTy: Ty);
2407 }
2408
2409 if (ShAmtX > 0) {
2410 Worklist.pushValue(V: Result);
2411 Result = Builder.CreateShl(LHS: Result, RHS: ShAmtX);
2412 }
2413
2414 replaceValue(Old&: I, New&: *Result);
2415 return true;
2416}
2417
2418/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2419/// --> "binop (shuffle), (shuffle)".
2420bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2421 BinaryOperator *BinOp;
2422 ArrayRef<int> OuterMask;
2423 if (!match(V: &I, P: m_Shuffle(v1: m_BinOp(I&: BinOp), v2: m_Undef(), mask: m_Mask(OuterMask))))
2424 return false;
2425
2426 // Don't introduce poison into div/rem.
2427 if (BinOp->isIntDivRem() && llvm::is_contained(Range&: OuterMask, Element: PoisonMaskElem))
2428 return false;
2429
2430 Value *Op00, *Op01, *Op10, *Op11;
2431 ArrayRef<int> Mask0, Mask1;
2432 bool Match0 = match(V: BinOp->getOperand(i_nocapture: 0),
2433 P: m_Shuffle(v1: m_Value(V&: Op00), v2: m_Value(V&: Op01), mask: m_Mask(Mask0)));
2434 bool Match1 = match(V: BinOp->getOperand(i_nocapture: 1),
2435 P: m_Shuffle(v1: m_Value(V&: Op10), v2: m_Value(V&: Op11), mask: m_Mask(Mask1)));
2436 if (!Match0 && !Match1)
2437 return false;
2438
2439 Op00 = Match0 ? Op00 : BinOp->getOperand(i_nocapture: 0);
2440 Op01 = Match0 ? Op01 : BinOp->getOperand(i_nocapture: 0);
2441 Op10 = Match1 ? Op10 : BinOp->getOperand(i_nocapture: 1);
2442 Op11 = Match1 ? Op11 : BinOp->getOperand(i_nocapture: 1);
2443
2444 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2445 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2446 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: BinOp->getType());
2447 auto *Op0Ty = dyn_cast<FixedVectorType>(Val: Op00->getType());
2448 auto *Op1Ty = dyn_cast<FixedVectorType>(Val: Op10->getType());
2449 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2450 return false;
2451
2452 unsigned NumSrcElts = BinOpTy->getNumElements();
2453
2454 // Don't accept shuffles that reference the second operand in
2455 // div/rem or if its an undef arg.
2456 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(Val: I.getOperand(i: 1))) &&
2457 any_of(Range&: OuterMask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2458 return false;
2459
2460 // Merge outer / inner (or identity if no match) shuffles.
2461 SmallVector<int> NewMask0, NewMask1;
2462 for (int M : OuterMask) {
2463 if (M < 0 || M >= (int)NumSrcElts) {
2464 NewMask0.push_back(Elt: PoisonMaskElem);
2465 NewMask1.push_back(Elt: PoisonMaskElem);
2466 } else {
2467 NewMask0.push_back(Elt: Match0 ? Mask0[M] : M);
2468 NewMask1.push_back(Elt: Match1 ? Mask1[M] : M);
2469 }
2470 }
2471
2472 unsigned NumOpElts = Op0Ty->getNumElements();
2473 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2474 all_of(Range&: NewMask0, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2475 ShuffleVectorInst::isIdentityMask(Mask: NewMask0, NumSrcElts: NumOpElts);
2476 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2477 all_of(Range&: NewMask1, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2478 ShuffleVectorInst::isIdentityMask(Mask: NewMask1, NumSrcElts: NumOpElts);
2479
2480 InstructionCost NewCost = 0;
2481 // Try to merge shuffles across the binop if the new shuffles are not costly.
2482 InstructionCost BinOpCost =
2483 TTI.getArithmeticInstrCost(Opcode, Ty: BinOpTy, CostKind);
2484 InstructionCost OldCost =
2485 BinOpCost + TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
2486 DstTy: ShuffleDstTy, SrcTy: BinOpTy, Mask: OuterMask, CostKind,
2487 Index: 0, SubTp: nullptr, Args: {BinOp}, CxtI: &I);
2488 if (!BinOp->hasOneUse())
2489 NewCost += BinOpCost;
2490
2491 if (Match0) {
2492 InstructionCost Shuf0Cost = TTI.getShuffleCost(
2493 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op0Ty, Mask: Mask0, CostKind,
2494 Index: 0, SubTp: nullptr, Args: {Op00, Op01}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 0)));
2495 OldCost += Shuf0Cost;
2496 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 0)->hasOneUse())
2497 NewCost += Shuf0Cost;
2498 }
2499 if (Match1) {
2500 InstructionCost Shuf1Cost = TTI.getShuffleCost(
2501 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op1Ty, Mask: Mask1, CostKind,
2502 Index: 0, SubTp: nullptr, Args: {Op10, Op11}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 1)));
2503 OldCost += Shuf1Cost;
2504 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 1)->hasOneUse())
2505 NewCost += Shuf1Cost;
2506 }
2507
2508 NewCost += TTI.getArithmeticInstrCost(Opcode, Ty: ShuffleDstTy, CostKind);
2509
2510 if (!IsIdentity0)
2511 NewCost +=
2512 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2513 SrcTy: Op0Ty, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {Op00, Op01});
2514 if (!IsIdentity1)
2515 NewCost +=
2516 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2517 SrcTy: Op1Ty, Mask: NewMask1, CostKind, Index: 0, SubTp: nullptr, Args: {Op10, Op11});
2518
2519 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2520 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2521 << "\n");
2522
2523 // If costs are equal, still fold as we reduce instruction count.
2524 if (NewCost > OldCost)
2525 return false;
2526
2527 Value *LHS =
2528 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(V1: Op00, V2: Op01, Mask: NewMask0);
2529 Value *RHS =
2530 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(V1: Op10, V2: Op11, Mask: NewMask1);
2531 Value *NewBO = Builder.CreateBinOp(Opc: Opcode, LHS, RHS);
2532
2533 // Intersect flags from the old binops.
2534 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO))
2535 NewInst->copyIRFlags(V: BinOp);
2536
2537 Worklist.pushValue(V: LHS);
2538 Worklist.pushValue(V: RHS);
2539 replaceValue(Old&: I, New&: *NewBO);
2540 return true;
2541}
2542
2543/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2544/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2545bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2546 ArrayRef<int> OldMask;
2547 Instruction *LHS, *RHS;
2548 if (!match(V: &I, P: m_Shuffle(v1: m_Instruction(I&: LHS), v2: m_Instruction(I&: RHS),
2549 mask: m_Mask(OldMask))))
2550 return false;
2551
2552 // TODO: Add support for addlike etc.
2553 if (LHS->getOpcode() != RHS->getOpcode())
2554 return false;
2555
2556 Value *X, *Y, *Z, *W;
2557 bool IsCommutative = false;
2558 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2559 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2560 if (match(V: LHS, P: m_BinOp(L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2561 match(V: RHS, P: m_BinOp(L: m_Value(V&: Z), R: m_Value(V&: W)))) {
2562 auto *BO = cast<BinaryOperator>(Val: LHS);
2563 // Don't introduce poison into div/rem.
2564 if (llvm::is_contained(Range&: OldMask, Element: PoisonMaskElem) && BO->isIntDivRem())
2565 return false;
2566 IsCommutative = BinaryOperator::isCommutative(Opcode: BO->getOpcode());
2567 } else if (match(V: LHS, P: m_Cmp(Pred&: PredLHS, L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2568 match(V: RHS, P: m_Cmp(Pred&: PredRHS, L: m_Value(V&: Z), R: m_Value(V&: W))) &&
2569 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2570 IsCommutative = cast<CmpInst>(Val: LHS)->isCommutative();
2571 } else
2572 return false;
2573
2574 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2575 auto *BinResTy = dyn_cast<FixedVectorType>(Val: LHS->getType());
2576 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: X->getType());
2577 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2578 return false;
2579
2580 bool SameBinOp = LHS == RHS;
2581 unsigned NumSrcElts = BinOpTy->getNumElements();
2582
2583 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2584 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2585 std::swap(a&: X, b&: Y);
2586
2587 auto ConvertToUnary = [NumSrcElts](int &M) {
2588 if (M >= (int)NumSrcElts)
2589 M -= NumSrcElts;
2590 };
2591
2592 SmallVector<int> NewMask0(OldMask);
2593 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2594 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Y: Z);
2595 if (X == Z) {
2596 llvm::for_each(Range&: NewMask0, F: ConvertToUnary);
2597 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2598 Z = PoisonValue::get(T: BinOpTy);
2599 }
2600
2601 SmallVector<int> NewMask1(OldMask);
2602 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2603 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(X: Y, Y: W);
2604 if (Y == W) {
2605 llvm::for_each(Range&: NewMask1, F: ConvertToUnary);
2606 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2607 W = PoisonValue::get(T: BinOpTy);
2608 }
2609
2610 // Try to replace a binop with a shuffle if the shuffle is not costly.
2611 // When SameBinOp, only count the binop cost once.
2612 InstructionCost LHSCost = TTI.getInstructionCost(U: LHS, CostKind);
2613 InstructionCost RHSCost = TTI.getInstructionCost(U: RHS, CostKind);
2614
2615 InstructionCost OldCost = LHSCost;
2616 if (!SameBinOp) {
2617 OldCost += RHSCost;
2618 }
2619 OldCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc,
2620 DstTy: ShuffleDstTy, SrcTy: BinResTy, Mask: OldMask, CostKind, Index: 0,
2621 SubTp: nullptr, Args: {LHS, RHS}, CxtI: &I);
2622
2623 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2624 // where one use shuffles have gotten split across the binop/cmp. These
2625 // often allow a major reduction in total cost that wouldn't happen as
2626 // individual folds.
2627 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2628 TTI::TargetCostKind CostKind) -> bool {
2629 Value *InnerOp;
2630 ArrayRef<int> InnerMask;
2631 if (match(V: Op, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: InnerOp), v2: m_Undef(),
2632 mask: m_Mask(InnerMask)))) &&
2633 InnerOp->getType() == Op->getType() &&
2634 all_of(Range&: InnerMask,
2635 P: [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2636 for (int &M : Mask)
2637 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2638 M = InnerMask[M - Offset];
2639 M = 0 <= M ? M + Offset : M;
2640 }
2641 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: Op), CostKind);
2642 Op = InnerOp;
2643 return true;
2644 }
2645 return false;
2646 };
2647 bool ReducedInstCount = false;
2648 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2649 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2650 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2651 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2652 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2653 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2654 // original binop(s). If binops have multiple uses, they won't be eliminated.
2655 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2656
2657 auto *ShuffleCmpTy =
2658 FixedVectorType::get(ElementType: BinOpTy->getElementType(), FVTy: ShuffleDstTy);
2659 InstructionCost NewCost = TTI.getShuffleCost(
2660 Kind: SK0, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {X, Z});
2661 if (!SingleSrcBinOp)
2662 NewCost += TTI.getShuffleCost(Kind: SK1, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask1,
2663 CostKind, Index: 0, SubTp: nullptr, Args: {Y, W});
2664
2665 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2666 NewCost += TTI.getArithmeticInstrCost(Opcode: LHS->getOpcode(), Ty: ShuffleDstTy,
2667 CostKind, Opd1Info: Op0Info, Opd2Info: Op1Info);
2668 } else {
2669 NewCost +=
2670 TTI.getCmpSelInstrCost(Opcode: LHS->getOpcode(), ValTy: ShuffleCmpTy, CondTy: ShuffleDstTy,
2671 VecPred: PredLHS, CostKind, Op1Info: Op0Info, Op2Info: Op1Info);
2672 }
2673 // If LHS/RHS have other uses, we need to account for the cost of keeping
2674 // the original instructions. When SameBinOp, only add the cost once.
2675 if (!LHS->hasOneUser())
2676 NewCost += LHSCost;
2677 if (!SameBinOp && !RHS->hasOneUser())
2678 NewCost += RHSCost;
2679
2680 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2681 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2682 << "\n");
2683
2684 // If either shuffle will constant fold away, then fold for the same cost as
2685 // we will reduce the instruction count.
2686 ReducedInstCount |= (isa<Constant>(Val: X) && isa<Constant>(Val: Z)) ||
2687 (isa<Constant>(Val: Y) && isa<Constant>(Val: W));
2688 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2689 return false;
2690
2691 Value *Shuf0 = Builder.CreateShuffleVector(V1: X, V2: Z, Mask: NewMask0);
2692 Value *Shuf1 =
2693 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(V1: Y, V2: W, Mask: NewMask1);
2694 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2695 ? Builder.CreateBinOp(
2696 Opc: cast<BinaryOperator>(Val: LHS)->getOpcode(), LHS: Shuf0, RHS: Shuf1)
2697 : Builder.CreateCmp(Pred: PredLHS, LHS: Shuf0, RHS: Shuf1);
2698
2699 // Intersect flags from the old binops.
2700 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
2701 NewInst->copyIRFlags(V: LHS);
2702 NewInst->andIRFlags(V: RHS);
2703 }
2704
2705 Worklist.pushValue(V: Shuf0);
2706 Worklist.pushValue(V: Shuf1);
2707 replaceValue(Old&: I, New&: *NewBO);
2708 return true;
2709}
2710
2711/// Try to convert,
2712/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2713/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2714bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2715 ArrayRef<int> Mask;
2716 Value *C1, *T1, *F1, *C2, *T2, *F2;
2717 if (!match(V: &I, P: m_Shuffle(v1: m_Select(C: m_Value(V&: C1), L: m_Value(V&: T1), R: m_Value(V&: F1)),
2718 v2: m_Select(C: m_Value(V&: C2), L: m_Value(V&: T2), R: m_Value(V&: F2)),
2719 mask: m_Mask(Mask))))
2720 return false;
2721
2722 auto *Sel1 = cast<Instruction>(Val: I.getOperand(i: 0));
2723 auto *Sel2 = cast<Instruction>(Val: I.getOperand(i: 1));
2724
2725 auto *C1VecTy = dyn_cast<FixedVectorType>(Val: C1->getType());
2726 auto *C2VecTy = dyn_cast<FixedVectorType>(Val: C2->getType());
2727 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2728 return false;
2729
2730 auto *SI0FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 0));
2731 auto *SI1FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 1));
2732 // SelectInsts must have the same FMF.
2733 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2734 ((SI0FOp != nullptr) &&
2735 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2736 return false;
2737
2738 auto *SrcVecTy = cast<FixedVectorType>(Val: T1->getType());
2739 auto *DstVecTy = cast<FixedVectorType>(Val: I.getType());
2740 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2741 auto SelOp = Instruction::Select;
2742
2743 InstructionCost CostSel1 = TTI.getCmpSelInstrCost(
2744 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C1VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2745 InstructionCost CostSel2 = TTI.getCmpSelInstrCost(
2746 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C2VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2747
2748 InstructionCost OldCost =
2749 CostSel1 + CostSel2 +
2750 TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
2751 Args: {I.getOperand(i: 0), I.getOperand(i: 1)}, CxtI: &I);
2752
2753 InstructionCost NewCost = TTI.getShuffleCost(
2754 Kind: SK, DstTy: FixedVectorType::get(ElementType: C1VecTy->getScalarType(), NumElts: Mask.size()), SrcTy: C1VecTy,
2755 Mask, CostKind, Index: 0, SubTp: nullptr, Args: {C1, C2});
2756 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2757 SubTp: nullptr, Args: {T1, T2});
2758 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2759 SubTp: nullptr, Args: {F1, F2});
2760 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2761 ElementType: Type::getInt1Ty(C&: I.getContext()), NumElts: DstVecTy->getNumElements());
2762 NewCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: DstVecTy, CondTy: C1C2ShuffledVecTy,
2763 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2764
2765 if (!Sel1->hasOneUse())
2766 NewCost += CostSel1;
2767 if (!Sel2->hasOneUse())
2768 NewCost += CostSel2;
2769
2770 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2771 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2772 << "\n");
2773 if (NewCost > OldCost)
2774 return false;
2775
2776 Value *ShuffleCmp = Builder.CreateShuffleVector(V1: C1, V2: C2, Mask);
2777 Value *ShuffleTrue = Builder.CreateShuffleVector(V1: T1, V2: T2, Mask);
2778 Value *ShuffleFalse = Builder.CreateShuffleVector(V1: F1, V2: F2, Mask);
2779 Value *NewSel;
2780 // We presuppose that the SelectInsts have the same FMF.
2781 if (SI0FOp)
2782 NewSel = Builder.CreateSelectFMF(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse,
2783 FMFSource: SI0FOp->getFastMathFlags());
2784 else
2785 NewSel = Builder.CreateSelect(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse);
2786
2787 Worklist.pushValue(V: ShuffleCmp);
2788 Worklist.pushValue(V: ShuffleTrue);
2789 Worklist.pushValue(V: ShuffleFalse);
2790 replaceValue(Old&: I, New&: *NewSel);
2791 return true;
2792}
2793
2794/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2795/// into "castop (shuffle)".
2796bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2797 Value *V0, *V1;
2798 ArrayRef<int> OldMask;
2799 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
2800 return false;
2801
2802 // Check whether this is a binary shuffle.
2803 bool IsBinaryShuffle = !isa<UndefValue>(Val: V1);
2804
2805 auto *C0 = dyn_cast<CastInst>(Val: V0);
2806 auto *C1 = dyn_cast<CastInst>(Val: V1);
2807 if (!C0 || (IsBinaryShuffle && !C1))
2808 return false;
2809
2810 Instruction::CastOps Opcode = C0->getOpcode();
2811
2812 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2813 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2814 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2815 return false;
2816
2817 if (IsBinaryShuffle) {
2818 if (C0->getSrcTy() != C1->getSrcTy())
2819 return false;
2820 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2821 if (Opcode != C1->getOpcode()) {
2822 if (match(V: C0, P: m_SExtLike(Op: m_Value())) && match(V: C1, P: m_SExtLike(Op: m_Value())))
2823 Opcode = Instruction::SExt;
2824 else
2825 return false;
2826 }
2827 }
2828
2829 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2830 auto *CastDstTy = dyn_cast<FixedVectorType>(Val: C0->getDestTy());
2831 auto *CastSrcTy = dyn_cast<FixedVectorType>(Val: C0->getSrcTy());
2832 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2833 return false;
2834
2835 unsigned NumSrcElts = CastSrcTy->getNumElements();
2836 unsigned NumDstElts = CastDstTy->getNumElements();
2837 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2838 "Only bitcasts expected to alter src/dst element counts");
2839
2840 // Check for bitcasting of unscalable vector types.
2841 // e.g. <32 x i40> -> <40 x i32>
2842 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2843 (NumDstElts % NumSrcElts) != 0)
2844 return false;
2845
2846 SmallVector<int, 16> NewMask;
2847 if (NumSrcElts >= NumDstElts) {
2848 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2849 // always be expanded to the equivalent form choosing narrower elements.
2850 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2851 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2852 narrowShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask);
2853 } else {
2854 // The bitcast is from narrow elements to wide elements. The shuffle mask
2855 // must choose consecutive elements to allow casting first.
2856 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2857 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2858 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask))
2859 return false;
2860 }
2861
2862 auto *NewShuffleDstTy =
2863 FixedVectorType::get(ElementType: CastSrcTy->getScalarType(), NumElts: NewMask.size());
2864
2865 // Try to replace a castop with a shuffle if the shuffle is not costly.
2866 InstructionCost CostC0 =
2867 TTI.getCastInstrCost(Opcode: C0->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2868 CCH: TTI::CastContextHint::None, CostKind, I: C0);
2869
2870 TargetTransformInfo::ShuffleKind ShuffleKind;
2871 if (IsBinaryShuffle)
2872 ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2873 else
2874 ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2875
2876 InstructionCost OldCost = CostC0;
2877 OldCost += TTI.getShuffleCost(Kind: ShuffleKind, DstTy: ShuffleDstTy, SrcTy: CastDstTy, Mask: OldMask,
2878 CostKind, Index: 0, SubTp: nullptr, Args: {}, CxtI: &I);
2879
2880 InstructionCost NewCost = TTI.getShuffleCost(Kind: ShuffleKind, DstTy: NewShuffleDstTy,
2881 SrcTy: CastSrcTy, Mask: NewMask, CostKind);
2882 NewCost += TTI.getCastInstrCost(Opcode, Dst: ShuffleDstTy, Src: NewShuffleDstTy,
2883 CCH: TTI::CastContextHint::None, CostKind);
2884 if (!C0->hasOneUse())
2885 NewCost += CostC0;
2886 if (IsBinaryShuffle) {
2887 InstructionCost CostC1 =
2888 TTI.getCastInstrCost(Opcode: C1->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2889 CCH: TTI::CastContextHint::None, CostKind, I: C1);
2890 OldCost += CostC1;
2891 if (!C1->hasOneUse())
2892 NewCost += CostC1;
2893 }
2894
2895 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2896 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2897 << "\n");
2898 if (NewCost > OldCost)
2899 return false;
2900
2901 Value *Shuf;
2902 if (IsBinaryShuffle)
2903 Shuf = Builder.CreateShuffleVector(V1: C0->getOperand(i_nocapture: 0), V2: C1->getOperand(i_nocapture: 0),
2904 Mask: NewMask);
2905 else
2906 Shuf = Builder.CreateShuffleVector(V: C0->getOperand(i_nocapture: 0), Mask: NewMask);
2907
2908 Value *Cast = Builder.CreateCast(Op: Opcode, V: Shuf, DestTy: ShuffleDstTy);
2909
2910 // Intersect flags from the old casts.
2911 if (auto *NewInst = dyn_cast<Instruction>(Val: Cast)) {
2912 NewInst->copyIRFlags(V: C0);
2913 if (IsBinaryShuffle)
2914 NewInst->andIRFlags(V: C1);
2915 }
2916
2917 Worklist.pushValue(V: Shuf);
2918 replaceValue(Old&: I, New&: *Cast);
2919 return true;
2920}
2921
2922/// Try to convert any of:
2923/// "shuffle (shuffle x, y), (shuffle y, x)"
2924/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2925/// "shuffle (shuffle x, undef), y"
2926/// "shuffle x, (shuffle y, undef)"
2927/// into "shuffle x, y".
2928bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2929 ArrayRef<int> OuterMask;
2930 Value *OuterV0, *OuterV1;
2931 if (!match(V: &I,
2932 P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1), mask: m_Mask(OuterMask))))
2933 return false;
2934
2935 ArrayRef<int> InnerMask0, InnerMask1;
2936 Value *X0, *X1, *Y0, *Y1;
2937 bool Match0 =
2938 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: X0), v2: m_Value(V&: Y0), mask: m_Mask(InnerMask0)));
2939 bool Match1 =
2940 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: X1), v2: m_Value(V&: Y1), mask: m_Mask(InnerMask1)));
2941 if (!Match0 && !Match1)
2942 return false;
2943
2944 // If the outer shuffle is a permute, then create a fake inner all-poison
2945 // shuffle. This is easier than accounting for length-changing shuffles below.
2946 SmallVector<int, 16> PoisonMask1;
2947 if (!Match1 && isa<PoisonValue>(Val: OuterV1)) {
2948 X1 = X0;
2949 Y1 = Y0;
2950 PoisonMask1.append(NumInputs: InnerMask0.size(), Elt: PoisonMaskElem);
2951 InnerMask1 = PoisonMask1;
2952 Match1 = true; // fake match
2953 }
2954
2955 X0 = Match0 ? X0 : OuterV0;
2956 Y0 = Match0 ? Y0 : OuterV0;
2957 X1 = Match1 ? X1 : OuterV1;
2958 Y1 = Match1 ? Y1 : OuterV1;
2959 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2960 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(Val: X0->getType());
2961 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(Val: OuterV0->getType());
2962 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2963 X0->getType() != X1->getType())
2964 return false;
2965
2966 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2967 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2968
2969 // Attempt to merge shuffles, matching upto 2 source operands.
2970 // Replace index to a poison arg with PoisonMaskElem.
2971 // Bail if either inner masks reference an undef arg.
2972 SmallVector<int, 16> NewMask(OuterMask);
2973 Value *NewX = nullptr, *NewY = nullptr;
2974 for (int &M : NewMask) {
2975 Value *Src = nullptr;
2976 if (0 <= M && M < (int)NumImmElts) {
2977 Src = OuterV0;
2978 if (Match0) {
2979 M = InnerMask0[M];
2980 Src = M >= (int)NumSrcElts ? Y0 : X0;
2981 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2982 }
2983 } else if (M >= (int)NumImmElts) {
2984 Src = OuterV1;
2985 M -= NumImmElts;
2986 if (Match1) {
2987 M = InnerMask1[M];
2988 Src = M >= (int)NumSrcElts ? Y1 : X1;
2989 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2990 }
2991 }
2992 if (Src && M != PoisonMaskElem) {
2993 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
2994 if (isa<UndefValue>(Val: Src)) {
2995 // We've referenced an undef element - if its poison, update the shuffle
2996 // mask, else bail.
2997 if (!isa<PoisonValue>(Val: Src))
2998 return false;
2999 M = PoisonMaskElem;
3000 continue;
3001 }
3002 if (!NewX || NewX == Src) {
3003 NewX = Src;
3004 continue;
3005 }
3006 if (!NewY || NewY == Src) {
3007 M += NumSrcElts;
3008 NewY = Src;
3009 continue;
3010 }
3011 return false;
3012 }
3013 }
3014
3015 if (!NewX)
3016 return PoisonValue::get(T: ShuffleDstTy);
3017 if (!NewY)
3018 NewY = PoisonValue::get(T: ShuffleSrcTy);
3019
3020 // Have we folded to an Identity shuffle?
3021 if (ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts)) {
3022 replaceValue(Old&: I, New&: *NewX);
3023 return true;
3024 }
3025
3026 // Try to merge the shuffles if the new shuffle is not costly.
3027 InstructionCost InnerCost0 = 0;
3028 if (Match0)
3029 InnerCost0 = TTI.getInstructionCost(U: cast<User>(Val: OuterV0), CostKind);
3030
3031 InstructionCost InnerCost1 = 0;
3032 if (Match1)
3033 InnerCost1 = TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3034
3035 InstructionCost OuterCost = TTI.getInstructionCost(U: &I, CostKind);
3036
3037 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3038
3039 bool IsUnary = all_of(Range&: NewMask, P: [&](int M) { return M < (int)NumSrcElts; });
3040 TargetTransformInfo::ShuffleKind SK =
3041 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
3042 : TargetTransformInfo::SK_PermuteTwoSrc;
3043 InstructionCost NewCost =
3044 TTI.getShuffleCost(Kind: SK, DstTy: ShuffleDstTy, SrcTy: ShuffleSrcTy, Mask: NewMask, CostKind, Index: 0,
3045 SubTp: nullptr, Args: {NewX, NewY});
3046 if (!OuterV0->hasOneUse())
3047 NewCost += InnerCost0;
3048 if (!OuterV1->hasOneUse())
3049 NewCost += InnerCost1;
3050
3051 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3052 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3053 << "\n");
3054 if (NewCost > OldCost)
3055 return false;
3056
3057 Value *Shuf = Builder.CreateShuffleVector(V1: NewX, V2: NewY, Mask: NewMask);
3058 replaceValue(Old&: I, New&: *Shuf);
3059 return true;
3060}
3061
3062/// Try to convert a chain of length-preserving shuffles that are fed by
3063/// length-changing shuffles from the same source, e.g. a chain of length 3:
3064///
3065/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3066/// (shuffle y, undef)),
3067// (shuffle y, undef)"
3068///
3069/// into a single shuffle fed by a length-changing shuffle:
3070///
3071/// "shuffle x, (shuffle y, undef)"
3072///
3073/// Such chains arise e.g. from folding extract/insert sequences.
3074bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3075 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(Val: I.getType());
3076 if (!TrunkType)
3077 return false;
3078
3079 unsigned ChainLength = 0;
3080 SmallVector<int> Mask;
3081 SmallVector<int> YMask;
3082 InstructionCost OldCost = 0;
3083 InstructionCost NewCost = 0;
3084 Value *Trunk = &I;
3085 unsigned NumTrunkElts = TrunkType->getNumElements();
3086 Value *Y = nullptr;
3087
3088 for (;;) {
3089 // Match the current trunk against (commutations of) the pattern
3090 // "shuffle trunk', (shuffle y, undef)"
3091 ArrayRef<int> OuterMask;
3092 Value *OuterV0, *OuterV1;
3093 if (ChainLength != 0 && !Trunk->hasOneUse())
3094 break;
3095 if (!match(V: Trunk, P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1),
3096 mask: m_Mask(OuterMask))))
3097 break;
3098 if (OuterV0->getType() != TrunkType) {
3099 // This shuffle is not length-preserving, so it cannot be part of the
3100 // chain.
3101 break;
3102 }
3103
3104 ArrayRef<int> InnerMask0, InnerMask1;
3105 Value *A0, *A1, *B0, *B1;
3106 bool Match0 =
3107 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: A0), v2: m_Value(V&: B0), mask: m_Mask(InnerMask0)));
3108 bool Match1 =
3109 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: A1), v2: m_Value(V&: B1), mask: m_Mask(InnerMask1)));
3110 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3111 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3112 if (Match0Leaf == Match1Leaf) {
3113 // Only handle the case of exactly one leaf in each step. The "two leaves"
3114 // case is handled by foldShuffleOfShuffles.
3115 break;
3116 }
3117
3118 SmallVector<int> CommutedOuterMask;
3119 if (Match0Leaf) {
3120 std::swap(a&: OuterV0, b&: OuterV1);
3121 std::swap(a&: InnerMask0, b&: InnerMask1);
3122 std::swap(a&: A0, b&: A1);
3123 std::swap(a&: B0, b&: B1);
3124 llvm::append_range(C&: CommutedOuterMask, R&: OuterMask);
3125 for (int &M : CommutedOuterMask) {
3126 if (M == PoisonMaskElem)
3127 continue;
3128 if (M < (int)NumTrunkElts)
3129 M += NumTrunkElts;
3130 else
3131 M -= NumTrunkElts;
3132 }
3133 OuterMask = CommutedOuterMask;
3134 }
3135 if (!OuterV1->hasOneUse())
3136 break;
3137
3138 if (!isa<UndefValue>(Val: A1)) {
3139 if (!Y)
3140 Y = A1;
3141 else if (Y != A1)
3142 break;
3143 }
3144 if (!isa<UndefValue>(Val: B1)) {
3145 if (!Y)
3146 Y = B1;
3147 else if (Y != B1)
3148 break;
3149 }
3150
3151 auto *YType = cast<FixedVectorType>(Val: A1->getType());
3152 int NumLeafElts = YType->getNumElements();
3153 SmallVector<int> LocalYMask(InnerMask1);
3154 for (int &M : LocalYMask) {
3155 if (M >= NumLeafElts)
3156 M -= NumLeafElts;
3157 }
3158
3159 InstructionCost LocalOldCost =
3160 TTI.getInstructionCost(U: cast<User>(Val: Trunk), CostKind) +
3161 TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3162
3163 // Handle the initial (start of chain) case.
3164 if (!ChainLength) {
3165 Mask.assign(AR: OuterMask);
3166 YMask.assign(RHS: LocalYMask);
3167 OldCost = NewCost = LocalOldCost;
3168 Trunk = OuterV0;
3169 ChainLength++;
3170 continue;
3171 }
3172
3173 // For the non-root case, first attempt to combine masks.
3174 SmallVector<int> NewYMask(YMask);
3175 bool Valid = true;
3176 for (auto [CombinedM, LeafM] : llvm::zip(t&: NewYMask, u&: LocalYMask)) {
3177 if (LeafM == -1 || CombinedM == LeafM)
3178 continue;
3179 if (CombinedM == -1) {
3180 CombinedM = LeafM;
3181 } else {
3182 Valid = false;
3183 break;
3184 }
3185 }
3186 if (!Valid)
3187 break;
3188
3189 SmallVector<int> NewMask;
3190 NewMask.reserve(N: NumTrunkElts);
3191 for (int M : Mask) {
3192 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3193 NewMask.push_back(Elt: M);
3194 else
3195 NewMask.push_back(Elt: OuterMask[M]);
3196 }
3197
3198 // Break the chain if adding this new step complicates the shuffles such
3199 // that it would increase the new cost by more than the old cost of this
3200 // step.
3201 InstructionCost LocalNewCost =
3202 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: TrunkType,
3203 SrcTy: YType, Mask: NewYMask, CostKind) +
3204 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: TrunkType,
3205 SrcTy: TrunkType, Mask: NewMask, CostKind);
3206
3207 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3208 break;
3209
3210 LLVM_DEBUG({
3211 if (ChainLength == 1) {
3212 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3213 << I << '\n';
3214 }
3215 dbgs() << " next chain link: " << *Trunk << '\n'
3216 << " old cost: " << (OldCost + LocalOldCost)
3217 << " new cost: " << LocalNewCost << '\n';
3218 });
3219
3220 Mask = NewMask;
3221 YMask = NewYMask;
3222 OldCost += LocalOldCost;
3223 NewCost = LocalNewCost;
3224 Trunk = OuterV0;
3225 ChainLength++;
3226 }
3227 if (ChainLength <= 1)
3228 return false;
3229
3230 if (llvm::all_of(Range&: Mask, P: [&](int M) {
3231 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3232 })) {
3233 // Produce a canonical simplified form if all elements are sourced from Y.
3234 for (int &M : Mask) {
3235 if (M >= static_cast<int>(NumTrunkElts))
3236 M = YMask[M - NumTrunkElts];
3237 }
3238 Value *Root =
3239 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask);
3240 replaceValue(Old&: I, New&: *Root);
3241 return true;
3242 }
3243
3244 Value *Leaf =
3245 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask: YMask);
3246 Value *Root = Builder.CreateShuffleVector(V1: Trunk, V2: Leaf, Mask);
3247 replaceValue(Old&: I, New&: *Root);
3248 return true;
3249}
3250
3251/// Try to convert
3252/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3253bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3254 Value *V0, *V1;
3255 ArrayRef<int> OldMask;
3256 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
3257 return false;
3258
3259 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3260 auto *II1 = dyn_cast<IntrinsicInst>(Val: V1);
3261 if (!II0 || !II1)
3262 return false;
3263
3264 Intrinsic::ID IID = II0->getIntrinsicID();
3265 if (IID != II1->getIntrinsicID())
3266 return false;
3267 InstructionCost CostII0 =
3268 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3269 InstructionCost CostII1 =
3270 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II1), CostKind);
3271
3272 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3273 auto *II0Ty = dyn_cast<FixedVectorType>(Val: II0->getType());
3274 if (!ShuffleDstTy || !II0Ty)
3275 return false;
3276
3277 if (!isTriviallyVectorizable(ID: IID))
3278 return false;
3279
3280 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3281 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI) &&
3282 II0->getArgOperand(i: I) != II1->getArgOperand(i: I))
3283 return false;
3284
3285 InstructionCost OldCost =
3286 CostII0 + CostII1 +
3287 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
3288 SrcTy: II0Ty, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {II0, II1}, CxtI: &I);
3289
3290 SmallVector<Type *> NewArgsTy;
3291 InstructionCost NewCost = 0;
3292 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3293 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3294 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3295 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3296 } else {
3297 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3298 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3299 NumElts: ShuffleDstTy->getNumElements());
3300 NewArgsTy.push_back(Elt: ArgTy);
3301 std::pair<Value *, Value *> OperandPair =
3302 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3303 if (!SeenOperandPairs.insert(V: OperandPair).second) {
3304 // We've already computed the cost for this operand pair.
3305 continue;
3306 }
3307 NewCost += TTI.getShuffleCost(
3308 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ArgTy, SrcTy: VecTy, Mask: OldMask,
3309 CostKind, Index: 0, SubTp: nullptr, Args: {II0->getArgOperand(i: I), II1->getArgOperand(i: I)});
3310 }
3311 }
3312 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3313
3314 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3315 if (!II0->hasOneUse())
3316 NewCost += CostII0;
3317 if (II1 != II0 && !II1->hasOneUse())
3318 NewCost += CostII1;
3319
3320 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3321 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3322 << "\n");
3323
3324 if (NewCost > OldCost)
3325 return false;
3326
3327 SmallVector<Value *> NewArgs;
3328 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3329 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3330 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3331 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3332 } else {
3333 std::pair<Value *, Value *> OperandPair =
3334 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3335 auto It = ShuffleCache.find(Val: OperandPair);
3336 if (It != ShuffleCache.end()) {
3337 // Reuse previously created shuffle for this operand pair.
3338 NewArgs.push_back(Elt: It->second);
3339 continue;
3340 }
3341 Value *Shuf = Builder.CreateShuffleVector(V1: II0->getArgOperand(i: I),
3342 V2: II1->getArgOperand(i: I), Mask: OldMask);
3343 ShuffleCache[OperandPair] = Shuf;
3344 NewArgs.push_back(Elt: Shuf);
3345 Worklist.pushValue(V: Shuf);
3346 }
3347 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3348
3349 // Intersect flags from the old intrinsics.
3350 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic)) {
3351 NewInst->copyIRFlags(V: II0);
3352 NewInst->andIRFlags(V: II1);
3353 }
3354
3355 replaceValue(Old&: I, New&: *NewIntrinsic);
3356 return true;
3357}
3358
3359/// Try to convert
3360/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3361bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3362 Value *V0;
3363 ArrayRef<int> Mask;
3364 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Undef(), mask: m_Mask(Mask))))
3365 return false;
3366
3367 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3368 if (!II0)
3369 return false;
3370
3371 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3372 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(Val: II0->getType());
3373 if (!ShuffleDstTy || !IntrinsicSrcTy)
3374 return false;
3375
3376 // Validate it's a pure permute, mask should only reference the first vector
3377 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3378 if (any_of(Range&: Mask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3379 return false;
3380
3381 Intrinsic::ID IID = II0->getIntrinsicID();
3382 if (!isTriviallyVectorizable(ID: IID))
3383 return false;
3384
3385 // Cost analysis
3386 InstructionCost IntrinsicCost =
3387 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3388 InstructionCost OldCost =
3389 IntrinsicCost +
3390 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: ShuffleDstTy,
3391 SrcTy: IntrinsicSrcTy, Mask, CostKind, Index: 0, SubTp: nullptr, Args: {V0}, CxtI: &I);
3392
3393 SmallVector<Type *> NewArgsTy;
3394 InstructionCost NewCost = 0;
3395 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3396 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3397 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3398 } else {
3399 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3400 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3401 NumElts: ShuffleDstTy->getNumElements());
3402 NewArgsTy.push_back(Elt: ArgTy);
3403 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
3404 DstTy: ArgTy, SrcTy: VecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
3405 Args: {II0->getArgOperand(i: I)});
3406 }
3407 }
3408 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3409 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3410
3411 // If the intrinsic has multiple uses, we need to account for the cost of
3412 // keeping the original intrinsic around.
3413 if (!II0->hasOneUse())
3414 NewCost += IntrinsicCost;
3415
3416 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3417 << OldCost << " vs NewCost: " << NewCost << "\n");
3418
3419 if (NewCost > OldCost)
3420 return false;
3421
3422 // Transform
3423 SmallVector<Value *> NewArgs;
3424 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3425 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3426 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3427 } else {
3428 Value *Shuf = Builder.CreateShuffleVector(V: II0->getArgOperand(i: I), Mask);
3429 NewArgs.push_back(Elt: Shuf);
3430 Worklist.pushValue(V: Shuf);
3431 }
3432 }
3433
3434 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3435
3436 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic))
3437 NewInst->copyIRFlags(V: II0);
3438
3439 replaceValue(Old&: I, New&: *NewIntrinsic);
3440 return true;
3441}
3442
3443using InstLane = std::pair<Value *, int>;
3444
3445static InstLane lookThroughShuffles(Value *V, int Lane) {
3446 while (auto *SV = dyn_cast<ShuffleVectorInst>(Val: V)) {
3447 unsigned NumElts =
3448 cast<FixedVectorType>(Val: SV->getOperand(i_nocapture: 0)->getType())->getNumElements();
3449 int M = SV->getMaskValue(Elt: Lane);
3450 if (M < 0)
3451 return {nullptr, PoisonMaskElem};
3452 if (static_cast<unsigned>(M) < NumElts) {
3453 V = SV->getOperand(i_nocapture: 0);
3454 Lane = M;
3455 } else {
3456 V = SV->getOperand(i_nocapture: 1);
3457 Lane = M - NumElts;
3458 }
3459 }
3460 return InstLane{V, Lane};
3461}
3462
3463static SmallVector<InstLane>
3464generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
3465 SmallVector<InstLane> NItem;
3466 for (InstLane IL : Item) {
3467 auto [U, Lane] = IL;
3468 InstLane OpLane =
3469 U ? lookThroughShuffles(V: cast<Instruction>(Val: U)->getOperand(i: Op), Lane)
3470 : InstLane{nullptr, PoisonMaskElem};
3471 NItem.emplace_back(Args&: OpLane);
3472 }
3473 return NItem;
3474}
3475
3476/// Detect concat of multiple values into a vector
3477static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
3478 const TargetTransformInfo &TTI) {
3479 auto *Ty = cast<FixedVectorType>(Val: Item.front().first->getType());
3480 unsigned NumElts = Ty->getNumElements();
3481 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3482 return false;
3483
3484 // Check that the concat is free, usually meaning that the type will be split
3485 // during legalization.
3486 SmallVector<int, 16> ConcatMask(NumElts * 2);
3487 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
3488 if (TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc,
3489 DstTy: FixedVectorType::get(ElementType: Ty->getScalarType(), NumElts: NumElts * 2),
3490 SrcTy: Ty, Mask: ConcatMask, CostKind) != 0)
3491 return false;
3492
3493 unsigned NumSlices = Item.size() / NumElts;
3494 // Currently we generate a tree of shuffles for the concats, which limits us
3495 // to a power2.
3496 if (!isPowerOf2_32(Value: NumSlices))
3497 return false;
3498 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3499 Value *SliceV = Item[Slice * NumElts].first;
3500 if (!SliceV || SliceV->getType() != Ty)
3501 return false;
3502 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3503 auto [V, Lane] = Item[Slice * NumElts + Elt];
3504 if (Lane != static_cast<int>(Elt) || SliceV != V)
3505 return false;
3506 }
3507 }
3508 return true;
3509}
3510
3511static Value *
3512generateNewInstTree(ArrayRef<InstLane> Item, Use *From, FixedVectorType *Ty,
3513 const DenseSet<std::pair<Value *, Use *>> &IdentityLeafs,
3514 const DenseSet<std::pair<Value *, Use *>> &SplatLeafs,
3515 const DenseSet<std::pair<Value *, Use *>> &ConcatLeafs,
3516 IRBuilderBase &Builder, const TargetTransformInfo *TTI) {
3517 auto [FrontV, FrontLane] = Item.front();
3518
3519 if (IdentityLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3520 return FrontV;
3521 }
3522 if (SplatLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3523 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3524 return Builder.CreateShuffleVector(V: FrontV, Mask);
3525 }
3526 if (ConcatLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3527 unsigned NumElts =
3528 cast<FixedVectorType>(Val: FrontV->getType())->getNumElements();
3529 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3530 for (unsigned S = 0; S < Values.size(); ++S)
3531 Values[S] = Item[S * NumElts].first;
3532
3533 while (Values.size() > 1) {
3534 NumElts *= 2;
3535 SmallVector<int, 16> Mask(NumElts, 0);
3536 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
3537 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3538 for (unsigned S = 0; S < NewValues.size(); ++S)
3539 NewValues[S] =
3540 Builder.CreateShuffleVector(V1: Values[S * 2], V2: Values[S * 2 + 1], Mask);
3541 Values = NewValues;
3542 }
3543 return Values[0];
3544 }
3545
3546 auto *I = cast<Instruction>(Val: FrontV);
3547 auto *II = dyn_cast<IntrinsicInst>(Val: I);
3548 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3549 SmallVector<Value *> Ops(NumOps);
3550 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3551 if (II &&
3552 isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI)) {
3553 Ops[Idx] = II->getOperand(i_nocapture: Idx);
3554 continue;
3555 }
3556 Ops[Idx] = generateNewInstTree(Item: generateInstLaneVectorFromOperand(Item, Op: Idx),
3557 From: &I->getOperandUse(i: Idx), Ty, IdentityLeafs,
3558 SplatLeafs, ConcatLeafs, Builder, TTI);
3559 }
3560
3561 SmallVector<Value *, 8> ValueList;
3562 for (const auto &Lane : Item)
3563 if (Lane.first)
3564 ValueList.push_back(Elt: Lane.first);
3565
3566 Type *DstTy =
3567 FixedVectorType::get(ElementType: I->getType()->getScalarType(), NumElts: Ty->getNumElements());
3568 if (auto *BI = dyn_cast<BinaryOperator>(Val: I)) {
3569 auto *Value = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)BI->getOpcode(),
3570 LHS: Ops[0], RHS: Ops[1]);
3571 propagateIRFlags(I: Value, VL: ValueList);
3572 return Value;
3573 }
3574 if (auto *CI = dyn_cast<CmpInst>(Val: I)) {
3575 auto *Value = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: Ops[0], RHS: Ops[1]);
3576 propagateIRFlags(I: Value, VL: ValueList);
3577 return Value;
3578 }
3579 if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
3580 auto *Value = Builder.CreateSelect(C: Ops[0], True: Ops[1], False: Ops[2], Name: "", MDFrom: SI);
3581 propagateIRFlags(I: Value, VL: ValueList);
3582 return Value;
3583 }
3584 if (auto *CI = dyn_cast<CastInst>(Val: I)) {
3585 auto *Value = Builder.CreateCast(Op: CI->getOpcode(), V: Ops[0], DestTy: DstTy);
3586 propagateIRFlags(I: Value, VL: ValueList);
3587 return Value;
3588 }
3589 if (II) {
3590 auto *Value = Builder.CreateIntrinsic(RetTy: DstTy, ID: II->getIntrinsicID(), Args: Ops);
3591 propagateIRFlags(I: Value, VL: ValueList);
3592 return Value;
3593 }
3594 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3595 auto *Value =
3596 Builder.CreateUnOp(Opc: (Instruction::UnaryOps)I->getOpcode(), V: Ops[0]);
3597 propagateIRFlags(I: Value, VL: ValueList);
3598 return Value;
3599}
3600
3601// Starting from a shuffle, look up through operands tracking the shuffled index
3602// of each lane. If we can simplify away the shuffles to identities then
3603// do so.
3604bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3605 auto *Ty = dyn_cast<FixedVectorType>(Val: I.getType());
3606 if (!Ty || I.use_empty())
3607 return false;
3608
3609 SmallVector<InstLane> Start(Ty->getNumElements());
3610 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3611 Start[M] = lookThroughShuffles(V: &I, Lane: M);
3612
3613 SmallVector<std::pair<SmallVector<InstLane>, Use *>> Worklist;
3614 Worklist.push_back(Elt: std::make_pair(x&: Start, y: &*I.use_begin()));
3615 DenseSet<std::pair<Value *, Use *>> IdentityLeafs, SplatLeafs, ConcatLeafs;
3616 unsigned NumVisited = 0;
3617
3618 while (!Worklist.empty()) {
3619 if (++NumVisited > MaxInstrsToScan)
3620 return false;
3621
3622 auto ItemFrom = Worklist.pop_back_val();
3623 auto Item = ItemFrom.first;
3624 auto From = ItemFrom.second;
3625 auto [FrontV, FrontLane] = Item.front();
3626
3627 // If we found an undef first lane then bail out to keep things simple.
3628 if (!FrontV)
3629 return false;
3630
3631 // Helper to peek through bitcasts to the same value.
3632 auto IsEquiv = [&](Value *X, Value *Y) {
3633 return X->getType() == Y->getType() &&
3634 peekThroughBitcasts(V: X) == peekThroughBitcasts(V: Y);
3635 };
3636
3637 // Look for an identity value.
3638 if (FrontLane == 0 &&
3639 cast<FixedVectorType>(Val: FrontV->getType())->getNumElements() ==
3640 Ty->getNumElements() &&
3641 all_of(Range: drop_begin(RangeOrContainer: enumerate(First&: Item)), P: [IsEquiv, Item](const auto &E) {
3642 Value *FrontV = Item.front().first;
3643 return !E.value().first || (IsEquiv(E.value().first, FrontV) &&
3644 E.value().second == (int)E.index());
3645 })) {
3646 IdentityLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3647 continue;
3648 }
3649 // Look for constants, for the moment only supporting constant splats.
3650 if (auto *C = dyn_cast<Constant>(Val: FrontV);
3651 C && C->getSplatValue() &&
3652 all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3653 Value *FrontV = Item.front().first;
3654 Value *V = IL.first;
3655 return !V || (isa<Constant>(Val: V) &&
3656 cast<Constant>(Val: V)->getSplatValue() ==
3657 cast<Constant>(Val: FrontV)->getSplatValue());
3658 })) {
3659 SplatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3660 continue;
3661 }
3662 // Look for a splat value.
3663 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3664 auto [FrontV, FrontLane] = Item.front();
3665 auto [V, Lane] = IL;
3666 return !V || (V == FrontV && Lane == FrontLane);
3667 })) {
3668 SplatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3669 continue;
3670 }
3671
3672 // We need each element to be the same type of value, and check that each
3673 // element has a single use.
3674 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3675 Value *FrontV = Item.front().first;
3676 if (!IL.first)
3677 return true;
3678 Value *V = IL.first;
3679 if (auto *I = dyn_cast<Instruction>(Val: V); I && !I->hasOneUser())
3680 return false;
3681 if (V->getValueID() != FrontV->getValueID())
3682 return false;
3683 if (auto *CI = dyn_cast<CmpInst>(Val: V))
3684 if (CI->getPredicate() != cast<CmpInst>(Val: FrontV)->getPredicate())
3685 return false;
3686 if (auto *CI = dyn_cast<CastInst>(Val: V))
3687 if (CI->getSrcTy()->getScalarType() !=
3688 cast<CastInst>(Val: FrontV)->getSrcTy()->getScalarType())
3689 return false;
3690 if (auto *SI = dyn_cast<SelectInst>(Val: V))
3691 if (!isa<VectorType>(Val: SI->getOperand(i_nocapture: 0)->getType()) ||
3692 SI->getOperand(i_nocapture: 0)->getType() !=
3693 cast<SelectInst>(Val: FrontV)->getOperand(i_nocapture: 0)->getType())
3694 return false;
3695 if (isa<CallInst>(Val: V) && !isa<IntrinsicInst>(Val: V))
3696 return false;
3697 auto *II = dyn_cast<IntrinsicInst>(Val: V);
3698 return !II || (isa<IntrinsicInst>(Val: FrontV) &&
3699 II->getIntrinsicID() ==
3700 cast<IntrinsicInst>(Val: FrontV)->getIntrinsicID() &&
3701 !II->hasOperandBundles());
3702 };
3703 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: CheckLaneIsEquivalentToFirst)) {
3704 // Check the operator is one that we support.
3705 if (isa<BinaryOperator, CmpInst>(Val: FrontV)) {
3706 // We exclude div/rem in case they hit UB from poison lanes.
3707 if (auto *BO = dyn_cast<BinaryOperator>(Val: FrontV);
3708 BO && BO->isIntDivRem())
3709 return false;
3710 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3711 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 0));
3712 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 1),
3713 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 1));
3714 continue;
3715 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3716 FPToUIInst, SIToFPInst, UIToFPInst>(Val: FrontV)) {
3717 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3718 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 0));
3719 continue;
3720 } else if (auto *BitCast = dyn_cast<BitCastInst>(Val: FrontV)) {
3721 // TODO: Handle vector widening/narrowing bitcasts.
3722 auto *DstTy = dyn_cast<FixedVectorType>(Val: BitCast->getDestTy());
3723 auto *SrcTy = dyn_cast<FixedVectorType>(Val: BitCast->getSrcTy());
3724 if (DstTy && SrcTy &&
3725 SrcTy->getNumElements() == DstTy->getNumElements()) {
3726 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3727 Args: &BitCast->getOperandUse(i: 0));
3728 continue;
3729 }
3730 } else if (auto *Sel = dyn_cast<SelectInst>(Val: FrontV)) {
3731 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3732 Args: &Sel->getOperandUse(i: 0));
3733 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 1),
3734 Args: &Sel->getOperandUse(i: 1));
3735 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 2),
3736 Args: &Sel->getOperandUse(i: 2));
3737 continue;
3738 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: FrontV);
3739 II && isTriviallyVectorizable(ID: II->getIntrinsicID()) &&
3740 !II->hasOperandBundles()) {
3741 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3742 if (isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Op,
3743 TTI: &TTI)) {
3744 if (!all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item, Op](InstLane &IL) {
3745 Value *FrontV = Item.front().first;
3746 Value *V = IL.first;
3747 return !V || (cast<Instruction>(Val: V)->getOperand(i: Op) ==
3748 cast<Instruction>(Val: FrontV)->getOperand(i: Op));
3749 }))
3750 return false;
3751 continue;
3752 }
3753 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op),
3754 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: Op));
3755 }
3756 continue;
3757 }
3758 }
3759
3760 if (isFreeConcat(Item, CostKind, TTI)) {
3761 ConcatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3762 continue;
3763 }
3764
3765 return false;
3766 }
3767
3768 if (NumVisited <= 1)
3769 return false;
3770
3771 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3772
3773 // If we got this far, we know the shuffles are superfluous and can be
3774 // removed. Scan through again and generate the new tree of instructions.
3775 Builder.SetInsertPoint(&I);
3776 Value *V = generateNewInstTree(Item: Start, From: &*I.use_begin(), Ty, IdentityLeafs,
3777 SplatLeafs, ConcatLeafs, Builder, TTI: &TTI);
3778 replaceValue(Old&: I, New&: *V);
3779 return true;
3780}
3781
3782/// Given a commutative reduction, the order of the input lanes does not alter
3783/// the results. We can use this to remove certain shuffles feeding the
3784/// reduction, removing the need to shuffle at all.
3785bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3786 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
3787 if (!II)
3788 return false;
3789 switch (II->getIntrinsicID()) {
3790 case Intrinsic::vector_reduce_add:
3791 case Intrinsic::vector_reduce_mul:
3792 case Intrinsic::vector_reduce_and:
3793 case Intrinsic::vector_reduce_or:
3794 case Intrinsic::vector_reduce_xor:
3795 case Intrinsic::vector_reduce_smin:
3796 case Intrinsic::vector_reduce_smax:
3797 case Intrinsic::vector_reduce_umin:
3798 case Intrinsic::vector_reduce_umax:
3799 break;
3800 default:
3801 return false;
3802 }
3803
3804 // Find all the inputs when looking through operations that do not alter the
3805 // lane order (binops, for example). Currently we look for a single shuffle,
3806 // and can ignore splat values.
3807 std::queue<Value *> Worklist;
3808 SmallPtrSet<Value *, 4> Visited;
3809 ShuffleVectorInst *Shuffle = nullptr;
3810 if (auto *Op = dyn_cast<Instruction>(Val: I.getOperand(i: 0)))
3811 Worklist.push(x: Op);
3812
3813 while (!Worklist.empty()) {
3814 Value *CV = Worklist.front();
3815 Worklist.pop();
3816 if (Visited.contains(Ptr: CV))
3817 continue;
3818
3819 // Splats don't change the order, so can be safely ignored.
3820 if (isSplatValue(V: CV))
3821 continue;
3822
3823 Visited.insert(Ptr: CV);
3824
3825 if (auto *CI = dyn_cast<Instruction>(Val: CV)) {
3826 if (CI->isBinaryOp()) {
3827 for (auto *Op : CI->operand_values())
3828 Worklist.push(x: Op);
3829 continue;
3830 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: CI)) {
3831 if (Shuffle && Shuffle != SV)
3832 return false;
3833 Shuffle = SV;
3834 continue;
3835 }
3836 }
3837
3838 // Anything else is currently an unknown node.
3839 return false;
3840 }
3841
3842 if (!Shuffle)
3843 return false;
3844
3845 // Check all uses of the binary ops and shuffles are also included in the
3846 // lane-invariant operations (Visited should be the list of lanewise
3847 // instructions, including the shuffle that we found).
3848 for (auto *V : Visited)
3849 for (auto *U : V->users())
3850 if (!Visited.contains(Ptr: U) && U != &I)
3851 return false;
3852
3853 FixedVectorType *VecType =
3854 dyn_cast<FixedVectorType>(Val: II->getOperand(i_nocapture: 0)->getType());
3855 if (!VecType)
3856 return false;
3857 FixedVectorType *ShuffleInputType =
3858 dyn_cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType());
3859 if (!ShuffleInputType)
3860 return false;
3861 unsigned NumInputElts = ShuffleInputType->getNumElements();
3862
3863 // Find the mask from sorting the lanes into order. This is most likely to
3864 // become a identity or concat mask. Undef elements are pushed to the end.
3865 SmallVector<int> ConcatMask;
3866 Shuffle->getShuffleMask(Result&: ConcatMask);
3867 sort(C&: ConcatMask, Comp: [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3868 bool UsesSecondVec =
3869 any_of(Range&: ConcatMask, P: [&](int M) { return M >= (int)NumInputElts; });
3870
3871 InstructionCost OldCost = TTI.getShuffleCost(
3872 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3873 SrcTy: ShuffleInputType, Mask: Shuffle->getShuffleMask(), CostKind);
3874 InstructionCost NewCost = TTI.getShuffleCost(
3875 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3876 SrcTy: ShuffleInputType, Mask: ConcatMask, CostKind);
3877
3878 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3879 << "\n");
3880 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3881 << "\n");
3882 bool MadeChanges = false;
3883 if (NewCost < OldCost) {
3884 Builder.SetInsertPoint(Shuffle);
3885 Value *NewShuffle = Builder.CreateShuffleVector(
3886 V1: Shuffle->getOperand(i_nocapture: 0), V2: Shuffle->getOperand(i_nocapture: 1), Mask: ConcatMask);
3887 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3888 replaceValue(Old&: *Shuffle, New&: *NewShuffle);
3889 return true;
3890 }
3891
3892 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3893 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3894 MadeChanges |= foldSelectShuffle(I&: *Shuffle, FromReduction: true);
3895 return MadeChanges;
3896}
3897
3898/// For a given chain of patterns of the following form:
3899///
3900/// ```
3901/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3902///
3903/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3904/// ty1> %1)
3905/// OR
3906/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3907///
3908/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3909/// ...
3910/// ...
3911/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3912/// 3), <n x ty1> %(i - 2)
3913/// OR
3914/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3915///
3916/// %(i) = extractelement <n x ty1> %(i - 1), 0
3917/// ```
3918///
3919/// Where:
3920/// `mask` follows a partition pattern:
3921///
3922/// Ex:
3923/// [n = 8, p = poison]
3924///
3925/// 4 5 6 7 | p p p p
3926/// 2 3 | p p p p p p
3927/// 1 | p p p p p p p
3928///
3929/// For powers of 2, there's a consistent pattern, but for other cases
3930/// the parity of the current half value at each step decides the
3931/// next partition half (see `ExpectedParityMask` for more logical details
3932/// in generalising this).
3933///
3934/// Ex:
3935/// [n = 6]
3936///
3937/// 3 4 5 | p p p
3938/// 1 2 | p p p p
3939/// 1 | p p p p p
3940bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3941 // Going bottom-up for the pattern.
3942 std::queue<Value *> InstWorklist;
3943 InstructionCost OrigCost = 0;
3944
3945 // Common instruction operation after each shuffle op.
3946 std::optional<unsigned int> CommonCallOp = std::nullopt;
3947 std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3948
3949 bool IsFirstCallOrBinInst = true;
3950 bool ShouldBeCallOrBinInst = true;
3951
3952 // This stores the last used instructions for shuffle/common op.
3953 //
3954 // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3955 // instructions from either shuffle/common op.
3956 SmallVector<Value *, 2> PrevVecV(2, nullptr);
3957
3958 Value *VecOpEE;
3959 if (!match(V: &I, P: m_ExtractElt(Val: m_Value(V&: VecOpEE), Idx: m_Zero())))
3960 return false;
3961
3962 auto *FVT = dyn_cast<FixedVectorType>(Val: VecOpEE->getType());
3963 if (!FVT)
3964 return false;
3965
3966 int64_t VecSize = FVT->getNumElements();
3967 if (VecSize < 2)
3968 return false;
3969
3970 // Number of levels would be ~log2(n), considering we always partition
3971 // by half for this fold pattern.
3972 unsigned int NumLevels = Log2_64_Ceil(Value: VecSize), VisitedCnt = 0;
3973 int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3974
3975 // This is how we generalise for all element sizes.
3976 // At each step, if vector size is odd, we need non-poison
3977 // values to cover the dominant half so we don't miss out on any element.
3978 //
3979 // This mask will help us retrieve this as we go from bottom to top:
3980 //
3981 // Mask Set -> N = N * 2 - 1
3982 // Mask Unset -> N = N * 2
3983 for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3984 Cur = (Cur + 1) / 2, --Mask) {
3985 if (Cur & 1)
3986 ExpectedParityMask |= (1ll << Mask);
3987 }
3988
3989 InstWorklist.push(x: VecOpEE);
3990
3991 while (!InstWorklist.empty()) {
3992 Value *CI = InstWorklist.front();
3993 InstWorklist.pop();
3994
3995 if (auto *II = dyn_cast<IntrinsicInst>(Val: CI)) {
3996 if (!ShouldBeCallOrBinInst)
3997 return false;
3998
3999 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4000 return false;
4001
4002 // For the first found call/bin op, the vector has to come from the
4003 // extract element op.
4004 if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4005 return false;
4006 IsFirstCallOrBinInst = false;
4007
4008 if (!CommonCallOp)
4009 CommonCallOp = II->getIntrinsicID();
4010 if (II->getIntrinsicID() != *CommonCallOp)
4011 return false;
4012
4013 switch (II->getIntrinsicID()) {
4014 case Intrinsic::umin:
4015 case Intrinsic::umax:
4016 case Intrinsic::smin:
4017 case Intrinsic::smax: {
4018 auto *Op0 = II->getOperand(i_nocapture: 0);
4019 auto *Op1 = II->getOperand(i_nocapture: 1);
4020 PrevVecV[0] = Op0;
4021 PrevVecV[1] = Op1;
4022 break;
4023 }
4024 default:
4025 return false;
4026 }
4027 ShouldBeCallOrBinInst ^= 1;
4028
4029 IntrinsicCostAttributes ICA(
4030 *CommonCallOp, II->getType(),
4031 {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
4032 OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4033
4034 // We may need a swap here since it can be (a, b) or (b, a)
4035 // and accordingly change as we go up.
4036 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4037 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4038 InstWorklist.push(x: PrevVecV[1]);
4039 InstWorklist.push(x: PrevVecV[0]);
4040 } else if (auto *BinOp = dyn_cast<BinaryOperator>(Val: CI)) {
4041 // Similar logic for bin ops.
4042
4043 if (!ShouldBeCallOrBinInst)
4044 return false;
4045
4046 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4047 return false;
4048
4049 if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4050 return false;
4051 IsFirstCallOrBinInst = false;
4052
4053 if (!CommonBinOp)
4054 CommonBinOp = BinOp->getOpcode();
4055
4056 if (BinOp->getOpcode() != *CommonBinOp)
4057 return false;
4058
4059 switch (*CommonBinOp) {
4060 case BinaryOperator::Add:
4061 case BinaryOperator::Mul:
4062 case BinaryOperator::Or:
4063 case BinaryOperator::And:
4064 case BinaryOperator::Xor: {
4065 auto *Op0 = BinOp->getOperand(i_nocapture: 0);
4066 auto *Op1 = BinOp->getOperand(i_nocapture: 1);
4067 PrevVecV[0] = Op0;
4068 PrevVecV[1] = Op1;
4069 break;
4070 }
4071 default:
4072 return false;
4073 }
4074 ShouldBeCallOrBinInst ^= 1;
4075
4076 OrigCost +=
4077 TTI.getArithmeticInstrCost(Opcode: *CommonBinOp, Ty: BinOp->getType(), CostKind);
4078
4079 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4080 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4081 InstWorklist.push(x: PrevVecV[1]);
4082 InstWorklist.push(x: PrevVecV[0]);
4083 } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(Val: CI)) {
4084 // We shouldn't have any null values in the previous vectors,
4085 // is so, there was a mismatch in pattern.
4086 if (ShouldBeCallOrBinInst || any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4087 return false;
4088
4089 if (SVInst != PrevVecV[1])
4090 return false;
4091
4092 ArrayRef<int> CurMask;
4093 if (!match(V: SVInst, P: m_Shuffle(v1: m_Specific(V: PrevVecV[0]), v2: m_Poison(),
4094 mask: m_Mask(CurMask))))
4095 return false;
4096
4097 // Subtract the parity mask when checking the condition.
4098 for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
4099 if (Mask < ShuffleMaskHalf &&
4100 CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
4101 return false;
4102 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
4103 return false;
4104 }
4105
4106 // Update mask values.
4107 ShuffleMaskHalf *= 2;
4108 ShuffleMaskHalf -= (ExpectedParityMask & 1);
4109 ExpectedParityMask >>= 1;
4110
4111 OrigCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
4112 DstTy: SVInst->getType(), SrcTy: SVInst->getType(),
4113 Mask: CurMask, CostKind);
4114
4115 VisitedCnt += 1;
4116 if (!ExpectedParityMask && VisitedCnt == NumLevels)
4117 break;
4118
4119 ShouldBeCallOrBinInst ^= 1;
4120 } else {
4121 return false;
4122 }
4123 }
4124
4125 // Pattern should end with a shuffle op.
4126 if (ShouldBeCallOrBinInst)
4127 return false;
4128
4129 assert(VecSize != -1 && "Expected Match for Vector Size");
4130
4131 Value *FinalVecV = PrevVecV[0];
4132 if (!FinalVecV)
4133 return false;
4134
4135 auto *FinalVecVTy = cast<FixedVectorType>(Val: FinalVecV->getType());
4136
4137 Intrinsic::ID ReducedOp =
4138 (CommonCallOp ? getMinMaxReductionIntrinsicID(IID: *CommonCallOp)
4139 : getReductionForBinop(Opc: *CommonBinOp));
4140 if (!ReducedOp)
4141 return false;
4142
4143 IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
4144 InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
4145
4146 if (NewCost >= OrigCost)
4147 return false;
4148
4149 auto *ReducedResult =
4150 Builder.CreateIntrinsic(ID: ReducedOp, Types: {FinalVecV->getType()}, Args: {FinalVecV});
4151 replaceValue(Old&: I, New&: *ReducedResult);
4152
4153 return true;
4154}
4155
4156/// Determine if its more efficient to fold:
4157/// reduce(trunc(x)) -> trunc(reduce(x)).
4158/// reduce(sext(x)) -> sext(reduce(x)).
4159/// reduce(zext(x)) -> zext(reduce(x)).
4160bool VectorCombine::foldCastFromReductions(Instruction &I) {
4161 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
4162 if (!II)
4163 return false;
4164
4165 bool TruncOnly = false;
4166 Intrinsic::ID IID = II->getIntrinsicID();
4167 switch (IID) {
4168 case Intrinsic::vector_reduce_add:
4169 case Intrinsic::vector_reduce_mul:
4170 TruncOnly = true;
4171 break;
4172 case Intrinsic::vector_reduce_and:
4173 case Intrinsic::vector_reduce_or:
4174 case Intrinsic::vector_reduce_xor:
4175 break;
4176 default:
4177 return false;
4178 }
4179
4180 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4181 Value *ReductionSrc = I.getOperand(i: 0);
4182
4183 Value *Src;
4184 if (!match(V: ReductionSrc, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: Src)))) &&
4185 (TruncOnly || !match(V: ReductionSrc, P: m_OneUse(SubPattern: m_ZExtOrSExt(Op: m_Value(V&: Src))))))
4186 return false;
4187
4188 auto CastOpc =
4189 (Instruction::CastOps)cast<Instruction>(Val: ReductionSrc)->getOpcode();
4190
4191 auto *SrcTy = cast<VectorType>(Val: Src->getType());
4192 auto *ReductionSrcTy = cast<VectorType>(Val: ReductionSrc->getType());
4193 Type *ResultTy = I.getType();
4194
4195 InstructionCost OldCost = TTI.getArithmeticReductionCost(
4196 Opcode: ReductionOpc, Ty: ReductionSrcTy, FMF: std::nullopt, CostKind);
4197 OldCost += TTI.getCastInstrCost(Opcode: CastOpc, Dst: ReductionSrcTy, Src: SrcTy,
4198 CCH: TTI::CastContextHint::None, CostKind,
4199 I: cast<CastInst>(Val: ReductionSrc));
4200 InstructionCost NewCost =
4201 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: SrcTy, FMF: std::nullopt,
4202 CostKind) +
4203 TTI.getCastInstrCost(Opcode: CastOpc, Dst: ResultTy, Src: ReductionSrcTy->getScalarType(),
4204 CCH: TTI::CastContextHint::None, CostKind);
4205
4206 if (OldCost <= NewCost || !NewCost.isValid())
4207 return false;
4208
4209 Value *NewReduction = Builder.CreateIntrinsic(RetTy: SrcTy->getScalarType(),
4210 ID: II->getIntrinsicID(), Args: {Src});
4211 Value *NewCast = Builder.CreateCast(Op: CastOpc, V: NewReduction, DestTy: ResultTy);
4212 replaceValue(Old&: I, New&: *NewCast);
4213 return true;
4214}
4215
4216/// Fold:
4217/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4218/// into:
4219/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4220///
4221/// Sign-bit reductions produce values with known semantics:
4222/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4223/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4224/// - reduce.add: count of negative elements (0 to NumElts)
4225///
4226/// Both lshr and ashr are supported:
4227/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4228/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4229///
4230/// The fold generalizes to multiple source vectors combined with the same
4231/// operation as the reduction. For example:
4232/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4233/// For reduce.add, this changes the count to M*N where M is the number of
4234/// source vectors.
4235///
4236/// We transform to a direct sign check on the original vector using
4237/// reduce.{or,umax} or reduce.{and,umin}.
4238///
4239/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4240bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4241 CmpPredicate Pred;
4242 IntrinsicInst *ReduceOp;
4243 const APInt *CmpVal;
4244 if (!match(V: &I,
4245 P: m_ICmp(Pred, L: m_OneUse(SubPattern: m_AnyIntrinsic(I&: ReduceOp)), R: m_APInt(Res&: CmpVal))))
4246 return false;
4247
4248 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4249 switch (OrigIID) {
4250 case Intrinsic::vector_reduce_or:
4251 case Intrinsic::vector_reduce_umax:
4252 case Intrinsic::vector_reduce_and:
4253 case Intrinsic::vector_reduce_umin:
4254 case Intrinsic::vector_reduce_add:
4255 break;
4256 default:
4257 return false;
4258 }
4259
4260 Value *ReductionSrc = ReduceOp->getArgOperand(i: 0);
4261 auto *VecTy = dyn_cast<FixedVectorType>(Val: ReductionSrc->getType());
4262 if (!VecTy)
4263 return false;
4264
4265 unsigned BitWidth = VecTy->getScalarSizeInBits();
4266 if (BitWidth == 1)
4267 return false;
4268
4269 unsigned NumElts = VecTy->getNumElements();
4270
4271 // Determine the expected tree opcode for multi-vector patterns.
4272 // The tree opcode must match the reduction's underlying operation.
4273 //
4274 // TODO: for pairs of equivalent operators, we should match both,
4275 // not only the most common.
4276 Instruction::BinaryOps TreeOpcode;
4277 switch (OrigIID) {
4278 case Intrinsic::vector_reduce_or:
4279 case Intrinsic::vector_reduce_umax:
4280 TreeOpcode = Instruction::Or;
4281 break;
4282 case Intrinsic::vector_reduce_and:
4283 case Intrinsic::vector_reduce_umin:
4284 TreeOpcode = Instruction::And;
4285 break;
4286 case Intrinsic::vector_reduce_add:
4287 TreeOpcode = Instruction::Add;
4288 break;
4289 default:
4290 llvm_unreachable("Unexpected intrinsic");
4291 }
4292
4293 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4294 // The tree conceptually extends the vector being reduced.
4295 SmallVector<Value *, 8> Worklist;
4296 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4297 Worklist.push_back(Elt: ReductionSrc);
4298 std::optional<bool> IsAShr;
4299 constexpr unsigned MaxSources = 8;
4300
4301 // Calculate old cost: all shifts + tree ops + reduction
4302 InstructionCost OldCost = TTI.getInstructionCost(U: ReduceOp, CostKind);
4303
4304 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4305 Sources.size() <= MaxSources) {
4306 Value *V = Worklist.pop_back_val();
4307
4308 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4309 Value *X;
4310 if (match(V, P: m_OneUse(SubPattern: m_Shr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth - 1))))) {
4311 auto *Shr = cast<Instruction>(Val: V);
4312
4313 // All shifts must be the same type (all lshr or all ashr)
4314 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4315 if (!IsAShr)
4316 IsAShr = ThisIsAShr;
4317 else if (*IsAShr != ThisIsAShr)
4318 return false;
4319
4320 Sources.push_back(Elt: X);
4321
4322 // As part of the fold, we remove all of the shifts, so we need to keep
4323 // track of their costs.
4324 OldCost += TTI.getInstructionCost(U: Shr, CostKind);
4325
4326 continue;
4327 }
4328
4329 // Try to extend through a tree node of the expected opcode
4330 Value *A, *B;
4331 if (!match(V, P: m_OneUse(SubPattern: m_BinOp(Opcode: TreeOpcode, L: m_Value(V&: A), R: m_Value(V&: B)))))
4332 return false;
4333
4334 // We are potentially replacing these operations as well, so we add them
4335 // to the costs.
4336 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: V), CostKind);
4337
4338 Worklist.push_back(Elt: A);
4339 Worklist.push_back(Elt: B);
4340 }
4341
4342 // Must have at least one source and not exceed limit
4343 if (Sources.empty() || Sources.size() > MaxSources ||
4344 Worklist.size() > MaxSources || !IsAShr)
4345 return false;
4346
4347 unsigned NumSources = Sources.size();
4348
4349 // For reduce.add, the total count must fit as a signed integer.
4350 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4351 if (OrigIID == Intrinsic::vector_reduce_add &&
4352 !isIntN(N: BitWidth, x: NumSources * NumElts))
4353 return false;
4354
4355 // Compute the boundary value when all elements are negative:
4356 // - Per-element contribution: 1 for lshr, -1 for ashr
4357 // - For add: M*N (total elements across all sources); for others: just 1
4358 unsigned Count =
4359 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4360 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4361 if (*IsAShr)
4362 NegativeVal.negate();
4363
4364 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4365 APInt Zero = APInt::getZero(numBits: CmpVal->getBitWidth());
4366 APInt RangeLow = APIntOps::smin(A: Zero, B: NegativeVal);
4367 APInt RangeHigh = APIntOps::smax(A: Zero, B: NegativeVal);
4368
4369 // Determine comparison semantics:
4370 // - IsEq: true for equality test, false for inequality
4371 // - TestsNegative: true if testing against AllNegVal, false for zero
4372 //
4373 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4374 // that fold to boundary tests given the narrow value range:
4375 // < RangeHigh -> != RangeHigh
4376 // > RangeHigh-1 -> == RangeHigh
4377 // > RangeLow -> != RangeLow
4378 // < RangeLow+1 -> == RangeLow
4379 //
4380 // For inequalities, we work with signed predicates only. Unsigned predicates
4381 // are canonicalized to signed when the range is non-negative (where they are
4382 // equivalent). When the range includes negative values, unsigned predicates
4383 // would have different semantics due to wrap-around, so we reject them.
4384 if (!ICmpInst::isEquality(P: Pred) && !ICmpInst::isSigned(Pred)) {
4385 if (RangeLow.isNegative())
4386 return false;
4387 Pred = ICmpInst::getSignedPredicate(Pred);
4388 }
4389
4390 bool IsEq;
4391 bool TestsNegative;
4392 if (ICmpInst::isEquality(P: Pred)) {
4393 if (CmpVal->isZero()) {
4394 TestsNegative = false;
4395 } else if (*CmpVal == NegativeVal) {
4396 TestsNegative = true;
4397 } else {
4398 return false;
4399 }
4400 IsEq = Pred == ICmpInst::ICMP_EQ;
4401 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4402 IsEq = false;
4403 TestsNegative = (RangeHigh == NegativeVal);
4404 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4405 IsEq = true;
4406 TestsNegative = (RangeHigh == NegativeVal);
4407 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4408 IsEq = false;
4409 TestsNegative = (RangeLow == NegativeVal);
4410 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4411 IsEq = true;
4412 TestsNegative = (RangeLow == NegativeVal);
4413 } else {
4414 return false;
4415 }
4416
4417 // For this fold we support four types of checks:
4418 //
4419 // 1. All lanes are negative - AllNeg
4420 // 2. All lanes are non-negative - AllNonNeg
4421 // 3. At least one negative lane - AnyNeg
4422 // 4. At least one non-negative lane - AnyNonNeg
4423 //
4424 // For each case, we can generate the following code:
4425 //
4426 // 1. AllNeg - reduce.and/umin(X) < 0
4427 // 2. AllNonNeg - reduce.or/umax(X) > -1
4428 // 3. AnyNeg - reduce.or/umax(X) < 0
4429 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4430 //
4431 // The table below shows the aggregation of all supported cases
4432 // using these four cases.
4433 //
4434 // Reduction | == 0 | != 0 | == MAX | != MAX
4435 // ------------+-----------+-----------+-----------+-----------
4436 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4437 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4438 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4439 //
4440 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4441 //
4442 // For easier codegen and check inversion, we use the following encoding:
4443 //
4444 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4445 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4446 // 3. Bit-1 === universal (1) or existential (0) check
4447 //
4448 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4449 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4450 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4451 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4452 //
4453 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4454 //
4455 enum CheckKind : unsigned {
4456 AnyNonNeg = 0b000,
4457 AllNeg = 0b011,
4458 AllNonNeg = 0b101,
4459 AnyNeg = 0b110,
4460 };
4461 // Return true if we fold this check into or/umax and false for and/umin
4462 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4463 // Return true if we should check if result is negative and false otherwise
4464 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4465 // Logically invert the check
4466 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4467
4468 CheckKind Base;
4469 switch (OrigIID) {
4470 case Intrinsic::vector_reduce_or:
4471 case Intrinsic::vector_reduce_umax:
4472 Base = TestsNegative ? AnyNeg : AllNonNeg;
4473 break;
4474 case Intrinsic::vector_reduce_and:
4475 case Intrinsic::vector_reduce_umin:
4476 Base = TestsNegative ? AllNeg : AnyNonNeg;
4477 break;
4478 case Intrinsic::vector_reduce_add:
4479 Base = TestsNegative ? AllNeg : AllNonNeg;
4480 break;
4481 default:
4482 llvm_unreachable("Unexpected intrinsic");
4483 }
4484
4485 CheckKind Check = IsEq ? Base : Invert(Base);
4486
4487 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4488 InstructionCost ArithCost =
4489 TTI.getArithmeticReductionCost(Opcode: getArithmeticReductionInstruction(RdxID: Arith),
4490 Ty: VecTy, FMF: std::nullopt, CostKind);
4491 InstructionCost MinMaxCost =
4492 TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: MinMax), Ty: VecTy,
4493 FMF: FastMathFlags(), CostKind);
4494 return ArithCost <= MinMaxCost ? std::make_pair(x&: Arith, y&: ArithCost)
4495 : std::make_pair(x&: MinMax, y&: MinMaxCost);
4496 };
4497
4498 // Choose output reduction based on encoding's MSB
4499 auto [NewIID, NewCost] = RequiresOr(Check)
4500 ? PickCheaper(Intrinsic::vector_reduce_or,
4501 Intrinsic::vector_reduce_umax)
4502 : PickCheaper(Intrinsic::vector_reduce_and,
4503 Intrinsic::vector_reduce_umin);
4504
4505 // Add cost of combining multiple sources with or/and
4506 if (NumSources > 1) {
4507 unsigned CombineOpc =
4508 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4509 NewCost += TTI.getArithmeticInstrCost(Opcode: CombineOpc, Ty: VecTy, CostKind) *
4510 (NumSources - 1);
4511 }
4512
4513 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4514 << OldCost << " vs NewCost: " << NewCost << "\n");
4515
4516 if (NewCost > OldCost)
4517 return false;
4518
4519 // Generate the combined input and reduction
4520 Builder.SetInsertPoint(&I);
4521 Type *ScalarTy = VecTy->getScalarType();
4522
4523 Value *Input;
4524 if (NumSources == 1) {
4525 Input = Sources[0];
4526 } else {
4527 // Combine sources with or/and based on check type
4528 Input = RequiresOr(Check) ? Builder.CreateOr(Ops: Sources)
4529 : Builder.CreateAnd(Ops: Sources);
4530 }
4531
4532 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: NewIID, Args: {Input});
4533 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(Arg: NewReduce)
4534 : Builder.CreateIsNotNeg(Arg: NewReduce);
4535 replaceValue(Old&: I, New&: *NewCmp);
4536 return true;
4537}
4538
4539/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4540///
4541/// We can prove it for cases when:
4542///
4543/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4544/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4545/// 2. f(x) == 0 <=> x == 0
4546///
4547/// From 1 and 2 (or 1' and 2), we can infer that
4548///
4549/// OP f(X_i) == 0 <=> OP X_i == 0.
4550///
4551/// (1)
4552/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4553/// (2)
4554/// <=> \forall i \in [1, N] X_i == 0
4555/// (1)
4556/// <=> OP(X_i) == 0
4557///
4558/// For some of the OP's and f's, we need to have domain constraints on X
4559/// to ensure properties 1 (or 1') and 2.
4560bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4561 CmpPredicate Pred;
4562 Value *Op;
4563 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op), R: m_Zero())) ||
4564 !ICmpInst::isEquality(P: Pred))
4565 return false;
4566
4567 auto *II = dyn_cast<IntrinsicInst>(Val: Op);
4568 if (!II)
4569 return false;
4570
4571 switch (II->getIntrinsicID()) {
4572 case Intrinsic::vector_reduce_add:
4573 case Intrinsic::vector_reduce_or:
4574 case Intrinsic::vector_reduce_umin:
4575 case Intrinsic::vector_reduce_umax:
4576 case Intrinsic::vector_reduce_smin:
4577 case Intrinsic::vector_reduce_smax:
4578 break;
4579 default:
4580 return false;
4581 }
4582
4583 Value *InnerOp = II->getArgOperand(i: 0);
4584
4585 // TODO: fixed vector type might be too restrictive
4586 if (!II->hasOneUse() || !isa<FixedVectorType>(Val: InnerOp->getType()))
4587 return false;
4588
4589 Value *X = nullptr;
4590
4591 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4592 //
4593 // 1. f(x) = shl nuw x, y for arbitrary y
4594 // 2. f(x) = mul nuw x, c for defined c != 0
4595 // 3. f(x) = zext x
4596 // 4. f(x) = sext x
4597 // 5. f(x) = neg x
4598 //
4599 if (!(match(V: InnerOp, P: m_NUWShl(L: m_Value(V&: X), R: m_Value())) || // Case 1
4600 match(V: InnerOp, P: m_NUWMul(L: m_Value(V&: X), R: m_NonZeroInt())) || // Case 2
4601 match(V: InnerOp, P: m_ZExt(Op: m_Value(V&: X))) || // Case 3
4602 match(V: InnerOp, P: m_SExt(Op: m_Value(V&: X))) || // Case 4
4603 match(V: InnerOp, P: m_Neg(V: m_Value(V&: X))) // Case 5
4604 ))
4605 return false;
4606
4607 SimplifyQuery S = SQ.getWithInstruction(I: &I);
4608 auto *XTy = cast<FixedVectorType>(Val: X->getType());
4609
4610 // Check for domain constraints for all supported reductions.
4611 //
4612 // a. OR X_i - has property 1 for every X
4613 // b. UMAX X_i - has property 1 for every X
4614 // c. UMIN X_i - has property 1' for every X
4615 // d. SMAX X_i - has property 1 for X >= 0
4616 // e. SMIN X_i - has property 1' for X >= 0
4617 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4618 //
4619 // In order for the proof to work, we need 1 (or 1') to be true for both
4620 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4621 //
4622 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4623 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4624 // of known bits, we can't reasonably hold knowledge of "either 0
4625 // or negative".
4626 switch (II->getIntrinsicID()) {
4627 case Intrinsic::vector_reduce_add: {
4628 // We need to check that both X_i and f(X_i) have enough leading
4629 // zeros to not overflow.
4630 KnownBits KnownX = computeKnownBits(V: X, Q: S);
4631 KnownBits KnownFX = computeKnownBits(V: InnerOp, Q: S);
4632 unsigned NumElems = XTy->getNumElements();
4633 // Adding N elements loses at most ceil(log2(N)) leading bits.
4634 unsigned LostBits = Log2_32_Ceil(Value: NumElems);
4635 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4636 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4637 // Need at least one leading zero left after summation to ensure no overflow
4638 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4639 return false;
4640
4641 // We are not checking whether X or f(X) are positive explicitly because
4642 // we implicitly checked for it when we checked if both cases have enough
4643 // leading zeros to not wrap addition.
4644 break;
4645 }
4646 case Intrinsic::vector_reduce_smin:
4647 case Intrinsic::vector_reduce_smax:
4648 // Check whether X >= 0 and f(X) >= 0
4649 if (!isKnownNonNegative(V: InnerOp, SQ: S) || !isKnownNonNegative(V: X, SQ: S))
4650 return false;
4651
4652 break;
4653 default:
4654 break;
4655 };
4656
4657 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4658 << *II << "\n");
4659
4660 // For zext/sext, check if the transform is profitable using cost model.
4661 // For other operations (shl, mul, neg), we're removing an instruction
4662 // while keeping the same reduction type, so it's always profitable.
4663 if (isa<ZExtInst>(Val: InnerOp) || isa<SExtInst>(Val: InnerOp)) {
4664 auto *FXTy = cast<FixedVectorType>(Val: InnerOp->getType());
4665 Intrinsic::ID IID = II->getIntrinsicID();
4666
4667 InstructionCost ExtCost = TTI.getCastInstrCost(
4668 Opcode: cast<CastInst>(Val: InnerOp)->getOpcode(), Dst: FXTy, Src: XTy,
4669 CCH: TTI::CastContextHint::None, CostKind, I: cast<CastInst>(Val: InnerOp));
4670
4671 InstructionCost OldReduceCost, NewReduceCost;
4672 switch (IID) {
4673 case Intrinsic::vector_reduce_add:
4674 case Intrinsic::vector_reduce_or:
4675 OldReduceCost = TTI.getArithmeticReductionCost(
4676 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: FXTy, FMF: std::nullopt, CostKind);
4677 NewReduceCost = TTI.getArithmeticReductionCost(
4678 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: XTy, FMF: std::nullopt, CostKind);
4679 break;
4680 case Intrinsic::vector_reduce_umin:
4681 case Intrinsic::vector_reduce_umax:
4682 case Intrinsic::vector_reduce_smin:
4683 case Intrinsic::vector_reduce_smax:
4684 OldReduceCost = TTI.getMinMaxReductionCost(
4685 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: FXTy, FMF: FastMathFlags(), CostKind);
4686 NewReduceCost = TTI.getMinMaxReductionCost(
4687 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: XTy, FMF: FastMathFlags(), CostKind);
4688 break;
4689 default:
4690 llvm_unreachable("Unexpected reduction");
4691 }
4692
4693 InstructionCost OldCost = OldReduceCost + ExtCost;
4694 InstructionCost NewCost =
4695 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4696
4697 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4698 << *InnerOp << "\n OldCost: " << OldCost
4699 << " vs NewCost: " << NewCost << "\n");
4700
4701 // We consider transformation to still be potentially beneficial even
4702 // when the costs are the same because we might remove a use from f(X)
4703 // and unlock other optimizations. Equal costs would just mean that we
4704 // didn't make it worse in the worst case.
4705 if (NewCost > OldCost)
4706 return false;
4707 }
4708
4709 // Since we support zext and sext as f, we might change the scalar type
4710 // of the intrinsic.
4711 Type *Ty = XTy->getScalarType();
4712 Value *NewReduce = Builder.CreateIntrinsic(RetTy: Ty, ID: II->getIntrinsicID(), Args: {X});
4713 Value *NewCmp =
4714 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::getNullValue(Ty));
4715 replaceValue(Old&: I, New&: *NewCmp);
4716 return true;
4717}
4718
4719/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4720/// based on cost, preserving the comparison semantics.
4721///
4722/// We use two fundamental properties for each pair:
4723///
4724/// 1. or(X) == 0 <=> umax(X) == 0
4725/// 2. or(X) == 1 <=> umax(X) == 1
4726/// 3. sign(or(X)) == sign(umax(X))
4727///
4728/// 1. and(X) == -1 <=> umin(X) == -1
4729/// 2. and(X) == -2 <=> umin(X) == -2
4730/// 3. sign(and(X)) == sign(umin(X))
4731///
4732/// From these we can infer the following transformations:
4733/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4734/// b. or(X) s< 0 <-> umax(X) s< 0
4735/// c. or(X) s> -1 <-> umax(X) s> -1
4736/// d. or(X) s< 1 <-> umax(X) s< 1
4737/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4738/// f. or(X) s< 2 <-> umax(X) s< 2
4739/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4740/// h. and(X) s< 0 <-> umin(X) s< 0
4741/// i. and(X) s> -1 <-> umin(X) s> -1
4742/// j. and(X) s> -2 <-> umin(X) s> -2
4743/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4744/// l. and(X) s> -3 <-> umin(X) s> -3
4745///
4746bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4747 CmpPredicate Pred;
4748 Value *ReduceOp;
4749 const APInt *CmpVal;
4750 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: ReduceOp), R: m_APInt(Res&: CmpVal))))
4751 return false;
4752
4753 auto *II = dyn_cast<IntrinsicInst>(Val: ReduceOp);
4754 if (!II || !II->hasOneUse())
4755 return false;
4756
4757 const auto IsValidOrUmaxCmp = [&]() {
4758 // or === umax for i1
4759 if (CmpVal->getBitWidth() == 1)
4760 return true;
4761
4762 // Cases a and e
4763 bool IsEquality =
4764 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(P: Pred);
4765 // Case c
4766 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4767 // Cases b, d, and f
4768 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4769 Pred == ICmpInst::ICMP_SLT;
4770 return IsEquality || IsPositive || IsNegative;
4771 };
4772
4773 const auto IsValidAndUminCmp = [&]() {
4774 // and === umin for i1
4775 if (CmpVal->getBitWidth() == 1)
4776 return true;
4777
4778 const auto LeadingOnes = CmpVal->countl_one();
4779
4780 // Cases g and k
4781 bool IsEquality =
4782 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4783 ICmpInst::isEquality(P: Pred);
4784 // Case h
4785 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4786 // Cases i, j, and l
4787 bool IsPositive =
4788 // if the number has at least N - 2 leading ones
4789 // and the two LSBs are:
4790 // - 1 x 1 -> -1
4791 // - 1 x 0 -> -2
4792 // - 0 x 1 -> -3
4793 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4794 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4795 return IsEquality || IsNegative || IsPositive;
4796 };
4797
4798 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4799 Intrinsic::ID AlternativeIID;
4800
4801 // Check if this is a valid comparison pattern and determine the alternate
4802 // reduction intrinsic.
4803 switch (OriginalIID) {
4804 case Intrinsic::vector_reduce_or:
4805 if (!IsValidOrUmaxCmp())
4806 return false;
4807 AlternativeIID = Intrinsic::vector_reduce_umax;
4808 break;
4809 case Intrinsic::vector_reduce_umax:
4810 if (!IsValidOrUmaxCmp())
4811 return false;
4812 AlternativeIID = Intrinsic::vector_reduce_or;
4813 break;
4814 case Intrinsic::vector_reduce_and:
4815 if (!IsValidAndUminCmp())
4816 return false;
4817 AlternativeIID = Intrinsic::vector_reduce_umin;
4818 break;
4819 case Intrinsic::vector_reduce_umin:
4820 if (!IsValidAndUminCmp())
4821 return false;
4822 AlternativeIID = Intrinsic::vector_reduce_and;
4823 break;
4824 default:
4825 return false;
4826 }
4827
4828 Value *X = II->getArgOperand(i: 0);
4829 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
4830 if (!VecTy)
4831 return false;
4832
4833 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4834 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4835 if (ReductionOpc != Instruction::ICmp)
4836 return TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecTy, FMF: std::nullopt,
4837 CostKind);
4838 return TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: VecTy,
4839 FMF: FastMathFlags(), CostKind);
4840 };
4841
4842 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4843 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4844
4845 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4846 << "\n OrigCost: " << OrigCost
4847 << " vs AltCost: " << AltCost << "\n");
4848
4849 if (AltCost >= OrigCost)
4850 return false;
4851
4852 Builder.SetInsertPoint(&I);
4853 Type *ScalarTy = VecTy->getScalarType();
4854 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: AlternativeIID, Args: {X});
4855 Value *NewCmp =
4856 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::get(Ty: ScalarTy, V: *CmpVal));
4857
4858 replaceValue(Old&: I, New&: *NewCmp);
4859 return true;
4860}
4861
4862/// Returns true if this ShuffleVectorInst eventually feeds into a
4863/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
4864/// chains of shuffles and binary operators (in any combination/order).
4865/// The search does not go deeper than the given Depth.
4866static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI) {
4867 constexpr unsigned MaxVisited = 32;
4868 SmallPtrSet<Instruction *, 8> Visited;
4869 SmallVector<Instruction *, 4> WorkList;
4870 bool FoundReduction = false;
4871
4872 WorkList.push_back(Elt: SVI);
4873 while (!WorkList.empty()) {
4874 Instruction *I = WorkList.pop_back_val();
4875 for (User *U : I->users()) {
4876 auto *UI = cast<Instruction>(Val: U);
4877 if (!UI || !Visited.insert(Ptr: UI).second)
4878 continue;
4879 if (Visited.size() > MaxVisited)
4880 return false;
4881 if (auto *II = dyn_cast<IntrinsicInst>(Val: UI)) {
4882 // More than one reduction reached
4883 if (FoundReduction)
4884 return false;
4885 switch (II->getIntrinsicID()) {
4886 case Intrinsic::vector_reduce_add:
4887 case Intrinsic::vector_reduce_mul:
4888 case Intrinsic::vector_reduce_and:
4889 case Intrinsic::vector_reduce_or:
4890 case Intrinsic::vector_reduce_xor:
4891 case Intrinsic::vector_reduce_smin:
4892 case Intrinsic::vector_reduce_smax:
4893 case Intrinsic::vector_reduce_umin:
4894 case Intrinsic::vector_reduce_umax:
4895 FoundReduction = true;
4896 continue;
4897 default:
4898 return false;
4899 }
4900 }
4901
4902 if (!isa<BinaryOperator>(Val: UI) && !isa<ShuffleVectorInst>(Val: UI))
4903 return false;
4904
4905 WorkList.emplace_back(Args&: UI);
4906 }
4907 }
4908 return FoundReduction;
4909}
4910
4911/// This method looks for groups of shuffles acting on binops, of the form:
4912/// %x = shuffle ...
4913/// %y = shuffle ...
4914/// %a = binop %x, %y
4915/// %b = binop %x, %y
4916/// shuffle %a, %b, selectmask
4917/// We may, especially if the shuffle is wider than legal, be able to convert
4918/// the shuffle to a form where only parts of a and b need to be computed. On
4919/// architectures with no obvious "select" shuffle, this can reduce the total
4920/// number of operations if the target reports them as cheaper.
4921bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
4922 auto *SVI = cast<ShuffleVectorInst>(Val: &I);
4923 auto *VT = cast<FixedVectorType>(Val: I.getType());
4924 auto *Op0 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 0));
4925 auto *Op1 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 1));
4926 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
4927 VT != Op0->getType())
4928 return false;
4929
4930 auto *SVI0A = dyn_cast<Instruction>(Val: Op0->getOperand(i: 0));
4931 auto *SVI0B = dyn_cast<Instruction>(Val: Op0->getOperand(i: 1));
4932 auto *SVI1A = dyn_cast<Instruction>(Val: Op1->getOperand(i: 0));
4933 auto *SVI1B = dyn_cast<Instruction>(Val: Op1->getOperand(i: 1));
4934 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
4935 auto checkSVNonOpUses = [&](Instruction *I) {
4936 if (!I || I->getOperand(i: 0)->getType() != VT)
4937 return true;
4938 return any_of(Range: I->users(), P: [&](User *U) {
4939 return U != Op0 && U != Op1 &&
4940 !(isa<ShuffleVectorInst>(Val: U) &&
4941 (InputShuffles.contains(Ptr: cast<Instruction>(Val: U)) ||
4942 isInstructionTriviallyDead(I: cast<Instruction>(Val: U))));
4943 });
4944 };
4945 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
4946 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
4947 return false;
4948
4949 // Collect all the uses that are shuffles that we can transform together. We
4950 // may not have a single shuffle, but a group that can all be transformed
4951 // together profitably.
4952 SmallVector<ShuffleVectorInst *> Shuffles;
4953 auto collectShuffles = [&](Instruction *I) {
4954 for (auto *U : I->users()) {
4955 auto *SV = dyn_cast<ShuffleVectorInst>(Val: U);
4956 if (!SV || SV->getType() != VT)
4957 return false;
4958 if ((SV->getOperand(i_nocapture: 0) != Op0 && SV->getOperand(i_nocapture: 0) != Op1) ||
4959 (SV->getOperand(i_nocapture: 1) != Op0 && SV->getOperand(i_nocapture: 1) != Op1))
4960 return false;
4961 if (!llvm::is_contained(Range&: Shuffles, Element: SV))
4962 Shuffles.push_back(Elt: SV);
4963 }
4964 return true;
4965 };
4966 if (!collectShuffles(Op0) || !collectShuffles(Op1))
4967 return false;
4968 // From a reduction, we need to be processing a single shuffle, otherwise the
4969 // other uses will not be lane-invariant.
4970 if (FromReduction && Shuffles.size() > 1)
4971 return false;
4972
4973 // Add any shuffle uses for the shuffles we have found, to include them in our
4974 // cost calculations.
4975 if (!FromReduction) {
4976 for (ShuffleVectorInst *SV : Shuffles) {
4977 for (auto *U : SV->users()) {
4978 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(Val: U);
4979 if (SSV && isa<UndefValue>(Val: SSV->getOperand(i_nocapture: 1)) && SSV->getType() == VT)
4980 Shuffles.push_back(Elt: SSV);
4981 }
4982 }
4983 }
4984
4985 // For each of the output shuffles, we try to sort all the first vector
4986 // elements to the beginning, followed by the second array elements at the
4987 // end. If the binops are legalized to smaller vectors, this may reduce total
4988 // number of binops. We compute the ReconstructMask mask needed to convert
4989 // back to the original lane order.
4990 SmallVector<std::pair<int, int>> V1, V2;
4991 SmallVector<SmallVector<int>> OrigReconstructMasks;
4992 int MaxV1Elt = 0, MaxV2Elt = 0;
4993 unsigned NumElts = VT->getNumElements();
4994 for (ShuffleVectorInst *SVN : Shuffles) {
4995 SmallVector<int> Mask;
4996 SVN->getShuffleMask(Result&: Mask);
4997
4998 // Check the operands are the same as the original, or reversed (in which
4999 // case we need to commute the mask).
5000 Value *SVOp0 = SVN->getOperand(i_nocapture: 0);
5001 Value *SVOp1 = SVN->getOperand(i_nocapture: 1);
5002 if (isa<UndefValue>(Val: SVOp1)) {
5003 auto *SSV = cast<ShuffleVectorInst>(Val: SVOp0);
5004 SVOp0 = SSV->getOperand(i_nocapture: 0);
5005 SVOp1 = SSV->getOperand(i_nocapture: 1);
5006 for (int &Elem : Mask) {
5007 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
5008 return false;
5009 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elt: Elem);
5010 }
5011 }
5012 if (SVOp0 == Op1 && SVOp1 == Op0) {
5013 std::swap(a&: SVOp0, b&: SVOp1);
5014 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumElts);
5015 }
5016 if (SVOp0 != Op0 || SVOp1 != Op1)
5017 return false;
5018
5019 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5020 // take the packed values from Op0/Op1 and reconstructing to the original
5021 // order.
5022 SmallVector<int> ReconstructMask;
5023 for (unsigned I = 0; I < Mask.size(); I++) {
5024 if (Mask[I] < 0) {
5025 ReconstructMask.push_back(Elt: -1);
5026 } else if (Mask[I] < static_cast<int>(NumElts)) {
5027 MaxV1Elt = std::max(a: MaxV1Elt, b: Mask[I]);
5028 auto It = find_if(Range&: V1, P: [&](const std::pair<int, int> &A) {
5029 return Mask[I] == A.first;
5030 });
5031 if (It != V1.end())
5032 ReconstructMask.push_back(Elt: It - V1.begin());
5033 else {
5034 ReconstructMask.push_back(Elt: V1.size());
5035 V1.emplace_back(Args&: Mask[I], Args: V1.size());
5036 }
5037 } else {
5038 MaxV2Elt = std::max<int>(a: MaxV2Elt, b: Mask[I] - NumElts);
5039 auto It = find_if(Range&: V2, P: [&](const std::pair<int, int> &A) {
5040 return Mask[I] - static_cast<int>(NumElts) == A.first;
5041 });
5042 if (It != V2.end())
5043 ReconstructMask.push_back(Elt: NumElts + It - V2.begin());
5044 else {
5045 ReconstructMask.push_back(Elt: NumElts + V2.size());
5046 V2.emplace_back(Args: Mask[I] - NumElts, Args: NumElts + V2.size());
5047 }
5048 }
5049 }
5050
5051 // For reductions, we know that the lane ordering out doesn't alter the
5052 // result. In-order can help simplify the shuffle away.
5053 if (FromReduction)
5054 sort(C&: ReconstructMask);
5055 OrigReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5056 }
5057
5058 // If the Maximum element used from V1 and V2 are not larger than the new
5059 // vectors, the vectors are already packes and performing the optimization
5060 // again will likely not help any further. This also prevents us from getting
5061 // stuck in a cycle in case the costs do not also rule it out.
5062 if (V1.empty() || V2.empty() ||
5063 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5064 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5065 return false;
5066
5067 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5068 // shuffle of another shuffle, or not a shuffle (that is treated like a
5069 // identity shuffle).
5070 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5071 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5072 if (!SV)
5073 return M;
5074 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5075 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5076 if (InputShuffles.contains(Ptr: SSV))
5077 return SSV->getMaskValue(Elt: SV->getMaskValue(Elt: M));
5078 return SV->getMaskValue(Elt: M);
5079 };
5080
5081 // Attempt to sort the inputs my ascending mask values to make simpler input
5082 // shuffles and push complex shuffles down to the uses. We sort on the first
5083 // of the two input shuffle orders, to try and get at least one input into a
5084 // nice order.
5085 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5086 std::pair<int, int> Y) {
5087 int MXA = GetBaseMaskValue(A, X.first);
5088 int MYA = GetBaseMaskValue(A, Y.first);
5089 return MXA < MYA;
5090 };
5091 stable_sort(Range&: V1, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5092 return SortBase(SVI0A, A, B);
5093 });
5094 stable_sort(Range&: V2, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5095 return SortBase(SVI1A, A, B);
5096 });
5097 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5098 // modified order of the input shuffles.
5099 SmallVector<SmallVector<int>> ReconstructMasks;
5100 for (const auto &Mask : OrigReconstructMasks) {
5101 SmallVector<int> ReconstructMask;
5102 for (int M : Mask) {
5103 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5104 auto It = find_if(Range: V, P: [M](auto A) { return A.second == M; });
5105 assert(It != V.end() && "Expected all entries in Mask");
5106 return std::distance(first: V.begin(), last: It);
5107 };
5108 if (M < 0)
5109 ReconstructMask.push_back(Elt: -1);
5110 else if (M < static_cast<int>(NumElts)) {
5111 ReconstructMask.push_back(Elt: FindIndex(V1, M));
5112 } else {
5113 ReconstructMask.push_back(Elt: NumElts + FindIndex(V2, M));
5114 }
5115 }
5116 ReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5117 }
5118
5119 // Calculate the masks needed for the new input shuffles, which get padded
5120 // with undef
5121 SmallVector<int> V1A, V1B, V2A, V2B;
5122 for (unsigned I = 0; I < V1.size(); I++) {
5123 V1A.push_back(Elt: GetBaseMaskValue(SVI0A, V1[I].first));
5124 V1B.push_back(Elt: GetBaseMaskValue(SVI0B, V1[I].first));
5125 }
5126 for (unsigned I = 0; I < V2.size(); I++) {
5127 V2A.push_back(Elt: GetBaseMaskValue(SVI1A, V2[I].first));
5128 V2B.push_back(Elt: GetBaseMaskValue(SVI1B, V2[I].first));
5129 }
5130 while (V1A.size() < NumElts) {
5131 V1A.push_back(Elt: PoisonMaskElem);
5132 V1B.push_back(Elt: PoisonMaskElem);
5133 }
5134 while (V2A.size() < NumElts) {
5135 V2A.push_back(Elt: PoisonMaskElem);
5136 V2B.push_back(Elt: PoisonMaskElem);
5137 }
5138
5139 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5140 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5141 if (!SV)
5142 return C;
5143 return C + TTI.getShuffleCost(Kind: isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1))
5144 ? TTI::SK_PermuteSingleSrc
5145 : TTI::SK_PermuteTwoSrc,
5146 DstTy: VT, SrcTy: VT, Mask: SV->getShuffleMask(), CostKind);
5147 };
5148 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5149 return C +
5150 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5151 };
5152
5153 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5154 unsigned MaxVectorSize =
5155 TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector);
5156 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5157 if (MaxElementsInVector == 0)
5158 return false;
5159 // When there are multiple shufflevector operations on the same input,
5160 // especially when the vector length is larger than the register size,
5161 // identical shuffle patterns may occur across different groups of elements.
5162 // To avoid overestimating the cost by counting these repeated shuffles more
5163 // than once, we only account for unique shuffle patterns. This adjustment
5164 // prevents inflated costs in the cost model for wide vectors split into
5165 // several register-sized groups.
5166 std::set<SmallVector<int, 4>> UniqueShuffles;
5167 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5168 // Compute the cost for performing the shuffle over the full vector.
5169 auto ShuffleCost =
5170 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5171 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5172 if (NumFullVectors < 2)
5173 return C + ShuffleCost;
5174 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5175 unsigned NumUniqueGroups = 0;
5176 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5177 // For each group of MaxElementsInVector contiguous elements,
5178 // collect their shuffle pattern and insert into the set of unique patterns.
5179 for (unsigned I = 0; I < NumFullVectors; ++I) {
5180 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5181 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5182 if (UniqueShuffles.insert(x: SubShuffle).second)
5183 NumUniqueGroups += 1;
5184 }
5185 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5186 };
5187 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5188 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5189 if (!SV)
5190 return C;
5191 SmallVector<int, 16> Mask;
5192 SV->getShuffleMask(Result&: Mask);
5193 return AddShuffleMaskAdjustedCost(C, Mask);
5194 };
5195 // Check that input consists of ShuffleVectors applied to the same input
5196 auto AllShufflesHaveSameOperands =
5197 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5198 if (InputShuffles.size() < 2)
5199 return false;
5200 ShuffleVectorInst *FirstSV =
5201 dyn_cast<ShuffleVectorInst>(Val: *InputShuffles.begin());
5202 if (!FirstSV)
5203 return false;
5204
5205 Value *In0 = FirstSV->getOperand(i_nocapture: 0), *In1 = FirstSV->getOperand(i_nocapture: 1);
5206 return std::all_of(
5207 first: std::next(x: InputShuffles.begin()), last: InputShuffles.end(),
5208 pred: [&](Instruction *I) {
5209 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5210 return SV && SV->getOperand(i_nocapture: 0) == In0 && SV->getOperand(i_nocapture: 1) == In1;
5211 });
5212 };
5213
5214 // Get the costs of the shuffles + binops before and after with the new
5215 // shuffle masks.
5216 InstructionCost CostBefore =
5217 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: VT, CostKind) +
5218 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: VT, CostKind);
5219 CostBefore += std::accumulate(first: Shuffles.begin(), last: Shuffles.end(),
5220 init: InstructionCost(0), binary_op: AddShuffleCost);
5221 if (AllShufflesHaveSameOperands(InputShuffles)) {
5222 UniqueShuffles.clear();
5223 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5224 init: InstructionCost(0), binary_op: AddShuffleAdjustedCost);
5225 } else {
5226 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5227 init: InstructionCost(0), binary_op: AddShuffleCost);
5228 }
5229
5230 // The new binops will be unused for lanes past the used shuffle lengths.
5231 // These types attempt to get the correct cost for that from the target.
5232 FixedVectorType *Op0SmallVT =
5233 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V1.size());
5234 FixedVectorType *Op1SmallVT =
5235 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V2.size());
5236 InstructionCost CostAfter =
5237 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: Op0SmallVT, CostKind) +
5238 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: Op1SmallVT, CostKind);
5239 UniqueShuffles.clear();
5240 CostAfter += std::accumulate(first: ReconstructMasks.begin(), last: ReconstructMasks.end(),
5241 init: InstructionCost(0), binary_op: AddShuffleMaskAdjustedCost);
5242 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5243 CostAfter +=
5244 std::accumulate(first: OutputShuffleMasks.begin(), last: OutputShuffleMasks.end(),
5245 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
5246
5247 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5248 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5249 << " vs CostAfter: " << CostAfter << "\n");
5250 if (CostBefore < CostAfter ||
5251 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5252 return false;
5253
5254 // The cost model has passed, create the new instructions.
5255 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5256 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5257 if (!SV)
5258 return I;
5259 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5260 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5261 if (InputShuffles.contains(Ptr: SSV))
5262 return SSV->getOperand(i_nocapture: Op);
5263 return SV->getOperand(i_nocapture: Op);
5264 };
5265 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5266 Value *NSV0A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0A, 0),
5267 V2: GetShuffleOperand(SVI0A, 1), Mask: V1A);
5268 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5269 Value *NSV0B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0B, 0),
5270 V2: GetShuffleOperand(SVI0B, 1), Mask: V1B);
5271 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5272 Value *NSV1A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1A, 0),
5273 V2: GetShuffleOperand(SVI1A, 1), Mask: V2A);
5274 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5275 Value *NSV1B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1B, 0),
5276 V2: GetShuffleOperand(SVI1B, 1), Mask: V2B);
5277 Builder.SetInsertPoint(Op0);
5278 Value *NOp0 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op0->getOpcode(),
5279 LHS: NSV0A, RHS: NSV0B);
5280 if (auto *I = dyn_cast<Instruction>(Val: NOp0))
5281 I->copyIRFlags(V: Op0, IncludeWrapFlags: true);
5282 Builder.SetInsertPoint(Op1);
5283 Value *NOp1 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op1->getOpcode(),
5284 LHS: NSV1A, RHS: NSV1B);
5285 if (auto *I = dyn_cast<Instruction>(Val: NOp1))
5286 I->copyIRFlags(V: Op1, IncludeWrapFlags: true);
5287
5288 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5289 Builder.SetInsertPoint(Shuffles[S]);
5290 Value *NSV = Builder.CreateShuffleVector(V1: NOp0, V2: NOp1, Mask: ReconstructMasks[S]);
5291 replaceValue(Old&: *Shuffles[S], New&: *NSV, Erase: false);
5292 }
5293
5294 Worklist.pushValue(V: NSV0A);
5295 Worklist.pushValue(V: NSV0B);
5296 Worklist.pushValue(V: NSV1A);
5297 Worklist.pushValue(V: NSV1B);
5298 return true;
5299}
5300
5301/// Check if instruction depends on ZExt and this ZExt can be moved after the
5302/// instruction. Move ZExt if it is profitable. For example:
5303/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5304/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5305/// Cost model calculations takes into account if zext(x) has other users and
5306/// whether it can be propagated through them too.
5307bool VectorCombine::shrinkType(Instruction &I) {
5308 Value *ZExted, *OtherOperand;
5309 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_ZExt(Op: m_Value(V&: ZExted)),
5310 R: m_Value(V&: OtherOperand))) &&
5311 !match(V: &I, P: m_LShr(L: m_ZExt(Op: m_Value(V&: ZExted)), R: m_Value(V&: OtherOperand))))
5312 return false;
5313
5314 Value *ZExtOperand = I.getOperand(i: I.getOperand(i: 0) == OtherOperand ? 1 : 0);
5315
5316 auto *BigTy = cast<FixedVectorType>(Val: I.getType());
5317 auto *SmallTy = cast<FixedVectorType>(Val: ZExted->getType());
5318 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5319
5320 if (I.getOpcode() == Instruction::LShr) {
5321 // Check that the shift amount is less than the number of bits in the
5322 // smaller type. Otherwise, the smaller lshr will return a poison value.
5323 KnownBits ShAmtKB = computeKnownBits(V: I.getOperand(i: 1), DL: *DL);
5324 if (ShAmtKB.getMaxValue().uge(RHS: BW))
5325 return false;
5326 } else {
5327 // Check that the expression overall uses at most the same number of bits as
5328 // ZExted
5329 KnownBits KB = computeKnownBits(V: &I, DL: *DL);
5330 if (KB.countMaxActiveBits() > BW)
5331 return false;
5332 }
5333
5334 // Calculate costs of leaving current IR as it is and moving ZExt operation
5335 // later, along with adding truncates if needed
5336 InstructionCost ZExtCost = TTI.getCastInstrCost(
5337 Opcode: Instruction::ZExt, Dst: BigTy, Src: SmallTy,
5338 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5339 InstructionCost CurrentCost = ZExtCost;
5340 InstructionCost ShrinkCost = 0;
5341
5342 // Calculate total cost and check that we can propagate through all ZExt users
5343 for (User *U : ZExtOperand->users()) {
5344 auto *UI = cast<Instruction>(Val: U);
5345 if (UI == &I) {
5346 CurrentCost +=
5347 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5348 ShrinkCost +=
5349 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5350 ShrinkCost += ZExtCost;
5351 continue;
5352 }
5353
5354 if (!Instruction::isBinaryOp(Opcode: UI->getOpcode()))
5355 return false;
5356
5357 // Check if we can propagate ZExt through its other users
5358 KnownBits KB = computeKnownBits(V: UI, DL: *DL);
5359 if (KB.countMaxActiveBits() > BW)
5360 return false;
5361
5362 CurrentCost += TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5363 ShrinkCost +=
5364 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5365 ShrinkCost += ZExtCost;
5366 }
5367
5368 // If the other instruction operand is not a constant, we'll need to
5369 // generate a truncate instruction. So we have to adjust cost
5370 if (!isa<Constant>(Val: OtherOperand))
5371 ShrinkCost += TTI.getCastInstrCost(
5372 Opcode: Instruction::Trunc, Dst: SmallTy, Src: BigTy,
5373 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5374
5375 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5376 // towards modifying the IR because shrinking opens opportunities for other
5377 // shrinking optimisations.
5378 if (ShrinkCost > CurrentCost)
5379 return false;
5380
5381 Builder.SetInsertPoint(&I);
5382 Value *Op0 = ZExted;
5383 Value *Op1 = Builder.CreateTrunc(V: OtherOperand, DestTy: SmallTy);
5384 // Keep the order of operands the same
5385 if (I.getOperand(i: 0) == OtherOperand)
5386 std::swap(a&: Op0, b&: Op1);
5387 Value *NewBinOp =
5388 Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(), LHS: Op0, RHS: Op1);
5389 cast<Instruction>(Val: NewBinOp)->copyIRFlags(V: &I);
5390 cast<Instruction>(Val: NewBinOp)->copyMetadata(SrcInst: I);
5391 Value *NewZExtr = Builder.CreateZExt(V: NewBinOp, DestTy: BigTy);
5392 replaceValue(Old&: I, New&: *NewZExtr);
5393 return true;
5394}
5395
5396/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5397/// shuffle (DstVec, SrcVec, Mask)
5398bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5399 Value *DstVec, *SrcVec;
5400 uint64_t ExtIdx, InsIdx;
5401 if (!match(V: &I,
5402 P: m_InsertElt(Val: m_Value(V&: DstVec),
5403 Elt: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx)),
5404 Idx: m_ConstantInt(V&: InsIdx))))
5405 return false;
5406
5407 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
5408 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
5409 // We can try combining vectors with different element sizes.
5410 if (!DstVecTy || !SrcVecTy ||
5411 SrcVecTy->getElementType() != DstVecTy->getElementType())
5412 return false;
5413
5414 unsigned NumDstElts = DstVecTy->getNumElements();
5415 unsigned NumSrcElts = SrcVecTy->getNumElements();
5416 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5417 return false;
5418
5419 // Insertion into poison is a cheaper single operand shuffle.
5420 TargetTransformInfo::ShuffleKind SK;
5421 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5422
5423 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5424 bool NeedDstSrcSwap = isa<PoisonValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec);
5425 if (NeedDstSrcSwap) {
5426 SK = TargetTransformInfo::SK_PermuteSingleSrc;
5427 Mask[InsIdx] = ExtIdx % NumDstElts;
5428 std::swap(a&: DstVec, b&: SrcVec);
5429 } else {
5430 SK = TargetTransformInfo::SK_PermuteTwoSrc;
5431 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
5432 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5433 }
5434
5435 // Cost
5436 auto *Ins = cast<InsertElementInst>(Val: &I);
5437 auto *Ext = cast<ExtractElementInst>(Val: I.getOperand(i: 1));
5438 InstructionCost InsCost =
5439 TTI.getVectorInstrCost(I: *Ins, Val: DstVecTy, CostKind, Index: InsIdx);
5440 InstructionCost ExtCost =
5441 TTI.getVectorInstrCost(I: *Ext, Val: DstVecTy, CostKind, Index: ExtIdx);
5442 InstructionCost OldCost = ExtCost + InsCost;
5443
5444 InstructionCost NewCost = 0;
5445 SmallVector<int> ExtToVecMask;
5446 if (!NeedExpOrNarrow) {
5447 // Ignore 'free' identity insertion shuffle.
5448 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5449 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5450 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind, Index: 0,
5451 SubTp: nullptr, Args: {DstVec, SrcVec});
5452 } else {
5453 // When creating a length-changing-vector, always try to keep the relevant
5454 // element in an equivalent position, so that bulk shuffles are more likely
5455 // to be useful.
5456 ExtToVecMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
5457 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5458 // Add cost for expanding or narrowing
5459 NewCost = TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
5460 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: ExtToVecMask, CostKind);
5461 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind);
5462 }
5463
5464 if (!Ext->hasOneUse())
5465 NewCost += ExtCost;
5466
5467 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5468 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5469 << "\n");
5470
5471 if (OldCost < NewCost)
5472 return false;
5473
5474 if (NeedExpOrNarrow) {
5475 if (!NeedDstSrcSwap)
5476 SrcVec = Builder.CreateShuffleVector(V: SrcVec, Mask: ExtToVecMask);
5477 else
5478 DstVec = Builder.CreateShuffleVector(V: DstVec, Mask: ExtToVecMask);
5479 }
5480
5481 // Canonicalize undef param to RHS to help further folds.
5482 if (isa<UndefValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec)) {
5483 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumDstElts);
5484 std::swap(a&: DstVec, b&: SrcVec);
5485 }
5486
5487 Value *Shuf = Builder.CreateShuffleVector(V1: DstVec, V2: SrcVec, Mask);
5488 replaceValue(Old&: I, New&: *Shuf);
5489
5490 return true;
5491}
5492
5493/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5494/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5495/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5496/// before casting it back into `<vscale x 16 x i32>`.
5497bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5498 const APInt *SplatVal0, *SplatVal1;
5499 if (!match(V: &I, P: m_Intrinsic<Intrinsic::vector_interleave2>(
5500 Op0: m_APInt(Res&: SplatVal0), Op1: m_APInt(Res&: SplatVal1))))
5501 return false;
5502
5503 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5504 << "\n");
5505
5506 auto *VTy =
5507 cast<VectorType>(Val: cast<IntrinsicInst>(Val&: I).getArgOperand(i: 0)->getType());
5508 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5509 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5510
5511 // Just in case the cost of interleave2 intrinsic and bitcast are both
5512 // invalid, in which case we want to bail out, we use <= rather
5513 // than < here. Even they both have valid and equal costs, it's probably
5514 // not a good idea to emit a high-cost constant splat.
5515 if (TTI.getInstructionCost(U: &I, CostKind) <=
5516 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: I.getType(), Src: ExtVTy,
5517 CCH: TTI::CastContextHint::None, CostKind)) {
5518 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5519 << *I.getType() << " is too high.\n");
5520 return false;
5521 }
5522
5523 APInt NewSplatVal = SplatVal1->zext(width: Width * 2);
5524 NewSplatVal <<= Width;
5525 NewSplatVal |= SplatVal0->zext(width: Width * 2);
5526 auto *NewSplat = ConstantVector::getSplat(
5527 EC: ExtVTy->getElementCount(), Elt: ConstantInt::get(Context&: F.getContext(), V: NewSplatVal));
5528
5529 IRBuilder<> Builder(&I);
5530 replaceValue(Old&: I, New&: *Builder.CreateBitCast(V: NewSplat, DestTy: I.getType()));
5531 return true;
5532}
5533
5534// Attempt to shrink loads that are only used by shufflevector instructions.
5535bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5536 auto *OldLoad = dyn_cast<LoadInst>(Val: &I);
5537 if (!OldLoad || !OldLoad->isSimple())
5538 return false;
5539
5540 auto *OldLoadTy = dyn_cast<FixedVectorType>(Val: OldLoad->getType());
5541 if (!OldLoadTy)
5542 return false;
5543
5544 unsigned const OldNumElements = OldLoadTy->getNumElements();
5545
5546 // Search all uses of load. If all uses are shufflevector instructions, and
5547 // the second operands are all poison values, find the minimum and maximum
5548 // indices of the vector elements referenced by all shuffle masks.
5549 // Otherwise return `std::nullopt`.
5550 using IndexRange = std::pair<int, int>;
5551 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5552 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5553 for (llvm::Use &Use : I.uses()) {
5554 // Ensure all uses match the required pattern.
5555 User *Shuffle = Use.getUser();
5556 ArrayRef<int> Mask;
5557
5558 if (!match(V: Shuffle,
5559 P: m_Shuffle(v1: m_Specific(V: OldLoad), v2: m_Undef(), mask: m_Mask(Mask))))
5560 return std::nullopt;
5561
5562 // Ignore shufflevector instructions that have no uses.
5563 if (Shuffle->use_empty())
5564 continue;
5565
5566 // Find the min and max indices used by the shufflevector instruction.
5567 for (int Index : Mask) {
5568 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5569 OutputRange.first = std::min(a: Index, b: OutputRange.first);
5570 OutputRange.second = std::max(a: Index, b: OutputRange.second);
5571 }
5572 }
5573 }
5574
5575 if (OutputRange.second < OutputRange.first)
5576 return std::nullopt;
5577
5578 return OutputRange;
5579 };
5580
5581 // Get the range of vector elements used by shufflevector instructions.
5582 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5583 unsigned const NewNumElements = Indices->second + 1u;
5584
5585 // If the range of vector elements is smaller than the full load, attempt
5586 // to create a smaller load.
5587 if (NewNumElements < OldNumElements) {
5588 IRBuilder Builder(&I);
5589 Builder.SetCurrentDebugLocation(I.getDebugLoc());
5590
5591 // Calculate costs of old and new ops.
5592 Type *ElemTy = OldLoadTy->getElementType();
5593 FixedVectorType *NewLoadTy = FixedVectorType::get(ElementType: ElemTy, NumElts: NewNumElements);
5594 Value *PtrOp = OldLoad->getPointerOperand();
5595
5596 InstructionCost OldCost = TTI.getMemoryOpCost(
5597 Opcode: Instruction::Load, Src: OldLoad->getType(), Alignment: OldLoad->getAlign(),
5598 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5599 InstructionCost NewCost =
5600 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: NewLoadTy, Alignment: OldLoad->getAlign(),
5601 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5602
5603 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
5604 SmallVector<UseEntry, 4u> NewUses;
5605 unsigned const MaxIndex = NewNumElements * 2u;
5606
5607 for (llvm::Use &Use : I.uses()) {
5608 auto *Shuffle = cast<ShuffleVectorInst>(Val: Use.getUser());
5609
5610 // Ignore shufflevector instructions that have no uses.
5611 if (Shuffle->use_empty())
5612 continue;
5613
5614 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
5615
5616 // Create entry for new use.
5617 NewUses.push_back(Elt: {Shuffle, OldMask});
5618
5619 // Validate mask indices.
5620 for (int Index : OldMask) {
5621 if (Index >= static_cast<int>(MaxIndex))
5622 return false;
5623 }
5624
5625 // Update costs.
5626 OldCost +=
5627 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5628 SrcTy: OldLoadTy, Mask: OldMask, CostKind);
5629 NewCost +=
5630 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5631 SrcTy: NewLoadTy, Mask: OldMask, CostKind);
5632 }
5633
5634 LLVM_DEBUG(
5635 dbgs() << "Found a load used only by shufflevector instructions: "
5636 << I << "\n OldCost: " << OldCost
5637 << " vs NewCost: " << NewCost << "\n");
5638
5639 if (OldCost < NewCost || !NewCost.isValid())
5640 return false;
5641
5642 // Create new load of smaller vector.
5643 auto *NewLoad = cast<LoadInst>(
5644 Val: Builder.CreateAlignedLoad(Ty: NewLoadTy, Ptr: PtrOp, Align: OldLoad->getAlign()));
5645 NewLoad->copyMetadata(SrcInst: I);
5646
5647 // Replace all uses.
5648 for (UseEntry &Use : NewUses) {
5649 ShuffleVectorInst *Shuffle = Use.first;
5650 std::vector<int> &NewMask = Use.second;
5651
5652 Builder.SetInsertPoint(Shuffle);
5653 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
5654 Value *NewShuffle = Builder.CreateShuffleVector(
5655 V1: NewLoad, V2: PoisonValue::get(T: NewLoadTy), Mask: NewMask);
5656
5657 replaceValue(Old&: *Shuffle, New&: *NewShuffle, Erase: false);
5658 }
5659
5660 return true;
5661 }
5662 }
5663 return false;
5664}
5665
5666// Attempt to narrow a phi of shufflevector instructions where the two incoming
5667// values have the same operands but different masks. If the two shuffle masks
5668// are offsets of one another we can use one branch to rotate the incoming
5669// vector and perform one larger shuffle after the phi.
5670bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
5671 auto *Phi = dyn_cast<PHINode>(Val: &I);
5672 if (!Phi || Phi->getNumIncomingValues() != 2u)
5673 return false;
5674
5675 Value *Op = nullptr;
5676 ArrayRef<int> Mask0;
5677 ArrayRef<int> Mask1;
5678
5679 if (!match(V: Phi->getOperand(i_nocapture: 0u),
5680 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op), v2: m_Poison(), mask: m_Mask(Mask0)))) ||
5681 !match(V: Phi->getOperand(i_nocapture: 1u),
5682 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Specific(V: Op), v2: m_Poison(), mask: m_Mask(Mask1)))))
5683 return false;
5684
5685 auto *Shuf = cast<ShuffleVectorInst>(Val: Phi->getOperand(i_nocapture: 0u));
5686
5687 // Ensure result vectors are wider than the argument vector.
5688 auto *InputVT = cast<FixedVectorType>(Val: Op->getType());
5689 auto *ResultVT = cast<FixedVectorType>(Val: Shuf->getType());
5690 auto const InputNumElements = InputVT->getNumElements();
5691
5692 if (InputNumElements >= ResultVT->getNumElements())
5693 return false;
5694
5695 // Take the difference of the two shuffle masks at each index. Ignore poison
5696 // values at the same index in both masks.
5697 SmallVector<int, 16> NewMask;
5698 NewMask.reserve(N: Mask0.size());
5699
5700 for (auto [M0, M1] : zip(t&: Mask0, u&: Mask1)) {
5701 if (M0 >= 0 && M1 >= 0)
5702 NewMask.push_back(Elt: M0 - M1);
5703 else if (M0 == -1 && M1 == -1)
5704 continue;
5705 else
5706 return false;
5707 }
5708
5709 // Ensure all elements of the new mask are equal. If the difference between
5710 // the incoming mask elements is the same, the two must be constant offsets
5711 // of one another.
5712 if (NewMask.empty() || !all_equal(Range&: NewMask))
5713 return false;
5714
5715 // Create new mask using difference of the two incoming masks.
5716 int MaskOffset = NewMask[0u];
5717 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
5718 NewMask.clear();
5719
5720 for (unsigned I = 0u; I < InputNumElements; ++I) {
5721 NewMask.push_back(Elt: Index);
5722 Index = (Index + 1u) % InputNumElements;
5723 }
5724
5725 // Calculate costs for worst cases and compare.
5726 auto const Kind = TTI::SK_PermuteSingleSrc;
5727 auto OldCost =
5728 std::max(a: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask0, CostKind),
5729 b: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind));
5730 auto NewCost = TTI.getShuffleCost(Kind, DstTy: InputVT, SrcTy: InputVT, Mask: NewMask, CostKind) +
5731 TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind);
5732
5733 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
5734 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5735 << "\n");
5736
5737 if (NewCost > OldCost)
5738 return false;
5739
5740 // Create new shuffles and narrowed phi.
5741 auto Builder = IRBuilder(Shuf);
5742 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
5743 auto *PoisonVal = PoisonValue::get(T: InputVT);
5744 auto *NewShuf0 = Builder.CreateShuffleVector(V1: Op, V2: PoisonVal, Mask: NewMask);
5745 Worklist.push(I: cast<Instruction>(Val: NewShuf0));
5746
5747 Builder.SetInsertPoint(Phi);
5748 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
5749 auto *NewPhi = Builder.CreatePHI(Ty: NewShuf0->getType(), NumReservedValues: 2u);
5750 NewPhi->addIncoming(V: NewShuf0, BB: Phi->getIncomingBlock(i: 0u));
5751 NewPhi->addIncoming(V: Op, BB: Phi->getIncomingBlock(i: 1u));
5752
5753 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
5754 PoisonVal = PoisonValue::get(T: NewPhi->getType());
5755 auto *NewShuf1 = Builder.CreateShuffleVector(V1: NewPhi, V2: PoisonVal, Mask: Mask1);
5756
5757 replaceValue(Old&: *Phi, New&: *NewShuf1);
5758 return true;
5759}
5760
5761/// This is the entry point for all transforms. Pass manager differences are
5762/// handled in the callers of this function.
5763bool VectorCombine::run() {
5764 if (DisableVectorCombine)
5765 return false;
5766
5767 // Don't attempt vectorization if the target does not support vectors.
5768 if (!TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(/*Vector*/ true)))
5769 return false;
5770
5771 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
5772
5773 auto FoldInst = [this](Instruction &I) {
5774 Builder.SetInsertPoint(&I);
5775 bool IsVectorType = isa<VectorType>(Val: I.getType());
5776 bool IsFixedVectorType = isa<FixedVectorType>(Val: I.getType());
5777 auto Opcode = I.getOpcode();
5778
5779 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
5780
5781 // These folds should be beneficial regardless of when this pass is run
5782 // in the optimization pipeline.
5783 // The type checking is for run-time efficiency. We can avoid wasting time
5784 // dispatching to folding functions if there's no chance of matching.
5785 if (IsFixedVectorType) {
5786 switch (Opcode) {
5787 case Instruction::InsertElement:
5788 if (vectorizeLoadInsert(I))
5789 return true;
5790 break;
5791 case Instruction::ShuffleVector:
5792 if (widenSubvectorLoad(I))
5793 return true;
5794 break;
5795 default:
5796 break;
5797 }
5798 }
5799
5800 // This transform works with scalable and fixed vectors
5801 // TODO: Identify and allow other scalable transforms
5802 if (IsVectorType) {
5803 if (scalarizeOpOrCmp(I))
5804 return true;
5805 if (scalarizeLoad(I))
5806 return true;
5807 if (scalarizeExtExtract(I))
5808 return true;
5809 if (scalarizeVPIntrinsic(I))
5810 return true;
5811 if (foldInterleaveIntrinsics(I))
5812 return true;
5813 }
5814
5815 if (Opcode == Instruction::Store)
5816 if (foldSingleElementStore(I))
5817 return true;
5818
5819 // If this is an early pipeline invocation of this pass, we are done.
5820 if (TryEarlyFoldsOnly)
5821 return false;
5822
5823 // Otherwise, try folds that improve codegen but may interfere with
5824 // early IR canonicalizations.
5825 // The type checking is for run-time efficiency. We can avoid wasting time
5826 // dispatching to folding functions if there's no chance of matching.
5827 if (IsFixedVectorType) {
5828 switch (Opcode) {
5829 case Instruction::InsertElement:
5830 if (foldInsExtFNeg(I))
5831 return true;
5832 if (foldInsExtBinop(I))
5833 return true;
5834 if (foldInsExtVectorToShuffle(I))
5835 return true;
5836 break;
5837 case Instruction::ShuffleVector:
5838 if (foldPermuteOfBinops(I))
5839 return true;
5840 if (foldShuffleOfBinops(I))
5841 return true;
5842 if (foldShuffleOfSelects(I))
5843 return true;
5844 if (foldShuffleOfCastops(I))
5845 return true;
5846 if (foldShuffleOfShuffles(I))
5847 return true;
5848 if (foldPermuteOfIntrinsic(I))
5849 return true;
5850 if (foldShufflesOfLengthChangingShuffles(I))
5851 return true;
5852 if (foldShuffleOfIntrinsics(I))
5853 return true;
5854 if (foldSelectShuffle(I))
5855 return true;
5856 if (foldShuffleToIdentity(I))
5857 return true;
5858 break;
5859 case Instruction::Load:
5860 if (shrinkLoadForShuffles(I))
5861 return true;
5862 break;
5863 case Instruction::BitCast:
5864 if (foldBitcastShuffle(I))
5865 return true;
5866 if (foldSelectsFromBitcast(I))
5867 return true;
5868 break;
5869 case Instruction::And:
5870 case Instruction::Or:
5871 case Instruction::Xor:
5872 if (foldBitOpOfCastops(I))
5873 return true;
5874 if (foldBitOpOfCastConstant(I))
5875 return true;
5876 break;
5877 case Instruction::PHI:
5878 if (shrinkPhiOfShuffles(I))
5879 return true;
5880 break;
5881 default:
5882 if (shrinkType(I))
5883 return true;
5884 break;
5885 }
5886 } else {
5887 switch (Opcode) {
5888 case Instruction::Call:
5889 if (foldShuffleFromReductions(I))
5890 return true;
5891 if (foldCastFromReductions(I))
5892 return true;
5893 break;
5894 case Instruction::ExtractElement:
5895 if (foldShuffleChainsToReduce(I))
5896 return true;
5897 break;
5898 case Instruction::ICmp:
5899 if (foldSignBitReductionCmp(I))
5900 return true;
5901 if (foldICmpEqZeroVectorReduce(I))
5902 return true;
5903 if (foldEquivalentReductionCmp(I))
5904 return true;
5905 [[fallthrough]];
5906 case Instruction::FCmp:
5907 if (foldExtractExtract(I))
5908 return true;
5909 break;
5910 case Instruction::Or:
5911 if (foldConcatOfBoolMasks(I))
5912 return true;
5913 [[fallthrough]];
5914 default:
5915 if (Instruction::isBinaryOp(Opcode)) {
5916 if (foldExtractExtract(I))
5917 return true;
5918 if (foldExtractedCmps(I))
5919 return true;
5920 if (foldBinopOfReductions(I))
5921 return true;
5922 }
5923 break;
5924 }
5925 }
5926 return false;
5927 };
5928
5929 bool MadeChange = false;
5930 for (BasicBlock &BB : F) {
5931 // Ignore unreachable basic blocks.
5932 if (!DT.isReachableFromEntry(A: &BB))
5933 continue;
5934 // Use early increment range so that we can erase instructions in loop.
5935 // make_early_inc_range is not applicable here, as the next iterator may
5936 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
5937 // We manually maintain the next instruction and update it when it is about
5938 // to be deleted.
5939 Instruction *I = &BB.front();
5940 while (I) {
5941 NextInst = I->getNextNode();
5942 if (!I->isDebugOrPseudoInst())
5943 MadeChange |= FoldInst(*I);
5944 I = NextInst;
5945 }
5946 }
5947
5948 NextInst = nullptr;
5949
5950 while (!Worklist.isEmpty()) {
5951 Instruction *I = Worklist.removeOne();
5952 if (!I)
5953 continue;
5954
5955 if (isInstructionTriviallyDead(I)) {
5956 eraseInstruction(I&: *I);
5957 continue;
5958 }
5959
5960 MadeChange |= FoldInst(*I);
5961 }
5962
5963 return MadeChange;
5964}
5965
5966PreservedAnalyses VectorCombinePass::run(Function &F,
5967 FunctionAnalysisManager &FAM) {
5968 auto &AC = FAM.getResult<AssumptionAnalysis>(IR&: F);
5969 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
5970 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
5971 AAResults &AA = FAM.getResult<AAManager>(IR&: F);
5972 const DataLayout *DL = &F.getDataLayout();
5973 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
5974 TryEarlyFoldsOnly);
5975 if (!Combiner.run())
5976 return PreservedAnalyses::all();
5977 PreservedAnalyses PA;
5978 PA.preserveSet<CFGAnalyses>();
5979 return PA;
5980}
5981