1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Vectorize/VectorCombine.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/Statistic.h"
21#include "llvm/Analysis/AssumptionCache.h"
22#include "llvm/Analysis/BasicAliasAnalysis.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/Loads.h"
26#include "llvm/Analysis/TargetFolder.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/Analysis/ValueTracking.h"
29#include "llvm/Analysis/VectorUtils.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/PatternMatch.h"
35#include "llvm/Support/CommandLine.h"
36#include "llvm/Support/MathExtras.h"
37#include "llvm/Transforms/Utils/Local.h"
38#include "llvm/Transforms/Utils/LoopUtils.h"
39#include <numeric>
40#include <optional>
41#include <queue>
42#include <set>
43
44#define DEBUG_TYPE "vector-combine"
45#include "llvm/Transforms/Utils/InstructionWorklist.h"
46
47using namespace llvm;
48using namespace llvm::PatternMatch;
49
50STATISTIC(NumVecLoad, "Number of vector loads formed");
51STATISTIC(NumVecCmp, "Number of vector compares formed");
52STATISTIC(NumVecBO, "Number of vector binops formed");
53STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
54STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
55STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
56STATISTIC(NumScalarCmp, "Number of scalar compares formed");
57STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
58
59static cl::opt<bool> DisableVectorCombine(
60 "disable-vector-combine", cl::init(Val: false), cl::Hidden,
61 cl::desc("Disable all vector combine transforms"));
62
63static cl::opt<bool> DisableBinopExtractShuffle(
64 "disable-binop-extract-shuffle", cl::init(Val: false), cl::Hidden,
65 cl::desc("Disable binop extract to shuffle transforms"));
66
67static cl::opt<unsigned> MaxInstrsToScan(
68 "vector-combine-max-scan-instrs", cl::init(Val: 30), cl::Hidden,
69 cl::desc("Max number of instructions to scan for vector combining."));
70
71static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
72
73namespace {
74class VectorCombine {
75public:
76 VectorCombine(Function &F, const TargetTransformInfo &TTI,
77 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
78 const DataLayout *DL, TTI::TargetCostKind CostKind,
79 bool TryEarlyFoldsOnly)
80 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
81 DT(DT), AA(AA), AC(AC), DL(DL), CostKind(CostKind), SQ(*DL),
82 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
83
84 bool run();
85
86private:
87 Function &F;
88 IRBuilder<InstSimplifyFolder> Builder;
89 const TargetTransformInfo &TTI;
90 const DominatorTree &DT;
91 AAResults &AA;
92 AssumptionCache &AC;
93 const DataLayout *DL;
94 TTI::TargetCostKind CostKind;
95 const SimplifyQuery SQ;
96
97 /// If true, only perform beneficial early IR transforms. Do not introduce new
98 /// vector operations.
99 bool TryEarlyFoldsOnly;
100
101 InstructionWorklist Worklist;
102
103 /// Next instruction to iterate. It will be updated when it is erased by
104 /// RecursivelyDeleteTriviallyDeadInstructions.
105 Instruction *NextInst;
106
107 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
108 // parameter. That should be updated to specific sub-classes because the
109 // run loop was changed to dispatch on opcode.
110 bool vectorizeLoadInsert(Instruction &I);
111 bool widenSubvectorLoad(Instruction &I);
112 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
113 ExtractElementInst *Ext1,
114 unsigned PreferredExtractIndex) const;
115 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
116 const Instruction &I,
117 ExtractElementInst *&ConvertToShuffle,
118 unsigned PreferredExtractIndex);
119 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
120 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 bool foldExtractExtract(Instruction &I);
122 bool foldInsExtFNeg(Instruction &I);
123 bool foldInsExtBinop(Instruction &I);
124 bool foldInsExtVectorToShuffle(Instruction &I);
125 bool foldBitOpOfCastops(Instruction &I);
126 bool foldBitOpOfCastConstant(Instruction &I);
127 bool foldBitcastShuffle(Instruction &I);
128 bool scalarizeOpOrCmp(Instruction &I);
129 bool scalarizeVPIntrinsic(Instruction &I);
130 bool foldExtractedCmps(Instruction &I);
131 bool foldSelectsFromBitcast(Instruction &I);
132 bool foldBinopOfReductions(Instruction &I);
133 bool foldSingleElementStore(Instruction &I);
134 bool scalarizeLoad(Instruction &I);
135 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
136 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeExtExtract(Instruction &I);
138 bool foldConcatOfBoolMasks(Instruction &I);
139 bool foldPermuteOfBinops(Instruction &I);
140 bool foldShuffleOfBinops(Instruction &I);
141 bool foldShuffleOfSelects(Instruction &I);
142 bool foldShuffleOfCastops(Instruction &I);
143 bool foldShuffleOfShuffles(Instruction &I);
144 bool foldPermuteOfIntrinsic(Instruction &I);
145 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
146 bool foldShuffleOfIntrinsics(Instruction &I);
147 bool foldShuffleToIdentity(Instruction &I);
148 bool foldShuffleFromReductions(Instruction &I);
149 bool foldShuffleChainsToReduce(Instruction &I);
150 bool foldCastFromReductions(Instruction &I);
151 bool foldSignBitReductionCmp(Instruction &I);
152 bool foldICmpEqZeroVectorReduce(Instruction &I);
153 bool foldEquivalentReductionCmp(Instruction &I);
154 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
155 bool foldInterleaveIntrinsics(Instruction &I);
156 bool shrinkType(Instruction &I);
157 bool shrinkLoadForShuffles(Instruction &I);
158 bool shrinkPhiOfShuffles(Instruction &I);
159
160 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
161 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
162 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
163 Old.replaceAllUsesWith(V: &New);
164 if (auto *NewI = dyn_cast<Instruction>(Val: &New)) {
165 New.takeName(V: &Old);
166 Worklist.pushUsersToWorkList(I&: *NewI);
167 Worklist.pushValue(V: NewI);
168 }
169 if (Erase && isInstructionTriviallyDead(I: &Old)) {
170 eraseInstruction(I&: Old);
171 } else {
172 Worklist.push(I: &Old);
173 }
174 }
175
176 void eraseInstruction(Instruction &I) {
177 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
178 SmallVector<Value *> Ops(I.operands());
179 Worklist.remove(I: &I);
180 I.eraseFromParent();
181
182 // Push remaining users of the operands and then the operand itself - allows
183 // further folds that were hindered by OneUse limits.
184 SmallPtrSet<Value *, 4> Visited;
185 for (Value *Op : Ops) {
186 if (!Visited.contains(Ptr: Op)) {
187 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
188 if (RecursivelyDeleteTriviallyDeadInstructions(
189 V: OpI, TLI: nullptr, MSSAU: nullptr, AboutToDeleteCallback: [&](Value *V) {
190 if (auto *I = dyn_cast<Instruction>(Val: V)) {
191 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
192 Worklist.remove(I);
193 if (I == NextInst)
194 NextInst = NextInst->getNextNode();
195 Visited.insert(Ptr: I);
196 }
197 }))
198 continue;
199 Worklist.pushUsersToWorkList(I&: *OpI);
200 Worklist.pushValue(V: OpI);
201 }
202 }
203 }
204 }
205};
206} // namespace
207
208/// Return the source operand of a potentially bitcasted value. If there is no
209/// bitcast, return the input value itself.
210static Value *peekThroughBitcasts(Value *V) {
211 while (auto *BitCast = dyn_cast<BitCastInst>(Val: V))
212 V = BitCast->getOperand(i_nocapture: 0);
213 return V;
214}
215
216static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
217 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
218 // The widened load may load data from dirty regions or create data races
219 // non-existent in the source.
220 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
221 Load->getFunction()->hasFnAttribute(Kind: Attribute::SanitizeMemTag) ||
222 mustSuppressSpeculation(LI: *Load))
223 return false;
224
225 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
226 // sure we have all of our type-based constraints in place for this target.
227 Type *ScalarTy = Load->getType()->getScalarType();
228 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
229 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
230 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
231 ScalarSize % 8 != 0)
232 return false;
233
234 return true;
235}
236
237bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
238 // Match insert into fixed vector of scalar value.
239 // TODO: Handle non-zero insert index.
240 Value *Scalar;
241 if (!match(V: &I,
242 P: m_InsertElt(Val: m_Poison(), Elt: m_OneUse(SubPattern: m_Value(V&: Scalar)), Idx: m_ZeroInt())))
243 return false;
244
245 // Optionally match an extract from another vector.
246 Value *X;
247 bool HasExtract = match(V: Scalar, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ZeroInt()));
248 if (!HasExtract)
249 X = Scalar;
250
251 auto *Load = dyn_cast<LoadInst>(Val: X);
252 if (!canWidenLoad(Load, TTI))
253 return false;
254
255 Type *ScalarTy = Scalar->getType();
256 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
257 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
258
259 // Check safety of replacing the scalar load with a larger vector load.
260 // We use minimal alignment (maximum flexibility) because we only care about
261 // the dereferenceable region. When calculating cost and creating a new op,
262 // we may use a larger value based on alignment attributes.
263 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
264 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
265
266 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
267 auto *MinVecTy = VectorType::get(ElementType: ScalarTy, NumElements: MinVecNumElts, Scalable: false);
268 unsigned OffsetEltIndex = 0;
269 Align Alignment = Load->getAlign();
270 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
271 DT: &DT)) {
272 // It is not safe to load directly from the pointer, but we can still peek
273 // through gep offsets and check if it safe to load from a base address with
274 // updated alignment. If it is, we can shuffle the element(s) into place
275 // after loading.
276 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(Ty: SrcPtr->getType());
277 APInt Offset(OffsetBitWidth, 0);
278 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL: *DL, Offset);
279
280 // We want to shuffle the result down from a high element of a vector, so
281 // the offset must be positive.
282 if (Offset.isNegative())
283 return false;
284
285 // The offset must be a multiple of the scalar element to shuffle cleanly
286 // in the element's size.
287 uint64_t ScalarSizeInBytes = ScalarSize / 8;
288 if (Offset.urem(RHS: ScalarSizeInBytes) != 0)
289 return false;
290
291 // If we load MinVecNumElts, will our target element still be loaded?
292 OffsetEltIndex = Offset.udiv(RHS: ScalarSizeInBytes).getZExtValue();
293 if (OffsetEltIndex >= MinVecNumElts)
294 return false;
295
296 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
297 DT: &DT))
298 return false;
299
300 // Update alignment with offset value. Note that the offset could be negated
301 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
302 // negation does not change the result of the alignment calculation.
303 Alignment = commonAlignment(A: Alignment, Offset: Offset.getZExtValue());
304 }
305
306 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
307 // Use the greater of the alignment on the load or its source pointer.
308 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
309 Type *LoadTy = Load->getType();
310 unsigned AS = Load->getPointerAddressSpace();
311 InstructionCost OldCost =
312 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
313 APInt DemandedElts = APInt::getOneBitSet(numBits: MinVecNumElts, BitNo: 0);
314 OldCost +=
315 TTI.getScalarizationOverhead(Ty: MinVecTy, DemandedElts,
316 /* Insert */ true, Extract: HasExtract, CostKind);
317
318 // New pattern: load VecPtr
319 InstructionCost NewCost =
320 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: MinVecTy, Alignment, AddressSpace: AS, CostKind);
321 // Optionally, we are shuffling the loaded vector element(s) into place.
322 // For the mask set everything but element 0 to undef to prevent poison from
323 // propagating from the extra loaded memory. This will also optionally
324 // shrink/grow the vector from the loaded size to the output size.
325 // We assume this operation has no cost in codegen if there was no offset.
326 // Note that we could use freeze to avoid poison problems, but then we might
327 // still need a shuffle to change the vector size.
328 auto *Ty = cast<FixedVectorType>(Val: I.getType());
329 unsigned OutputNumElts = Ty->getNumElements();
330 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
331 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
332 Mask[0] = OffsetEltIndex;
333 if (OffsetEltIndex)
334 NewCost += TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Ty, SrcTy: MinVecTy, Mask,
335 CostKind);
336
337 // We can aggressively convert to the vector form because the backend can
338 // invert this transform if it does not result in a performance win.
339 if (OldCost < NewCost || !NewCost.isValid())
340 return false;
341
342 // It is safe and potentially profitable to load a vector directly:
343 // inselt undef, load Scalar, 0 --> load VecPtr
344 IRBuilder<> Builder(Load);
345 Value *CastedPtr =
346 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
347 Value *VecLd = Builder.CreateAlignedLoad(Ty: MinVecTy, Ptr: CastedPtr, Align: Alignment);
348 VecLd = Builder.CreateShuffleVector(V: VecLd, Mask);
349
350 replaceValue(Old&: I, New&: *VecLd);
351 ++NumVecLoad;
352 return true;
353}
354
355/// If we are loading a vector and then inserting it into a larger vector with
356/// undefined elements, try to load the larger vector and eliminate the insert.
357/// This removes a shuffle in IR and may allow combining of other loaded values.
358bool VectorCombine::widenSubvectorLoad(Instruction &I) {
359 // Match subvector insert of fixed vector.
360 auto *Shuf = cast<ShuffleVectorInst>(Val: &I);
361 if (!Shuf->isIdentityWithPadding())
362 return false;
363
364 // Allow a non-canonical shuffle mask that is choosing elements from op1.
365 unsigned NumOpElts =
366 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
367 unsigned OpIndex = any_of(Range: Shuf->getShuffleMask(), P: [&NumOpElts](int M) {
368 return M >= (int)(NumOpElts);
369 });
370
371 auto *Load = dyn_cast<LoadInst>(Val: Shuf->getOperand(i_nocapture: OpIndex));
372 if (!canWidenLoad(Load, TTI))
373 return false;
374
375 // We use minimal alignment (maximum flexibility) because we only care about
376 // the dereferenceable region. When calculating cost and creating a new op,
377 // we may use a larger value based on alignment attributes.
378 auto *Ty = cast<FixedVectorType>(Val: I.getType());
379 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
380 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
381 Align Alignment = Load->getAlign();
382 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC, DT: &DT))
383 return false;
384
385 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
386 Type *LoadTy = Load->getType();
387 unsigned AS = Load->getPointerAddressSpace();
388
389 // Original pattern: insert_subvector (load PtrOp)
390 // This conservatively assumes that the cost of a subvector insert into an
391 // undef value is 0. We could add that cost if the cost model accurately
392 // reflects the real cost of that operation.
393 InstructionCost OldCost =
394 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
395
396 // New pattern: load PtrOp
397 InstructionCost NewCost =
398 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: Ty, Alignment, AddressSpace: AS, CostKind);
399
400 // We can aggressively convert to the vector form because the backend can
401 // invert this transform if it does not result in a performance win.
402 if (OldCost < NewCost || !NewCost.isValid())
403 return false;
404
405 IRBuilder<> Builder(Load);
406 Value *CastedPtr =
407 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
408 Value *VecLd = Builder.CreateAlignedLoad(Ty, Ptr: CastedPtr, Align: Alignment);
409 replaceValue(Old&: I, New&: *VecLd);
410 ++NumVecLoad;
411 return true;
412}
413
414/// Determine which, if any, of the inputs should be replaced by a shuffle
415/// followed by extract from a different index.
416ExtractElementInst *VectorCombine::getShuffleExtract(
417 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
418 unsigned PreferredExtractIndex = InvalidIndex) const {
419 auto *Index0C = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
420 auto *Index1C = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
421 assert(Index0C && Index1C && "Expected constant extract indexes");
422
423 unsigned Index0 = Index0C->getZExtValue();
424 unsigned Index1 = Index1C->getZExtValue();
425
426 // If the extract indexes are identical, no shuffle is needed.
427 if (Index0 == Index1)
428 return nullptr;
429
430 Type *VecTy = Ext0->getVectorOperand()->getType();
431 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
432 InstructionCost Cost0 =
433 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
434 InstructionCost Cost1 =
435 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
436
437 // If both costs are invalid no shuffle is needed
438 if (!Cost0.isValid() && !Cost1.isValid())
439 return nullptr;
440
441 // We are extracting from 2 different indexes, so one operand must be shuffled
442 // before performing a vector operation and/or extract. The more expensive
443 // extract will be replaced by a shuffle.
444 if (Cost0 > Cost1)
445 return Ext0;
446 if (Cost1 > Cost0)
447 return Ext1;
448
449 // If the costs are equal and there is a preferred extract index, shuffle the
450 // opposite operand.
451 if (PreferredExtractIndex == Index0)
452 return Ext1;
453 if (PreferredExtractIndex == Index1)
454 return Ext0;
455
456 // Otherwise, replace the extract with the higher index.
457 return Index0 > Index1 ? Ext0 : Ext1;
458}
459
460/// Compare the relative costs of 2 extracts followed by scalar operation vs.
461/// vector operation(s) followed by extract. Return true if the existing
462/// instructions are cheaper than a vector alternative. Otherwise, return false
463/// and if one of the extracts should be transformed to a shufflevector, set
464/// \p ConvertToShuffle to that extract instruction.
465bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
466 ExtractElementInst *Ext1,
467 const Instruction &I,
468 ExtractElementInst *&ConvertToShuffle,
469 unsigned PreferredExtractIndex) {
470 auto *Ext0IndexC = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
471 auto *Ext1IndexC = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
472 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
473
474 unsigned Opcode = I.getOpcode();
475 Value *Ext0Src = Ext0->getVectorOperand();
476 Value *Ext1Src = Ext1->getVectorOperand();
477 Type *ScalarTy = Ext0->getType();
478 auto *VecTy = cast<VectorType>(Val: Ext0Src->getType());
479 InstructionCost ScalarOpCost, VectorOpCost;
480
481 // Get cost estimates for scalar and vector versions of the operation.
482 bool IsBinOp = Instruction::isBinaryOp(Opcode);
483 if (IsBinOp) {
484 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
485 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
486 } else {
487 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
488 "Expected a compare");
489 CmpInst::Predicate Pred = cast<CmpInst>(Val: I).getPredicate();
490 ScalarOpCost = TTI.getCmpSelInstrCost(
491 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
492 VectorOpCost = TTI.getCmpSelInstrCost(
493 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
494 }
495
496 // Get cost estimates for the extract elements. These costs will factor into
497 // both sequences.
498 unsigned Ext0Index = Ext0IndexC->getZExtValue();
499 unsigned Ext1Index = Ext1IndexC->getZExtValue();
500
501 InstructionCost Extract0Cost =
502 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Ext0Index);
503 InstructionCost Extract1Cost =
504 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Ext1Index);
505
506 // A more expensive extract will always be replaced by a splat shuffle.
507 // For example, if Ext0 is more expensive:
508 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
509 // extelt (opcode (splat V0, Ext0), V1), Ext1
510 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
511 // check the cost of creating a broadcast shuffle and shuffling both
512 // operands to element 0.
513 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
514 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
515 InstructionCost CheapExtractCost = std::min(a: Extract0Cost, b: Extract1Cost);
516
517 // Extra uses of the extracts mean that we include those costs in the
518 // vector total because those instructions will not be eliminated.
519 InstructionCost OldCost, NewCost;
520 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
521 // Handle a special case. If the 2 extracts are identical, adjust the
522 // formulas to account for that. The extra use charge allows for either the
523 // CSE'd pattern or an unoptimized form with identical values:
524 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
525 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(N: 2)
526 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
527 OldCost = CheapExtractCost + ScalarOpCost;
528 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
529 } else {
530 // Handle the general case. Each extract is actually a different value:
531 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
532 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
533 NewCost = VectorOpCost + CheapExtractCost +
534 !Ext0->hasOneUse() * Extract0Cost +
535 !Ext1->hasOneUse() * Extract1Cost;
536 }
537
538 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
539 if (ConvertToShuffle) {
540 if (IsBinOp && DisableBinopExtractShuffle)
541 return true;
542
543 // If we are extracting from 2 different indexes, then one operand must be
544 // shuffled before performing the vector operation. The shuffle mask is
545 // poison except for 1 lane that is being translated to the remaining
546 // extraction lane. Therefore, it is a splat shuffle. Ex:
547 // ShufMask = { poison, poison, 0, poison }
548 // TODO: The cost model has an option for a "broadcast" shuffle
549 // (splat-from-element-0), but no option for a more general splat.
550 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: VecTy)) {
551 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
552 PoisonMaskElem);
553 ShuffleMask[BestInsIndex] = BestExtIndex;
554 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
555 DstTy: VecTy, SrcTy: VecTy, Mask: ShuffleMask, CostKind, Index: 0,
556 SubTp: nullptr, Args: {ConvertToShuffle});
557 } else {
558 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
559 DstTy: VecTy, SrcTy: VecTy, Mask: {}, CostKind, Index: 0, SubTp: nullptr,
560 Args: {ConvertToShuffle});
561 }
562 }
563
564 // Aggressively form a vector op if the cost is equal because the transform
565 // may enable further optimization.
566 // Codegen can reverse this transform (scalarize) if it was not profitable.
567 return OldCost < NewCost;
568}
569
570/// Create a shuffle that translates (shifts) 1 element from the input vector
571/// to a new element location.
572static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
573 unsigned NewIndex, IRBuilderBase &Builder) {
574 // The shuffle mask is poison except for 1 lane that is being translated
575 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
576 // ShufMask = { 2, poison, poison, poison }
577 auto *VecTy = cast<FixedVectorType>(Val: Vec->getType());
578 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
579 ShufMask[NewIndex] = OldIndex;
580 return Builder.CreateShuffleVector(V: Vec, Mask: ShufMask, Name: "shift");
581}
582
583/// Given an extract element instruction with constant index operand, shuffle
584/// the source vector (shift the scalar element) to a NewIndex for extraction.
585/// Return null if the input can be constant folded, so that we are not creating
586/// unnecessary instructions.
587static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
588 IRBuilderBase &Builder) {
589 // Shufflevectors can only be created for fixed-width vectors.
590 Value *X = ExtElt->getVectorOperand();
591 if (!isa<FixedVectorType>(Val: X->getType()))
592 return nullptr;
593
594 // If the extract can be constant-folded, this code is unsimplified. Defer
595 // to other passes to handle that.
596 Value *C = ExtElt->getIndexOperand();
597 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
598 if (isa<Constant>(Val: X))
599 return nullptr;
600
601 Value *Shuf = createShiftShuffle(Vec: X, OldIndex: cast<ConstantInt>(Val: C)->getZExtValue(),
602 NewIndex, Builder);
603 return Shuf;
604}
605
606/// Try to reduce extract element costs by converting scalar compares to vector
607/// compares followed by extract.
608/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
609Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
610 Instruction &I) {
611 assert(isa<CmpInst>(&I) && "Expected a compare");
612
613 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
614 // --> extelt (cmp Pred V0, V1), ExtIndex
615 ++NumVecCmp;
616 CmpInst::Predicate Pred = cast<CmpInst>(Val: &I)->getPredicate();
617 Value *VecCmp = Builder.CreateCmp(Pred, LHS: V0, RHS: V1);
618 return Builder.CreateExtractElement(Vec: VecCmp, Idx: ExtIndex, Name: "foldExtExtCmp");
619}
620
621/// Try to reduce extract element costs by converting scalar binops to vector
622/// binops followed by extract.
623/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
624Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
625 Instruction &I) {
626 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
627
628 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
629 // --> extelt (bo V0, V1), ExtIndex
630 ++NumVecBO;
631 Value *VecBO = Builder.CreateBinOp(Opc: cast<BinaryOperator>(Val: &I)->getOpcode(), LHS: V0,
632 RHS: V1, Name: "foldExtExtBinop");
633
634 // All IR flags are safe to back-propagate because any potential poison
635 // created in unused vector elements is discarded by the extract.
636 if (auto *VecBOInst = dyn_cast<Instruction>(Val: VecBO))
637 VecBOInst->copyIRFlags(V: &I);
638
639 return Builder.CreateExtractElement(Vec: VecBO, Idx: ExtIndex, Name: "foldExtExtBinop");
640}
641
642/// Match an instruction with extracted vector operands.
643bool VectorCombine::foldExtractExtract(Instruction &I) {
644 // It is not safe to transform things like div, urem, etc. because we may
645 // create undefined behavior when executing those on unknown vector elements.
646 if (!isSafeToSpeculativelyExecute(I: &I))
647 return false;
648
649 Instruction *I0, *I1;
650 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
651 if (!match(V: &I, P: m_Cmp(Pred, L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))) &&
652 !match(V: &I, P: m_BinOp(L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))))
653 return false;
654
655 Value *V0, *V1;
656 uint64_t C0, C1;
657 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: V0), Idx: m_ConstantInt(V&: C0))) ||
658 !match(V: I1, P: m_ExtractElt(Val: m_Value(V&: V1), Idx: m_ConstantInt(V&: C1))) ||
659 V0->getType() != V1->getType())
660 return false;
661
662 // If the scalar value 'I' is going to be re-inserted into a vector, then try
663 // to create an extract to that same element. The extract/insert can be
664 // reduced to a "select shuffle".
665 // TODO: If we add a larger pattern match that starts from an insert, this
666 // probably becomes unnecessary.
667 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
668 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
669 uint64_t InsertIndex = InvalidIndex;
670 if (I.hasOneUse())
671 match(V: I.user_back(),
672 P: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ConstantInt(V&: InsertIndex)));
673
674 ExtractElementInst *ExtractToChange;
675 if (isExtractExtractCheap(Ext0, Ext1, I, ConvertToShuffle&: ExtractToChange, PreferredExtractIndex: InsertIndex))
676 return false;
677
678 Value *ExtOp0 = Ext0->getVectorOperand();
679 Value *ExtOp1 = Ext1->getVectorOperand();
680
681 if (ExtractToChange) {
682 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
683 Value *NewExtOp =
684 translateExtract(ExtElt: ExtractToChange, NewIndex: CheapExtractIdx, Builder);
685 if (!NewExtOp)
686 return false;
687 if (ExtractToChange == Ext0)
688 ExtOp0 = NewExtOp;
689 else
690 ExtOp1 = NewExtOp;
691 }
692
693 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
694 : Ext0->getIndexOperand();
695 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
696 ? foldExtExtCmp(V0: ExtOp0, V1: ExtOp1, ExtIndex, I)
697 : foldExtExtBinop(V0: ExtOp0, V1: ExtOp1, ExtIndex, I);
698 Worklist.push(I: Ext0);
699 Worklist.push(I: Ext1);
700 replaceValue(Old&: I, New&: *NewExt);
701 return true;
702}
703
704/// Try to replace an extract + scalar fneg + insert with a vector fneg +
705/// shuffle.
706bool VectorCombine::foldInsExtFNeg(Instruction &I) {
707 // Match an insert (op (extract)) pattern.
708 Value *DstVec;
709 uint64_t ExtIdx, InsIdx;
710 Instruction *FNeg;
711 if (!match(V: &I, P: m_InsertElt(Val: m_Value(V&: DstVec), Elt: m_OneUse(SubPattern: m_Instruction(I&: FNeg)),
712 Idx: m_ConstantInt(V&: InsIdx))))
713 return false;
714
715 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
716 Value *SrcVec;
717 Instruction *Extract;
718 if (!match(V: FNeg, P: m_FNeg(X: m_CombineAnd(
719 L: m_Instruction(I&: Extract),
720 R: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx))))))
721 return false;
722
723 auto *DstVecTy = cast<FixedVectorType>(Val: DstVec->getType());
724 auto *DstVecScalarTy = DstVecTy->getScalarType();
725 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
726 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
727 return false;
728
729 // Ignore if insert/extract index is out of bounds or destination vector has
730 // one element
731 unsigned NumDstElts = DstVecTy->getNumElements();
732 unsigned NumSrcElts = SrcVecTy->getNumElements();
733 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
734 return false;
735
736 // We are inserting the negated element into the same lane that we extracted
737 // from. This is equivalent to a select-shuffle that chooses all but the
738 // negated element from the destination vector.
739 SmallVector<int> Mask(NumDstElts);
740 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
741 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
742 InstructionCost OldCost =
743 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: DstVecScalarTy, CostKind) +
744 TTI.getVectorInstrCost(I, Val: DstVecTy, CostKind, Index: InsIdx);
745
746 // If the extract has one use, it will be eliminated, so count it in the
747 // original cost. If it has more than one use, ignore the cost because it will
748 // be the same before/after.
749 if (Extract->hasOneUse())
750 OldCost += TTI.getVectorInstrCost(I: *Extract, Val: SrcVecTy, CostKind, Index: ExtIdx);
751
752 InstructionCost NewCost =
753 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: SrcVecTy, CostKind) +
754 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: DstVecTy,
755 SrcTy: DstVecTy, Mask, CostKind);
756
757 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
758 // If the lengths of the two vectors are not equal,
759 // we need to add a length-change vector. Add this cost.
760 SmallVector<int> SrcMask;
761 if (NeedLenChg) {
762 SrcMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
763 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
764 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
765 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: SrcMask, CostKind);
766 }
767
768 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
769 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
770 << "\n");
771 if (NewCost > OldCost)
772 return false;
773
774 Value *NewShuf, *LenChgShuf = nullptr;
775 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
776 Value *VecFNeg = Builder.CreateFNegFMF(V: SrcVec, FMFSource: FNeg);
777 if (NeedLenChg) {
778 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
779 LenChgShuf = Builder.CreateShuffleVector(V: VecFNeg, Mask: SrcMask);
780 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: LenChgShuf, Mask);
781 Worklist.pushValue(V: LenChgShuf);
782 } else {
783 // shuffle DstVec, (fneg SrcVec), Mask
784 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: VecFNeg, Mask);
785 }
786
787 Worklist.pushValue(V: VecFNeg);
788 replaceValue(Old&: I, New&: *NewShuf);
789 return true;
790}
791
792/// Try to fold insert(binop(x,y),binop(a,b),idx)
793/// --> binop(insert(x,a,idx),insert(y,b,idx))
794bool VectorCombine::foldInsExtBinop(Instruction &I) {
795 BinaryOperator *VecBinOp, *SclBinOp;
796 uint64_t Index;
797 if (!match(V: &I,
798 P: m_InsertElt(Val: m_OneUse(SubPattern: m_BinOp(I&: VecBinOp)),
799 Elt: m_OneUse(SubPattern: m_BinOp(I&: SclBinOp)), Idx: m_ConstantInt(V&: Index))))
800 return false;
801
802 // TODO: Add support for addlike etc.
803 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
804 if (BinOpcode != SclBinOp->getOpcode())
805 return false;
806
807 auto *ResultTy = dyn_cast<FixedVectorType>(Val: I.getType());
808 if (!ResultTy)
809 return false;
810
811 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
812 // shuffle?
813
814 InstructionCost OldCost = TTI.getInstructionCost(U: &I, CostKind) +
815 TTI.getInstructionCost(U: VecBinOp, CostKind) +
816 TTI.getInstructionCost(U: SclBinOp, CostKind);
817 InstructionCost NewCost =
818 TTI.getArithmeticInstrCost(Opcode: BinOpcode, Ty: ResultTy, CostKind) +
819 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
820 Index, Op0: VecBinOp->getOperand(i_nocapture: 0),
821 Op1: SclBinOp->getOperand(i_nocapture: 0)) +
822 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
823 Index, Op0: VecBinOp->getOperand(i_nocapture: 1),
824 Op1: SclBinOp->getOperand(i_nocapture: 1));
825
826 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
827 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
828 << "\n");
829 if (NewCost > OldCost)
830 return false;
831
832 Value *NewIns0 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 0),
833 NewElt: SclBinOp->getOperand(i_nocapture: 0), Idx: Index);
834 Value *NewIns1 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 1),
835 NewElt: SclBinOp->getOperand(i_nocapture: 1), Idx: Index);
836 Value *NewBO = Builder.CreateBinOp(Opc: BinOpcode, LHS: NewIns0, RHS: NewIns1);
837
838 // Intersect flags from the old binops.
839 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
840 NewInst->copyIRFlags(V: VecBinOp);
841 NewInst->andIRFlags(V: SclBinOp);
842 }
843
844 Worklist.pushValue(V: NewIns0);
845 Worklist.pushValue(V: NewIns1);
846 replaceValue(Old&: I, New&: *NewBO);
847 return true;
848}
849
850/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
851/// Supports: bitcast, trunc, sext, zext
852bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
853 // Check if this is a bitwise logic operation
854 auto *BinOp = dyn_cast<BinaryOperator>(Val: &I);
855 if (!BinOp || !BinOp->isBitwiseLogicOp())
856 return false;
857
858 // Get the cast instructions
859 auto *LHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 0));
860 auto *RHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 1));
861 if (!LHSCast || !RHSCast) {
862 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
863 return false;
864 }
865
866 // Both casts must be the same type
867 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
868 if (CastOpcode != RHSCast->getOpcode())
869 return false;
870
871 // Only handle supported cast operations
872 switch (CastOpcode) {
873 case Instruction::BitCast:
874 case Instruction::Trunc:
875 case Instruction::SExt:
876 case Instruction::ZExt:
877 break;
878 default:
879 return false;
880 }
881
882 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
883 Value *RHSSrc = RHSCast->getOperand(i_nocapture: 0);
884
885 // Source types must match
886 if (LHSSrc->getType() != RHSSrc->getType())
887 return false;
888
889 auto *SrcTy = LHSSrc->getType();
890 auto *DstTy = I.getType();
891 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
892 // Other casts only handle vector types with integer elements.
893 if (CastOpcode != Instruction::BitCast &&
894 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
895 return false;
896
897 // Only integer scalar/vector values are legal for bitwise logic operations.
898 if (!SrcTy->getScalarType()->isIntegerTy() ||
899 !DstTy->getScalarType()->isIntegerTy())
900 return false;
901
902 // Cost Check :
903 // OldCost = bitlogic + 2*casts
904 // NewCost = bitlogic + cast
905
906 // Calculate specific costs for each cast with instruction context
907 InstructionCost LHSCastCost = TTI.getCastInstrCost(
908 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
909 InstructionCost RHSCastCost = TTI.getCastInstrCost(
910 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: RHSCast);
911
912 InstructionCost OldCost =
913 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: DstTy, CostKind) +
914 LHSCastCost + RHSCastCost;
915
916 // For new cost, we can't provide an instruction (it doesn't exist yet)
917 InstructionCost GenericCastCost = TTI.getCastInstrCost(
918 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
919
920 InstructionCost NewCost =
921 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: SrcTy, CostKind) +
922 GenericCastCost;
923
924 // Account for multi-use casts using specific costs
925 if (!LHSCast->hasOneUse())
926 NewCost += LHSCastCost;
927 if (!RHSCast->hasOneUse())
928 NewCost += RHSCastCost;
929
930 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
931 << " NewCost=" << NewCost << "\n");
932
933 if (NewCost > OldCost)
934 return false;
935
936 // Create the operation on the source type
937 Value *NewOp = Builder.CreateBinOp(Opc: BinOp->getOpcode(), LHS: LHSSrc, RHS: RHSSrc,
938 Name: BinOp->getName() + ".inner");
939 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
940 NewBinOp->copyIRFlags(V: BinOp);
941
942 Worklist.pushValue(V: NewOp);
943
944 // Create the cast operation directly to ensure we get a new instruction
945 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
946
947 // Preserve cast instruction flags
948 NewCast->copyIRFlags(V: LHSCast);
949 NewCast->andIRFlags(V: RHSCast);
950
951 // Insert the new instruction
952 Value *Result = Builder.Insert(I: NewCast);
953
954 replaceValue(Old&: I, New&: *Result);
955 return true;
956}
957
958/// Match:
959// bitop(castop(x), C) ->
960// bitop(castop(x), castop(InvC)) ->
961// castop(bitop(x, InvC))
962// Supports: bitcast
963bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
964 Instruction *LHS;
965 Constant *C;
966
967 // Check if this is a bitwise logic operation
968 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_Instruction(I&: LHS), R: m_Constant(C))))
969 return false;
970
971 // Get the cast instructions
972 auto *LHSCast = dyn_cast<CastInst>(Val: LHS);
973 if (!LHSCast)
974 return false;
975
976 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
977
978 // Only handle supported cast operations
979 switch (CastOpcode) {
980 case Instruction::BitCast:
981 case Instruction::ZExt:
982 case Instruction::SExt:
983 case Instruction::Trunc:
984 break;
985 default:
986 return false;
987 }
988
989 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
990
991 auto *SrcTy = LHSSrc->getType();
992 auto *DstTy = I.getType();
993 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
994 // Other casts only handle vector types with integer elements.
995 if (CastOpcode != Instruction::BitCast &&
996 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
997 return false;
998
999 // Only integer scalar/vector values are legal for bitwise logic operations.
1000 if (!SrcTy->getScalarType()->isIntegerTy() ||
1001 !DstTy->getScalarType()->isIntegerTy())
1002 return false;
1003
1004 // Find the constant InvC, such that castop(InvC) equals to C.
1005 PreservedCastFlags RHSFlags;
1006 Constant *InvC = getLosslessInvCast(C, InvCastTo: SrcTy, CastOp: CastOpcode, DL: *DL, Flags: &RHSFlags);
1007 if (!InvC)
1008 return false;
1009
1010 // Cost Check :
1011 // OldCost = bitlogic + cast
1012 // NewCost = bitlogic + cast
1013
1014 // Calculate specific costs for each cast with instruction context
1015 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1016 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
1017
1018 InstructionCost OldCost =
1019 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: DstTy, CostKind) + LHSCastCost;
1020
1021 // For new cost, we can't provide an instruction (it doesn't exist yet)
1022 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1023 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
1024
1025 InstructionCost NewCost =
1026 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: SrcTy, CostKind) +
1027 GenericCastCost;
1028
1029 // Account for multi-use casts using specific costs
1030 if (!LHSCast->hasOneUse())
1031 NewCost += LHSCastCost;
1032
1033 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1034 << " NewCost=" << NewCost << "\n");
1035
1036 if (NewCost > OldCost)
1037 return false;
1038
1039 // Create the operation on the source type
1040 Value *NewOp = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(),
1041 LHS: LHSSrc, RHS: InvC, Name: I.getName() + ".inner");
1042 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
1043 NewBinOp->copyIRFlags(V: &I);
1044
1045 Worklist.pushValue(V: NewOp);
1046
1047 // Create the cast operation directly to ensure we get a new instruction
1048 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
1049
1050 // Preserve cast instruction flags
1051 if (RHSFlags.NNeg)
1052 NewCast->setNonNeg();
1053 if (RHSFlags.NUW)
1054 NewCast->setHasNoUnsignedWrap();
1055 if (RHSFlags.NSW)
1056 NewCast->setHasNoSignedWrap();
1057
1058 NewCast->andIRFlags(V: LHSCast);
1059
1060 // Insert the new instruction
1061 Value *Result = Builder.Insert(I: NewCast);
1062
1063 replaceValue(Old&: I, New&: *Result);
1064 return true;
1065}
1066
1067/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1068/// destination type followed by shuffle. This can enable further transforms by
1069/// moving bitcasts or shuffles together.
1070bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1071 Value *V0, *V1;
1072 ArrayRef<int> Mask;
1073 if (!match(V: &I, P: m_BitCast(Op: m_OneUse(
1074 SubPattern: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(Mask))))))
1075 return false;
1076
1077 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1078 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1079 // mask for scalable type is a splat or not.
1080 // 2) Disallow non-vector casts.
1081 // TODO: We could allow any shuffle.
1082 auto *DestTy = dyn_cast<FixedVectorType>(Val: I.getType());
1083 auto *SrcTy = dyn_cast<FixedVectorType>(Val: V0->getType());
1084 if (!DestTy || !SrcTy)
1085 return false;
1086
1087 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1088 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1089 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1090 return false;
1091
1092 bool IsUnary = isa<UndefValue>(Val: V1);
1093
1094 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1095 // if it won't increase the number of bitcasts.
1096 if (!IsUnary) {
1097 auto *BCTy0 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V0)->getType());
1098 auto *BCTy1 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V1)->getType());
1099 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1100 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1101 return false;
1102 }
1103
1104 SmallVector<int, 16> NewMask;
1105 if (DestEltSize <= SrcEltSize) {
1106 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1107 // always be expanded to the equivalent form choosing narrower elements.
1108 if (SrcEltSize % DestEltSize != 0)
1109 return false;
1110 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1111 narrowShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask);
1112 } else {
1113 // The bitcast is from narrow elements to wide elements. The shuffle mask
1114 // must choose consecutive elements to allow casting first.
1115 if (DestEltSize % SrcEltSize != 0)
1116 return false;
1117 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1118 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask))
1119 return false;
1120 }
1121
1122 // Bitcast the shuffle src - keep its original width but using the destination
1123 // scalar type.
1124 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1125 auto *NewShuffleTy =
1126 FixedVectorType::get(ElementType: DestTy->getScalarType(), NumElts: NumSrcElts);
1127 auto *OldShuffleTy =
1128 FixedVectorType::get(ElementType: SrcTy->getScalarType(), NumElts: Mask.size());
1129 unsigned NumOps = IsUnary ? 1 : 2;
1130
1131 // The new shuffle must not cost more than the old shuffle.
1132 TargetTransformInfo::ShuffleKind SK =
1133 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
1134 : TargetTransformInfo::SK_PermuteTwoSrc;
1135
1136 InstructionCost NewCost =
1137 TTI.getShuffleCost(Kind: SK, DstTy: DestTy, SrcTy: NewShuffleTy, Mask: NewMask, CostKind) +
1138 (NumOps * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewShuffleTy, Src: SrcTy,
1139 CCH: TargetTransformInfo::CastContextHint::None,
1140 CostKind));
1141 InstructionCost OldCost =
1142 TTI.getShuffleCost(Kind: SK, DstTy: OldShuffleTy, SrcTy, Mask, CostKind) +
1143 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DestTy, Src: OldShuffleTy,
1144 CCH: TargetTransformInfo::CastContextHint::None,
1145 CostKind);
1146
1147 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1148 << OldCost << " vs NewCost: " << NewCost << "\n");
1149
1150 if (NewCost > OldCost || !NewCost.isValid())
1151 return false;
1152
1153 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1154 ++NumShufOfBitcast;
1155 Value *CastV0 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V0), DestTy: NewShuffleTy);
1156 Value *CastV1 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V1), DestTy: NewShuffleTy);
1157 Value *Shuf = Builder.CreateShuffleVector(V1: CastV0, V2: CastV1, Mask: NewMask);
1158 replaceValue(Old&: I, New&: *Shuf);
1159 return true;
1160}
1161
1162/// VP Intrinsics whose vector operands are both splat values may be simplified
1163/// into the scalar version of the operation and the result splatted. This
1164/// can lead to scalarization down the line.
1165bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1166 if (!isa<VPIntrinsic>(Val: I))
1167 return false;
1168 VPIntrinsic &VPI = cast<VPIntrinsic>(Val&: I);
1169 Value *Op0 = VPI.getArgOperand(i: 0);
1170 Value *Op1 = VPI.getArgOperand(i: 1);
1171
1172 if (!isSplatValue(V: Op0) || !isSplatValue(V: Op1))
1173 return false;
1174
1175 // Check getSplatValue early in this function, to avoid doing unnecessary
1176 // work.
1177 Value *ScalarOp0 = getSplatValue(V: Op0);
1178 Value *ScalarOp1 = getSplatValue(V: Op1);
1179 if (!ScalarOp0 || !ScalarOp1)
1180 return false;
1181
1182 // For the binary VP intrinsics supported here, the result on disabled lanes
1183 // is a poison value. For now, only do this simplification if all lanes
1184 // are active.
1185 // TODO: Relax the condition that all lanes are active by using insertelement
1186 // on inactive lanes.
1187 auto IsAllTrueMask = [](Value *MaskVal) {
1188 if (Value *SplattedVal = getSplatValue(V: MaskVal))
1189 if (auto *ConstValue = dyn_cast<Constant>(Val: SplattedVal))
1190 return ConstValue->isAllOnesValue();
1191 return false;
1192 };
1193 if (!IsAllTrueMask(VPI.getArgOperand(i: 2)))
1194 return false;
1195
1196 // Check to make sure we support scalarization of the intrinsic
1197 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1198 if (!VPBinOpIntrinsic::isVPBinOp(ID: IntrID))
1199 return false;
1200
1201 // Calculate cost of splatting both operands into vectors and the vector
1202 // intrinsic
1203 VectorType *VecTy = cast<VectorType>(Val: VPI.getType());
1204 SmallVector<int> Mask;
1205 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: VecTy))
1206 Mask.resize(N: FVTy->getNumElements(), NV: 0);
1207 InstructionCost SplatCost =
1208 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: 0) +
1209 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_Broadcast, DstTy: VecTy, SrcTy: VecTy, Mask,
1210 CostKind);
1211
1212 // Calculate the cost of the VP Intrinsic
1213 SmallVector<Type *, 4> Args;
1214 for (Value *V : VPI.args())
1215 Args.push_back(Elt: V->getType());
1216 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1217 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1218 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1219
1220 // Determine scalar opcode
1221 std::optional<unsigned> FunctionalOpcode =
1222 VPI.getFunctionalOpcode();
1223 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1224 if (!FunctionalOpcode) {
1225 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1226 if (!ScalarIntrID)
1227 return false;
1228 }
1229
1230 // Calculate cost of scalarizing
1231 InstructionCost ScalarOpCost = 0;
1232 if (ScalarIntrID) {
1233 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1234 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1235 } else {
1236 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode: *FunctionalOpcode,
1237 Ty: VecTy->getScalarType(), CostKind);
1238 }
1239
1240 // The existing splats may be kept around if other instructions use them.
1241 InstructionCost CostToKeepSplats =
1242 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1243 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1244
1245 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1246 << "\n");
1247 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1248 << ", Cost of scalarizing:" << NewCost << "\n");
1249
1250 // We want to scalarize unless the vector variant actually has lower cost.
1251 if (OldCost < NewCost || !NewCost.isValid())
1252 return false;
1253
1254 // Scalarize the intrinsic
1255 ElementCount EC = cast<VectorType>(Val: Op0->getType())->getElementCount();
1256 Value *EVL = VPI.getArgOperand(i: 3);
1257
1258 // If the VP op might introduce UB or poison, we can scalarize it provided
1259 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1260 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1261 // scalarizing it.
1262 bool SafeToSpeculate;
1263 if (ScalarIntrID)
1264 SafeToSpeculate = Intrinsic::getFnAttributes(C&: I.getContext(), id: *ScalarIntrID)
1265 .hasAttribute(Kind: Attribute::AttrKind::Speculatable);
1266 else
1267 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1268 Opcode: *FunctionalOpcode, Inst: &VPI, CtxI: nullptr, AC: &AC, DT: &DT);
1269 if (!SafeToSpeculate &&
1270 !isKnownNonZero(V: EVL, Q: SimplifyQuery(*DL, &DT, &AC, &VPI)))
1271 return false;
1272
1273 Value *ScalarVal =
1274 ScalarIntrID
1275 ? Builder.CreateIntrinsic(RetTy: VecTy->getScalarType(), ID: *ScalarIntrID,
1276 Args: {ScalarOp0, ScalarOp1})
1277 : Builder.CreateBinOp(Opc: (Instruction::BinaryOps)(*FunctionalOpcode),
1278 LHS: ScalarOp0, RHS: ScalarOp1);
1279
1280 replaceValue(Old&: VPI, New&: *Builder.CreateVectorSplat(EC, V: ScalarVal));
1281 return true;
1282}
1283
1284/// Match a vector op/compare/intrinsic with at least one
1285/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1286/// by insertelement.
1287bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1288 auto *UO = dyn_cast<UnaryOperator>(Val: &I);
1289 auto *BO = dyn_cast<BinaryOperator>(Val: &I);
1290 auto *CI = dyn_cast<CmpInst>(Val: &I);
1291 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
1292 if (!UO && !BO && !CI && !II)
1293 return false;
1294
1295 // TODO: Allow intrinsics with different argument types
1296 if (II) {
1297 if (!isTriviallyVectorizable(ID: II->getIntrinsicID()))
1298 return false;
1299 for (auto [Idx, Arg] : enumerate(First: II->args()))
1300 if (Arg->getType() != II->getType() &&
1301 !isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI))
1302 return false;
1303 }
1304
1305 // Do not convert the vector condition of a vector select into a scalar
1306 // condition. That may cause problems for codegen because of differences in
1307 // boolean formats and register-file transfers.
1308 // TODO: Can we account for that in the cost model?
1309 if (CI)
1310 for (User *U : I.users())
1311 if (match(V: U, P: m_Select(C: m_Specific(V: &I), L: m_Value(), R: m_Value())))
1312 return false;
1313
1314 // Match constant vectors or scalars being inserted into constant vectors:
1315 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1316 SmallVector<Value *> VecCs, ScalarOps;
1317 std::optional<uint64_t> Index;
1318
1319 auto Ops = II ? II->args() : I.operands();
1320 for (auto [OpNum, Op] : enumerate(First&: Ops)) {
1321 Constant *VecC;
1322 Value *V;
1323 uint64_t InsIdx = 0;
1324 if (match(V: Op.get(), P: m_InsertElt(Val: m_Constant(C&: VecC), Elt: m_Value(V),
1325 Idx: m_ConstantInt(V&: InsIdx)))) {
1326 // Bail if any inserts are out of bounds.
1327 VectorType *OpTy = cast<VectorType>(Val: Op->getType());
1328 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1329 return false;
1330 // All inserts must have the same index.
1331 // TODO: Deal with mismatched index constants and variable indexes?
1332 if (!Index)
1333 Index = InsIdx;
1334 else if (InsIdx != *Index)
1335 return false;
1336 VecCs.push_back(Elt: VecC);
1337 ScalarOps.push_back(Elt: V);
1338 } else if (II && isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(),
1339 ScalarOpdIdx: OpNum, TTI: &TTI)) {
1340 VecCs.push_back(Elt: Op.get());
1341 ScalarOps.push_back(Elt: Op.get());
1342 } else if (match(V: Op.get(), P: m_Constant(C&: VecC))) {
1343 VecCs.push_back(Elt: VecC);
1344 ScalarOps.push_back(Elt: nullptr);
1345 } else {
1346 return false;
1347 }
1348 }
1349
1350 // Bail if all operands are constant.
1351 if (!Index.has_value())
1352 return false;
1353
1354 VectorType *VecTy = cast<VectorType>(Val: I.getType());
1355 Type *ScalarTy = VecTy->getScalarType();
1356 assert(VecTy->isVectorTy() &&
1357 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1358 ScalarTy->isPointerTy()) &&
1359 "Unexpected types for insert element into binop or cmp");
1360
1361 unsigned Opcode = I.getOpcode();
1362 InstructionCost ScalarOpCost, VectorOpCost;
1363 if (CI) {
1364 CmpInst::Predicate Pred = CI->getPredicate();
1365 ScalarOpCost = TTI.getCmpSelInstrCost(
1366 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
1367 VectorOpCost = TTI.getCmpSelInstrCost(
1368 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1369 } else if (UO || BO) {
1370 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
1371 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
1372 } else {
1373 IntrinsicCostAttributes ScalarICA(
1374 II->getIntrinsicID(), ScalarTy,
1375 SmallVector<Type *>(II->arg_size(), ScalarTy));
1376 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: ScalarICA, CostKind);
1377 IntrinsicCostAttributes VectorICA(
1378 II->getIntrinsicID(), VecTy,
1379 SmallVector<Type *>(II->arg_size(), VecTy));
1380 VectorOpCost = TTI.getIntrinsicInstrCost(ICA: VectorICA, CostKind);
1381 }
1382
1383 // Fold the vector constants in the original vectors into a new base vector to
1384 // get more accurate cost modelling.
1385 Value *NewVecC = nullptr;
1386 if (CI)
1387 NewVecC = simplifyCmpInst(Predicate: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1388 else if (UO)
1389 NewVecC =
1390 simplifyUnOp(Opcode: UO->getOpcode(), Op: VecCs[0], FMF: UO->getFastMathFlags(), Q: SQ);
1391 else if (BO)
1392 NewVecC = simplifyBinOp(Opcode: BO->getOpcode(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1393 else if (II)
1394 NewVecC = simplifyCall(Call: II, Callee: II->getCalledOperand(), Args: VecCs, Q: SQ);
1395
1396 if (!NewVecC)
1397 return false;
1398
1399 // Get cost estimate for the insert element. This cost will factor into
1400 // both sequences.
1401 InstructionCost OldCost = VectorOpCost;
1402 InstructionCost NewCost =
1403 ScalarOpCost + TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy,
1404 CostKind, Index: *Index, Op0: NewVecC);
1405
1406 for (auto [Idx, Op, VecC, Scalar] : enumerate(First&: Ops, Rest&: VecCs, Rest&: ScalarOps)) {
1407 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1408 ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI)))
1409 continue;
1410 InstructionCost InsertCost = TTI.getVectorInstrCost(
1411 Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: *Index, Op0: VecC, Op1: Scalar);
1412 OldCost += InsertCost;
1413 NewCost += !Op->hasOneUse() * InsertCost;
1414 }
1415
1416 // We want to scalarize unless the vector variant actually has lower cost.
1417 if (OldCost < NewCost || !NewCost.isValid())
1418 return false;
1419
1420 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1421 // inselt NewVecC, (scalar_op V0, V1), Index
1422 if (CI)
1423 ++NumScalarCmp;
1424 else if (UO || BO)
1425 ++NumScalarOps;
1426 else
1427 ++NumScalarIntrinsic;
1428
1429 // For constant cases, extract the scalar element, this should constant fold.
1430 for (auto [OpIdx, Scalar, VecC] : enumerate(First&: ScalarOps, Rest&: VecCs))
1431 if (!Scalar)
1432 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1433 Vec: cast<Constant>(Val: VecC), Idx: Builder.getInt64(C: *Index));
1434
1435 Value *Scalar;
1436 if (CI)
1437 Scalar = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: ScalarOps[0], RHS: ScalarOps[1]);
1438 else if (UO || BO)
1439 Scalar = Builder.CreateNAryOp(Opc: Opcode, Ops: ScalarOps);
1440 else
1441 Scalar = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: II->getIntrinsicID(), Args: ScalarOps);
1442
1443 Scalar->setName(I.getName() + ".scalar");
1444
1445 // All IR flags are safe to back-propagate. There is no potential for extra
1446 // poison to be created by the scalar instruction.
1447 if (auto *ScalarInst = dyn_cast<Instruction>(Val: Scalar))
1448 ScalarInst->copyIRFlags(V: &I);
1449
1450 Value *Insert = Builder.CreateInsertElement(Vec: NewVecC, NewElt: Scalar, Idx: *Index);
1451 replaceValue(Old&: I, New&: *Insert);
1452 return true;
1453}
1454
1455/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1456/// a vector into vector operations followed by extract. Note: The SLP pass
1457/// may miss this pattern because of implementation problems.
1458bool VectorCombine::foldExtractedCmps(Instruction &I) {
1459 auto *BI = dyn_cast<BinaryOperator>(Val: &I);
1460
1461 // We are looking for a scalar binop of booleans.
1462 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1463 if (!BI || !I.getType()->isIntegerTy(Bitwidth: 1))
1464 return false;
1465
1466 // The compare predicates should match, and each compare should have a
1467 // constant operand.
1468 Value *B0 = I.getOperand(i: 0), *B1 = I.getOperand(i: 1);
1469 Instruction *I0, *I1;
1470 Constant *C0, *C1;
1471 CmpPredicate P0, P1;
1472 if (!match(V: B0, P: m_Cmp(Pred&: P0, L: m_Instruction(I&: I0), R: m_Constant(C&: C0))) ||
1473 !match(V: B1, P: m_Cmp(Pred&: P1, L: m_Instruction(I&: I1), R: m_Constant(C&: C1))))
1474 return false;
1475
1476 auto MatchingPred = CmpPredicate::getMatching(A: P0, B: P1);
1477 if (!MatchingPred)
1478 return false;
1479
1480 // The compare operands must be extracts of the same vector with constant
1481 // extract indexes.
1482 Value *X;
1483 uint64_t Index0, Index1;
1484 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ConstantInt(V&: Index0))) ||
1485 !match(V: I1, P: m_ExtractElt(Val: m_Specific(V: X), Idx: m_ConstantInt(V&: Index1))))
1486 return false;
1487
1488 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
1489 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
1490 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex: CostKind);
1491 if (!ConvertToShuf)
1492 return false;
1493 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1494 "Unknown ExtractElementInst");
1495
1496 // The original scalar pattern is:
1497 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1498 CmpInst::Predicate Pred = *MatchingPred;
1499 unsigned CmpOpcode =
1500 CmpInst::isFPPredicate(P: Pred) ? Instruction::FCmp : Instruction::ICmp;
1501 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
1502 if (!VecTy)
1503 return false;
1504
1505 InstructionCost Ext0Cost =
1506 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
1507 InstructionCost Ext1Cost =
1508 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
1509 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1510 Opcode: CmpOpcode, ValTy: I0->getType(), CondTy: CmpInst::makeCmpResultType(opnd_type: I0->getType()), VecPred: Pred,
1511 CostKind);
1512
1513 InstructionCost OldCost =
1514 Ext0Cost + Ext1Cost + CmpCost * 2 +
1515 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: I.getType(), CostKind);
1516
1517 // The proposed vector pattern is:
1518 // vcmp = cmp Pred X, VecC
1519 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1520 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1521 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1522 auto *CmpTy = cast<FixedVectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
1523 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1524 Opcode: CmpOpcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1525 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1526 ShufMask[CheapIndex] = ExpensiveIndex;
1527 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: CmpTy,
1528 SrcTy: CmpTy, Mask: ShufMask, CostKind);
1529 NewCost += TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: CmpTy, CostKind);
1530 NewCost += TTI.getVectorInstrCost(I: *Ext0, Val: CmpTy, CostKind, Index: CheapIndex);
1531 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1532 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1533
1534 // Aggressively form vector ops if the cost is equal because the transform
1535 // may enable further optimization.
1536 // Codegen can reverse this transform (scalarize) if it was not profitable.
1537 if (OldCost < NewCost || !NewCost.isValid())
1538 return false;
1539
1540 // Create a vector constant from the 2 scalar constants.
1541 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1542 PoisonValue::get(T: VecTy->getElementType()));
1543 CmpC[Index0] = C0;
1544 CmpC[Index1] = C1;
1545 Value *VCmp = Builder.CreateCmp(Pred, LHS: X, RHS: ConstantVector::get(V: CmpC));
1546 Value *Shuf = createShiftShuffle(Vec: VCmp, OldIndex: ExpensiveIndex, NewIndex: CheapIndex, Builder);
1547 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1548 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1549 Value *VecLogic = Builder.CreateBinOp(Opc: BI->getOpcode(), LHS, RHS);
1550 Value *NewExt = Builder.CreateExtractElement(Vec: VecLogic, Idx: CheapIndex);
1551 replaceValue(Old&: I, New&: *NewExt);
1552 ++NumVecCmpBO;
1553 return true;
1554}
1555
1556/// Try to fold scalar selects that select between extracted elements and zero
1557/// into extracting from a vector select. This is rooted at the bitcast.
1558///
1559/// This pattern arises when a vector is bitcast to a smaller element type,
1560/// elements are extracted, and then conditionally selected with zero:
1561///
1562/// %bc = bitcast <4 x i32> %src to <16 x i8>
1563/// %e0 = extractelement <16 x i8> %bc, i32 0
1564/// %s0 = select i1 %cond, i8 %e0, i8 0
1565/// %e1 = extractelement <16 x i8> %bc, i32 1
1566/// %s1 = select i1 %cond, i8 %e1, i8 0
1567/// ...
1568///
1569/// Transforms to:
1570/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1571/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1572/// %e0 = extractelement <16 x i8> %bc, i32 0
1573/// %e1 = extractelement <16 x i8> %bc, i32 1
1574/// ...
1575///
1576/// This is profitable because vector select on wider types produces fewer
1577/// select/cndmask instructions than scalar selects on each element.
1578bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1579 auto *BC = dyn_cast<BitCastInst>(Val: &I);
1580 if (!BC)
1581 return false;
1582
1583 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(Val: BC->getSrcTy());
1584 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(Val: BC->getDestTy());
1585 if (!SrcVecTy || !DstVecTy)
1586 return false;
1587
1588 // Source must be 32-bit or 64-bit elements, destination must be smaller
1589 // integer elements. Zero in all these types is all-bits-zero.
1590 Type *SrcEltTy = SrcVecTy->getElementType();
1591 Type *DstEltTy = DstVecTy->getElementType();
1592 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1593 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1594
1595 if (SrcEltBits != 32 && SrcEltBits != 64)
1596 return false;
1597
1598 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1599 return false;
1600
1601 // Check profitability using TTI before collecting users.
1602 Type *CondTy = CmpInst::makeCmpResultType(opnd_type: DstEltTy);
1603 Type *VecCondTy = CmpInst::makeCmpResultType(opnd_type: SrcVecTy);
1604
1605 InstructionCost ScalarSelCost =
1606 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: DstEltTy, CondTy,
1607 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1608 InstructionCost VecSelCost =
1609 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: SrcVecTy, CondTy: VecCondTy,
1610 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1611
1612 // We need at least this many selects for vectorization to be profitable.
1613 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1614 // ScalarSelCost
1615 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1616 return false;
1617
1618 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1619
1620 // Quick check: if bitcast doesn't have enough users, bail early.
1621 if (!BC->hasNUsesOrMore(N: MinSelects))
1622 return false;
1623
1624 // Collect all select users that match the pattern, grouped by condition.
1625 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1626 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1627
1628 for (User *U : BC->users()) {
1629 auto *Ext = dyn_cast<ExtractElementInst>(Val: U);
1630 if (!Ext)
1631 continue;
1632
1633 for (User *ExtUser : Ext->users()) {
1634 Value *Cond;
1635 // Match: select i1 %cond, %ext, 0
1636 if (match(V: ExtUser, P: m_Select(C: m_Value(V&: Cond), L: m_Specific(V: Ext), R: m_Zero())) &&
1637 Cond->getType()->isIntegerTy(Bitwidth: 1))
1638 CondToSelects[Cond].push_back(Elt: cast<SelectInst>(Val: ExtUser));
1639 }
1640 }
1641
1642 if (CondToSelects.empty())
1643 return false;
1644
1645 bool MadeChange = false;
1646 Value *SrcVec = BC->getOperand(i_nocapture: 0);
1647
1648 // Process each group of selects with the same condition.
1649 for (auto [Cond, Selects] : CondToSelects) {
1650 // Only profitable if vector select cost < total scalar select cost.
1651 if (Selects.size() < MinSelects) {
1652 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1653 << "profitable (VecCost=" << VecSelCost
1654 << ", ScalarCost=" << ScalarSelCost
1655 << ", NumSelects=" << Selects.size() << ")\n");
1656 continue;
1657 }
1658
1659 // Create the vector select and bitcast once for this condition.
1660 auto InsertPt = std::next(x: BC->getIterator());
1661
1662 if (auto *CondInst = dyn_cast<Instruction>(Val: Cond))
1663 if (DT.dominates(Def: BC, User: CondInst))
1664 InsertPt = std::next(x: CondInst->getIterator());
1665
1666 Builder.SetInsertPoint(InsertPt);
1667 Value *VecSel =
1668 Builder.CreateSelect(C: Cond, True: SrcVec, False: Constant::getNullValue(Ty: SrcVecTy));
1669 Value *NewBC = Builder.CreateBitCast(V: VecSel, DestTy: DstVecTy);
1670
1671 // Replace each scalar select with an extract from the new bitcast.
1672 for (SelectInst *Sel : Selects) {
1673 auto *Ext = cast<ExtractElementInst>(Val: Sel->getTrueValue());
1674 Value *Idx = Ext->getIndexOperand();
1675
1676 Builder.SetInsertPoint(Sel);
1677 Value *NewExt = Builder.CreateExtractElement(Vec: NewBC, Idx);
1678 replaceValue(Old&: *Sel, New&: *NewExt);
1679 MadeChange = true;
1680 }
1681
1682 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1683 << " selects into vector select\n");
1684 }
1685
1686 return MadeChange;
1687}
1688
1689static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1690 TTI::TargetCostKind CostKind,
1691 const TargetTransformInfo &TTI,
1692 InstructionCost &CostBeforeReduction,
1693 InstructionCost &CostAfterReduction) {
1694 Instruction *Op0, *Op1;
1695 auto *RedOp = dyn_cast<Instruction>(Val: II.getOperand(i_nocapture: 0));
1696 auto *VecRedTy = cast<VectorType>(Val: II.getOperand(i_nocapture: 0)->getType());
1697 unsigned ReductionOpc =
1698 getArithmeticReductionInstruction(RdxID: II.getIntrinsicID());
1699 if (RedOp && match(V: RedOp, P: m_ZExtOrSExt(Op: m_Value()))) {
1700 bool IsUnsigned = isa<ZExtInst>(Val: RedOp);
1701 auto *ExtType = cast<VectorType>(Val: RedOp->getOperand(i: 0)->getType());
1702
1703 CostBeforeReduction =
1704 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: ExtType,
1705 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1706 CostAfterReduction =
1707 TTI.getExtendedReductionCost(Opcode: ReductionOpc, IsUnsigned, ResTy: II.getType(),
1708 Ty: ExtType, FMF: FastMathFlags(), CostKind);
1709 return;
1710 }
1711 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1712 match(V: RedOp,
1713 P: m_ZExtOrSExt(Op: m_Mul(L: m_Instruction(I&: Op0), R: m_Instruction(I&: Op1)))) &&
1714 match(V: Op0, P: m_ZExtOrSExt(Op: m_Value())) &&
1715 Op0->getOpcode() == Op1->getOpcode() &&
1716 Op0->getOperand(i: 0)->getType() == Op1->getOperand(i: 0)->getType() &&
1717 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1718 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1719 bool IsUnsigned = isa<ZExtInst>(Val: Op0);
1720 auto *ExtType = cast<VectorType>(Val: Op0->getOperand(i: 0)->getType());
1721 VectorType *MulType = VectorType::get(ElementType: Op0->getType(), Other: VecRedTy);
1722
1723 InstructionCost ExtCost =
1724 TTI.getCastInstrCost(Opcode: Op0->getOpcode(), Dst: MulType, Src: ExtType,
1725 CCH: TTI::CastContextHint::None, CostKind, I: Op0);
1726 InstructionCost MulCost =
1727 TTI.getArithmeticInstrCost(Opcode: Instruction::Mul, Ty: MulType, CostKind);
1728 InstructionCost Ext2Cost =
1729 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: MulType,
1730 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1731
1732 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1733 CostAfterReduction = TTI.getMulAccReductionCost(
1734 IsUnsigned, RedOpcode: ReductionOpc, ResTy: II.getType(), Ty: ExtType, CostKind);
1735 return;
1736 }
1737 CostAfterReduction = TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecRedTy,
1738 FMF: std::nullopt, CostKind);
1739}
1740
1741bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1742 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(Val: &I)->getOpcode();
1743 Intrinsic::ID ReductionIID = getReductionForBinop(Opc: BinOpOpc);
1744 if (BinOpOpc == Instruction::Sub)
1745 ReductionIID = Intrinsic::vector_reduce_add;
1746 if (ReductionIID == Intrinsic::not_intrinsic)
1747 return false;
1748
1749 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1750 Intrinsic::ID IID) -> Value * {
1751 auto *II = dyn_cast<IntrinsicInst>(Val: V);
1752 if (!II)
1753 return nullptr;
1754 if (II->getIntrinsicID() == IID && II->hasOneUse())
1755 return II->getArgOperand(i: 0);
1756 return nullptr;
1757 };
1758
1759 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 0), ReductionIID);
1760 if (!V0)
1761 return false;
1762 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 1), ReductionIID);
1763 if (!V1)
1764 return false;
1765
1766 auto *VTy = cast<VectorType>(Val: V0->getType());
1767 if (V1->getType() != VTy)
1768 return false;
1769 const auto &II0 = *cast<IntrinsicInst>(Val: I.getOperand(i: 0));
1770 const auto &II1 = *cast<IntrinsicInst>(Val: I.getOperand(i: 1));
1771 unsigned ReductionOpc =
1772 getArithmeticReductionInstruction(RdxID: II0.getIntrinsicID());
1773
1774 InstructionCost OldCost = 0;
1775 InstructionCost NewCost = 0;
1776 InstructionCost CostOfRedOperand0 = 0;
1777 InstructionCost CostOfRed0 = 0;
1778 InstructionCost CostOfRedOperand1 = 0;
1779 InstructionCost CostOfRed1 = 0;
1780 analyzeCostOfVecReduction(II: II0, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand0, CostAfterReduction&: CostOfRed0);
1781 analyzeCostOfVecReduction(II: II1, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand1, CostAfterReduction&: CostOfRed1);
1782 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(U: &I, CostKind);
1783 NewCost =
1784 CostOfRedOperand0 + CostOfRedOperand1 +
1785 TTI.getArithmeticInstrCost(Opcode: BinOpOpc, Ty: VTy, CostKind) +
1786 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VTy, FMF: std::nullopt, CostKind);
1787 if (NewCost >= OldCost || !NewCost.isValid())
1788 return false;
1789
1790 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1791 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1792 << "\n");
1793 Value *VectorBO;
1794 if (BinOpOpc == Instruction::Or)
1795 VectorBO = Builder.CreateOr(LHS: V0, RHS: V1, Name: "",
1796 IsDisjoint: cast<PossiblyDisjointInst>(Val&: I).isDisjoint());
1797 else
1798 VectorBO = Builder.CreateBinOp(Opc: BinOpOpc, LHS: V0, RHS: V1);
1799
1800 Instruction *Rdx = Builder.CreateIntrinsic(ID: ReductionIID, Types: {VTy}, Args: {VectorBO});
1801 replaceValue(Old&: I, New&: *Rdx);
1802 return true;
1803}
1804
1805// Check if memory loc modified between two instrs in the same BB
1806static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1807 BasicBlock::iterator End,
1808 const MemoryLocation &Loc, AAResults &AA) {
1809 unsigned NumScanned = 0;
1810 return std::any_of(first: Begin, last: End, pred: [&](const Instruction &Instr) {
1811 return isModSet(MRI: AA.getModRefInfo(I: &Instr, OptLoc: Loc)) ||
1812 ++NumScanned > MaxInstrsToScan;
1813 });
1814}
1815
1816namespace {
1817/// Helper class to indicate whether a vector index can be safely scalarized and
1818/// if a freeze needs to be inserted.
1819class ScalarizationResult {
1820 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1821
1822 StatusTy Status;
1823 Value *ToFreeze;
1824
1825 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1826 : Status(Status), ToFreeze(ToFreeze) {}
1827
1828public:
1829 ScalarizationResult(const ScalarizationResult &Other) = default;
1830 ~ScalarizationResult() {
1831 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1832 }
1833
1834 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1835 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1836 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1837 return {StatusTy::SafeWithFreeze, ToFreeze};
1838 }
1839
1840 /// Returns true if the index can be scalarize without requiring a freeze.
1841 bool isSafe() const { return Status == StatusTy::Safe; }
1842 /// Returns true if the index cannot be scalarized.
1843 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1844 /// Returns true if the index can be scalarize, but requires inserting a
1845 /// freeze.
1846 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1847
1848 /// Reset the state of Unsafe and clear ToFreze if set.
1849 void discard() {
1850 ToFreeze = nullptr;
1851 Status = StatusTy::Unsafe;
1852 }
1853
1854 /// Freeze the ToFreeze and update the use in \p User to use it.
1855 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1856 assert(isSafeWithFreeze() &&
1857 "should only be used when freezing is required");
1858 assert(is_contained(ToFreeze->users(), &UserI) &&
1859 "UserI must be a user of ToFreeze");
1860 IRBuilder<>::InsertPointGuard Guard(Builder);
1861 Builder.SetInsertPoint(cast<Instruction>(Val: &UserI));
1862 Value *Frozen =
1863 Builder.CreateFreeze(V: ToFreeze, Name: ToFreeze->getName() + ".frozen");
1864 for (Use &U : make_early_inc_range(Range: (UserI.operands())))
1865 if (U.get() == ToFreeze)
1866 U.set(Frozen);
1867
1868 ToFreeze = nullptr;
1869 }
1870};
1871} // namespace
1872
1873/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1874/// Idx. \p Idx must access a valid vector element.
1875static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1876 Instruction *CtxI,
1877 AssumptionCache &AC,
1878 const DominatorTree &DT) {
1879 // We do checks for both fixed vector types and scalable vector types.
1880 // This is the number of elements of fixed vector types,
1881 // or the minimum number of elements of scalable vector types.
1882 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1883 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1884
1885 if (auto *C = dyn_cast<ConstantInt>(Val: Idx)) {
1886 if (C->getValue().ult(RHS: NumElements))
1887 return ScalarizationResult::safe();
1888 return ScalarizationResult::unsafe();
1889 }
1890
1891 // Always unsafe if the index type can't handle all inbound values.
1892 if (!llvm::isUIntN(N: IntWidth, x: NumElements))
1893 return ScalarizationResult::unsafe();
1894
1895 APInt Zero(IntWidth, 0);
1896 APInt MaxElts(IntWidth, NumElements);
1897 ConstantRange ValidIndices(Zero, MaxElts);
1898 ConstantRange IdxRange(IntWidth, true);
1899
1900 if (isGuaranteedNotToBePoison(V: Idx, AC: &AC)) {
1901 if (ValidIndices.contains(CR: computeConstantRange(V: Idx, /* ForSigned */ false,
1902 UseInstrInfo: true, AC: &AC, CtxI, DT: &DT)))
1903 return ScalarizationResult::safe();
1904 return ScalarizationResult::unsafe();
1905 }
1906
1907 // If the index may be poison, check if we can insert a freeze before the
1908 // range of the index is restricted.
1909 Value *IdxBase;
1910 ConstantInt *CI;
1911 if (match(V: Idx, P: m_And(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1912 IdxRange = IdxRange.binaryAnd(Other: CI->getValue());
1913 } else if (match(V: Idx, P: m_URem(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1914 IdxRange = IdxRange.urem(Other: CI->getValue());
1915 }
1916
1917 if (ValidIndices.contains(CR: IdxRange))
1918 return ScalarizationResult::safeWithFreeze(ToFreeze: IdxBase);
1919 return ScalarizationResult::unsafe();
1920}
1921
1922/// The memory operation on a vector of \p ScalarType had alignment of
1923/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1924/// alignment that will be valid for the memory operation on a single scalar
1925/// element of the same type with index \p Idx.
1926static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1927 Type *ScalarType, Value *Idx,
1928 const DataLayout &DL) {
1929 if (auto *C = dyn_cast<ConstantInt>(Val: Idx))
1930 return commonAlignment(A: VectorAlignment,
1931 Offset: C->getZExtValue() * DL.getTypeStoreSize(Ty: ScalarType));
1932 return commonAlignment(A: VectorAlignment, Offset: DL.getTypeStoreSize(Ty: ScalarType));
1933}
1934
1935// Combine patterns like:
1936// %0 = load <4 x i32>, <4 x i32>* %a
1937// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1938// store <4 x i32> %1, <4 x i32>* %a
1939// to:
1940// %0 = bitcast <4 x i32>* %a to i32*
1941// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1942// store i32 %b, i32* %1
1943bool VectorCombine::foldSingleElementStore(Instruction &I) {
1944 if (!TTI.allowVectorElementIndexingUsingGEP())
1945 return false;
1946 auto *SI = cast<StoreInst>(Val: &I);
1947 if (!SI->isSimple() || !isa<VectorType>(Val: SI->getValueOperand()->getType()))
1948 return false;
1949
1950 // TODO: Combine more complicated patterns (multiple insert) by referencing
1951 // TargetTransformInfo.
1952 Instruction *Source;
1953 Value *NewElement;
1954 Value *Idx;
1955 if (!match(V: SI->getValueOperand(),
1956 P: m_InsertElt(Val: m_Instruction(I&: Source), Elt: m_Value(V&: NewElement),
1957 Idx: m_Value(V&: Idx))))
1958 return false;
1959
1960 if (auto *Load = dyn_cast<LoadInst>(Val: Source)) {
1961 auto VecTy = cast<VectorType>(Val: SI->getValueOperand()->getType());
1962 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1963 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1964 // modified between, vector type matches store size, and index is inbounds.
1965 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1966 !DL->typeSizeEqualsStoreSize(Ty: Load->getType()->getScalarType()) ||
1967 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1968 return false;
1969
1970 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, CtxI: Load, AC, DT);
1971 if (ScalarizableIdx.isUnsafe() ||
1972 isMemModifiedBetween(Begin: Load->getIterator(), End: SI->getIterator(),
1973 Loc: MemoryLocation::get(SI), AA))
1974 return false;
1975
1976 // Ensure we add the load back to the worklist BEFORE its users so they can
1977 // erased in the correct order.
1978 Worklist.push(I: Load);
1979
1980 if (ScalarizableIdx.isSafeWithFreeze())
1981 ScalarizableIdx.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
1982 Value *GEP = Builder.CreateInBoundsGEP(
1983 Ty: SI->getValueOperand()->getType(), Ptr: SI->getPointerOperand(),
1984 IdxList: {ConstantInt::get(Ty: Idx->getType(), V: 0), Idx});
1985 StoreInst *NSI = Builder.CreateStore(Val: NewElement, Ptr: GEP);
1986 NSI->copyMetadata(SrcInst: *SI);
1987 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1988 VectorAlignment: std::max(a: SI->getAlign(), b: Load->getAlign()), ScalarType: NewElement->getType(), Idx,
1989 DL: *DL);
1990 NSI->setAlignment(ScalarOpAlignment);
1991 replaceValue(Old&: I, New&: *NSI);
1992 eraseInstruction(I);
1993 return true;
1994 }
1995
1996 return false;
1997}
1998
1999/// Try to scalarize vector loads feeding extractelement or bitcast
2000/// instructions.
2001bool VectorCombine::scalarizeLoad(Instruction &I) {
2002 Value *Ptr;
2003 if (!match(V: &I, P: m_Load(Op: m_Value(V&: Ptr))))
2004 return false;
2005
2006 auto *LI = cast<LoadInst>(Val: &I);
2007 auto *VecTy = cast<VectorType>(Val: LI->getType());
2008 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(Ty: VecTy->getScalarType()))
2009 return false;
2010
2011 bool AllExtracts = true;
2012 bool AllBitcasts = true;
2013 Instruction *LastCheckedInst = LI;
2014 unsigned NumInstChecked = 0;
2015
2016 // Check what type of users we have (must either all be extracts or
2017 // bitcasts) and ensure no memory modifications between the load and
2018 // its users.
2019 for (User *U : LI->users()) {
2020 auto *UI = dyn_cast<Instruction>(Val: U);
2021 if (!UI || UI->getParent() != LI->getParent())
2022 return false;
2023
2024 // If any user is waiting to be erased, then bail out as this will
2025 // distort the cost calculation and possibly lead to infinite loops.
2026 if (UI->use_empty())
2027 return false;
2028
2029 if (!isa<ExtractElementInst>(Val: UI))
2030 AllExtracts = false;
2031 if (!isa<BitCastInst>(Val: UI))
2032 AllBitcasts = false;
2033
2034 // Check if any instruction between the load and the user may modify memory.
2035 if (LastCheckedInst->comesBefore(Other: UI)) {
2036 for (Instruction &I :
2037 make_range(x: std::next(x: LI->getIterator()), y: UI->getIterator())) {
2038 // Bail out if we reached the check limit or the instruction may write
2039 // to memory.
2040 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2041 return false;
2042 NumInstChecked++;
2043 }
2044 LastCheckedInst = UI;
2045 }
2046 }
2047
2048 if (AllExtracts)
2049 return scalarizeLoadExtract(LI, VecTy, Ptr);
2050 if (AllBitcasts)
2051 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2052 return false;
2053}
2054
2055/// Try to scalarize vector loads feeding extractelement instructions.
2056bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2057 Value *Ptr) {
2058 if (!TTI.allowVectorElementIndexingUsingGEP())
2059 return false;
2060
2061 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2062 llvm::scope_exit FailureGuard([&]() {
2063 // If the transform is aborted, discard the ScalarizationResults.
2064 for (auto &Pair : NeedFreeze)
2065 Pair.second.discard();
2066 });
2067
2068 InstructionCost OriginalCost =
2069 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2070 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2071 InstructionCost ScalarizedCost = 0;
2072
2073 for (User *U : LI->users()) {
2074 auto *UI = cast<ExtractElementInst>(Val: U);
2075
2076 auto ScalarIdx =
2077 canScalarizeAccess(VecTy, Idx: UI->getIndexOperand(), CtxI: LI, AC, DT);
2078 if (ScalarIdx.isUnsafe())
2079 return false;
2080 if (ScalarIdx.isSafeWithFreeze()) {
2081 NeedFreeze.try_emplace(Key: UI, Args&: ScalarIdx);
2082 ScalarIdx.discard();
2083 }
2084
2085 auto *Index = dyn_cast<ConstantInt>(Val: UI->getIndexOperand());
2086 OriginalCost +=
2087 TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: VecTy, CostKind,
2088 Index: Index ? Index->getZExtValue() : -1);
2089 ScalarizedCost +=
2090 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy->getElementType(),
2091 Alignment: Align(1), AddressSpace: LI->getPointerAddressSpace(), CostKind);
2092 ScalarizedCost += TTI.getAddressComputationCost(PtrTy: LI->getPointerOperandType(),
2093 SE: nullptr, Ptr: nullptr, CostKind);
2094 }
2095
2096 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2097 << "\n LoadExtractCost: " << OriginalCost
2098 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2099
2100 if (ScalarizedCost >= OriginalCost)
2101 return false;
2102
2103 // Ensure we add the load back to the worklist BEFORE its users so they can
2104 // erased in the correct order.
2105 Worklist.push(I: LI);
2106
2107 Type *ElemType = VecTy->getElementType();
2108
2109 // Replace extracts with narrow scalar loads.
2110 for (User *U : LI->users()) {
2111 auto *EI = cast<ExtractElementInst>(Val: U);
2112 Value *Idx = EI->getIndexOperand();
2113
2114 // Insert 'freeze' for poison indexes.
2115 auto It = NeedFreeze.find(Val: EI);
2116 if (It != NeedFreeze.end())
2117 It->second.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
2118
2119 Builder.SetInsertPoint(EI);
2120 Value *GEP =
2121 Builder.CreateInBoundsGEP(Ty: VecTy, Ptr, IdxList: {Builder.getInt32(C: 0), Idx});
2122 auto *NewLoad = cast<LoadInst>(
2123 Val: Builder.CreateLoad(Ty: ElemType, Ptr: GEP, Name: EI->getName() + ".scalar"));
2124
2125 Align ScalarOpAlignment =
2126 computeAlignmentAfterScalarization(VectorAlignment: LI->getAlign(), ScalarType: ElemType, Idx, DL: *DL);
2127 NewLoad->setAlignment(ScalarOpAlignment);
2128
2129 if (auto *ConstIdx = dyn_cast<ConstantInt>(Val: Idx)) {
2130 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(Ty: ElemType);
2131 AAMDNodes OldAAMD = LI->getAAMetadata();
2132 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, AccessTy: ElemType, DL: *DL));
2133 }
2134
2135 replaceValue(Old&: *EI, New&: *NewLoad, Erase: false);
2136 }
2137
2138 FailureGuard.release();
2139 return true;
2140}
2141
2142/// Try to scalarize vector loads feeding bitcast instructions.
2143bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2144 Value *Ptr) {
2145 InstructionCost OriginalCost =
2146 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2147 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2148
2149 Type *TargetScalarType = nullptr;
2150 unsigned VecBitWidth = DL->getTypeSizeInBits(Ty: VecTy);
2151
2152 for (User *U : LI->users()) {
2153 auto *BC = cast<BitCastInst>(Val: U);
2154
2155 Type *DestTy = BC->getDestTy();
2156 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2157 return false;
2158
2159 unsigned DestBitWidth = DL->getTypeSizeInBits(Ty: DestTy);
2160 if (DestBitWidth != VecBitWidth)
2161 return false;
2162
2163 // All bitcasts must target the same scalar type.
2164 if (!TargetScalarType)
2165 TargetScalarType = DestTy;
2166 else if (TargetScalarType != DestTy)
2167 return false;
2168
2169 OriginalCost +=
2170 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: TargetScalarType, Src: VecTy,
2171 CCH: TTI.getCastContextHint(I: BC), CostKind, I: BC);
2172 }
2173
2174 if (!TargetScalarType)
2175 return false;
2176
2177 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2178 InstructionCost ScalarizedCost =
2179 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: TargetScalarType, Alignment: LI->getAlign(),
2180 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2181
2182 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2183 << "\n OriginalCost: " << OriginalCost
2184 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2185
2186 if (ScalarizedCost >= OriginalCost)
2187 return false;
2188
2189 // Ensure we add the load back to the worklist BEFORE its users so they can
2190 // erased in the correct order.
2191 Worklist.push(I: LI);
2192
2193 Builder.SetInsertPoint(LI);
2194 auto *ScalarLoad =
2195 Builder.CreateLoad(Ty: TargetScalarType, Ptr, Name: LI->getName() + ".scalar");
2196 ScalarLoad->setAlignment(LI->getAlign());
2197 ScalarLoad->copyMetadata(SrcInst: *LI);
2198
2199 // Replace all bitcast users with the scalar load.
2200 for (User *U : LI->users()) {
2201 auto *BC = cast<BitCastInst>(Val: U);
2202 replaceValue(Old&: *BC, New&: *ScalarLoad, Erase: false);
2203 }
2204
2205 return true;
2206}
2207
2208bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2209 if (!TTI.allowVectorElementIndexingUsingGEP())
2210 return false;
2211 auto *Ext = dyn_cast<ZExtInst>(Val: &I);
2212 if (!Ext)
2213 return false;
2214
2215 // Try to convert a vector zext feeding only extracts to a set of scalar
2216 // (Src << ExtIdx *Size) & (Size -1)
2217 // if profitable .
2218 auto *SrcTy = dyn_cast<FixedVectorType>(Val: Ext->getOperand(i_nocapture: 0)->getType());
2219 if (!SrcTy)
2220 return false;
2221 auto *DstTy = cast<FixedVectorType>(Val: Ext->getType());
2222
2223 Type *ScalarDstTy = DstTy->getElementType();
2224 if (DL->getTypeSizeInBits(Ty: SrcTy) != DL->getTypeSizeInBits(Ty: ScalarDstTy))
2225 return false;
2226
2227 InstructionCost VectorCost =
2228 TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: DstTy, Src: SrcTy,
2229 CCH: TTI::CastContextHint::None, CostKind, I: Ext);
2230 unsigned ExtCnt = 0;
2231 bool ExtLane0 = false;
2232 for (User *U : Ext->users()) {
2233 uint64_t Idx;
2234 if (!match(V: U, P: m_ExtractElt(Val: m_Value(), Idx: m_ConstantInt(V&: Idx))))
2235 return false;
2236 if (cast<Instruction>(Val: U)->use_empty())
2237 continue;
2238 ExtCnt += 1;
2239 ExtLane0 |= !Idx;
2240 VectorCost += TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: DstTy,
2241 CostKind, Index: Idx, Op0: U);
2242 }
2243
2244 InstructionCost ScalarCost =
2245 ExtCnt * TTI.getArithmeticInstrCost(
2246 Opcode: Instruction::And, Ty: ScalarDstTy, CostKind,
2247 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2248 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None}) +
2249 (ExtCnt - ExtLane0) *
2250 TTI.getArithmeticInstrCost(
2251 Opcode: Instruction::LShr, Ty: ScalarDstTy, CostKind,
2252 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2253 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None});
2254 if (ScalarCost > VectorCost)
2255 return false;
2256
2257 Value *ScalarV = Ext->getOperand(i_nocapture: 0);
2258 if (!isGuaranteedNotToBePoison(V: ScalarV, AC: &AC, CtxI: dyn_cast<Instruction>(Val: ScalarV),
2259 DT: &DT)) {
2260 // Check wether all lanes are extracted, all extracts trigger UB
2261 // on poison, and the last extract (and hence all previous ones)
2262 // are guaranteed to execute if Ext executes. If so, we do not
2263 // need to insert a freeze.
2264 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2265 bool AllExtractsTriggerUB = true;
2266 ExtractElementInst *LastExtract = nullptr;
2267 BasicBlock *ExtBB = Ext->getParent();
2268 for (User *U : Ext->users()) {
2269 auto *Extract = cast<ExtractElementInst>(Val: U);
2270 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Inst: Extract)) {
2271 AllExtractsTriggerUB = false;
2272 break;
2273 }
2274 ExtractedLanes.insert(V: cast<ConstantInt>(Val: Extract->getIndexOperand()));
2275 if (!LastExtract || LastExtract->comesBefore(Other: Extract))
2276 LastExtract = Extract;
2277 }
2278 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2279 !AllExtractsTriggerUB ||
2280 !isGuaranteedToTransferExecutionToSuccessor(Begin: Ext->getIterator(),
2281 End: LastExtract->getIterator()))
2282 ScalarV = Builder.CreateFreeze(V: ScalarV);
2283 }
2284 ScalarV = Builder.CreateBitCast(
2285 V: ScalarV,
2286 DestTy: IntegerType::get(C&: SrcTy->getContext(), NumBits: DL->getTypeSizeInBits(Ty: SrcTy)));
2287 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(Ty: SrcTy->getElementType());
2288 uint64_t TotalBits = DL->getTypeSizeInBits(Ty: SrcTy);
2289 APInt EltBitMask = APInt::getLowBitsSet(numBits: TotalBits, loBitsSet: SrcEltSizeInBits);
2290 Type *PackedTy = IntegerType::get(C&: SrcTy->getContext(), NumBits: TotalBits);
2291 Value *Mask = ConstantInt::get(Ty: PackedTy, V: EltBitMask);
2292 for (User *U : Ext->users()) {
2293 auto *Extract = cast<ExtractElementInst>(Val: U);
2294 uint64_t Idx =
2295 cast<ConstantInt>(Val: Extract->getIndexOperand())->getZExtValue();
2296 uint64_t ShiftAmt =
2297 DL->isBigEndian()
2298 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2299 : (Idx * SrcEltSizeInBits);
2300 Value *LShr = Builder.CreateLShr(LHS: ScalarV, RHS: ShiftAmt);
2301 Value *And = Builder.CreateAnd(LHS: LShr, RHS: Mask);
2302 U->replaceAllUsesWith(V: And);
2303 }
2304 return true;
2305}
2306
2307/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2308/// to "(bitcast (concat X, Y))"
2309/// where X/Y are bitcasted from i1 mask vectors.
2310bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2311 Type *Ty = I.getType();
2312 if (!Ty->isIntegerTy())
2313 return false;
2314
2315 // TODO: Add big endian test coverage
2316 if (DL->isBigEndian())
2317 return false;
2318
2319 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2320 Instruction *X, *Y;
2321 if (!match(V: &I, P: m_DisjointOr(L: m_Instruction(I&: X), R: m_Instruction(I&: Y))))
2322 return false;
2323
2324 // Allow both sources to contain shl, to handle more generic pattern:
2325 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2326 Value *SrcX;
2327 uint64_t ShAmtX = 0;
2328 if (!match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX)))))) &&
2329 !match(V: X, P: m_OneUse(
2330 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX))))),
2331 R: m_ConstantInt(V&: ShAmtX)))))
2332 return false;
2333
2334 Value *SrcY;
2335 uint64_t ShAmtY = 0;
2336 if (!match(V: Y, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY)))))) &&
2337 !match(V: Y, P: m_OneUse(
2338 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY))))),
2339 R: m_ConstantInt(V&: ShAmtY)))))
2340 return false;
2341
2342 // Canonicalize larger shift to the RHS.
2343 if (ShAmtX > ShAmtY) {
2344 std::swap(a&: X, b&: Y);
2345 std::swap(a&: SrcX, b&: SrcY);
2346 std::swap(a&: ShAmtX, b&: ShAmtY);
2347 }
2348
2349 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2350 // difference is the mask width so they can be easily concatenated together.
2351 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2352 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2353 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2354 auto *MaskTy = dyn_cast<FixedVectorType>(Val: SrcX->getType());
2355 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2356 !MaskTy->getElementType()->isIntegerTy(Bitwidth: 1) ||
2357 MaskTy->getNumElements() != ShAmtDiff ||
2358 MaskTy->getNumElements() > (BitWidth / 2))
2359 return false;
2360
2361 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(VTy: MaskTy);
2362 auto *ConcatIntTy =
2363 Type::getIntNTy(C&: Ty->getContext(), N: ConcatTy->getNumElements());
2364 auto *MaskIntTy = Type::getIntNTy(C&: Ty->getContext(), N: ShAmtDiff);
2365
2366 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2367 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
2368
2369 // TODO: Is it worth supporting multi use cases?
2370 InstructionCost OldCost = 0;
2371 OldCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Or, Ty, CostKind);
2372 OldCost +=
2373 NumSHL * TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2374 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: MaskIntTy,
2375 CCH: TTI::CastContextHint::None, CostKind);
2376 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: MaskIntTy, Src: MaskTy,
2377 CCH: TTI::CastContextHint::None, CostKind);
2378
2379 InstructionCost NewCost = 0;
2380 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ConcatTy,
2381 SrcTy: MaskTy, Mask: ConcatMask, CostKind);
2382 NewCost += TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: ConcatIntTy, Src: ConcatTy,
2383 CCH: TTI::CastContextHint::None, CostKind);
2384 if (Ty != ConcatIntTy)
2385 NewCost += TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: ConcatIntTy,
2386 CCH: TTI::CastContextHint::None, CostKind);
2387 if (ShAmtX > 0)
2388 NewCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2389
2390 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2391 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2392 << "\n");
2393
2394 if (NewCost > OldCost)
2395 return false;
2396
2397 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2398 // any residual zero-extension or shifting.
2399 Value *Concat = Builder.CreateShuffleVector(V1: SrcX, V2: SrcY, Mask: ConcatMask);
2400 Worklist.pushValue(V: Concat);
2401
2402 Value *Result = Builder.CreateBitCast(V: Concat, DestTy: ConcatIntTy);
2403
2404 if (Ty != ConcatIntTy) {
2405 Worklist.pushValue(V: Result);
2406 Result = Builder.CreateZExt(V: Result, DestTy: Ty);
2407 }
2408
2409 if (ShAmtX > 0) {
2410 Worklist.pushValue(V: Result);
2411 Result = Builder.CreateShl(LHS: Result, RHS: ShAmtX);
2412 }
2413
2414 replaceValue(Old&: I, New&: *Result);
2415 return true;
2416}
2417
2418/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2419/// --> "binop (shuffle), (shuffle)".
2420bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2421 BinaryOperator *BinOp;
2422 ArrayRef<int> OuterMask;
2423 if (!match(V: &I, P: m_Shuffle(v1: m_BinOp(I&: BinOp), v2: m_Undef(), mask: m_Mask(OuterMask))))
2424 return false;
2425
2426 // Don't introduce poison into div/rem.
2427 if (BinOp->isIntDivRem() && llvm::is_contained(Range&: OuterMask, Element: PoisonMaskElem))
2428 return false;
2429
2430 Value *Op00, *Op01, *Op10, *Op11;
2431 ArrayRef<int> Mask0, Mask1;
2432 bool Match0 = match(V: BinOp->getOperand(i_nocapture: 0),
2433 P: m_Shuffle(v1: m_Value(V&: Op00), v2: m_Value(V&: Op01), mask: m_Mask(Mask0)));
2434 bool Match1 = match(V: BinOp->getOperand(i_nocapture: 1),
2435 P: m_Shuffle(v1: m_Value(V&: Op10), v2: m_Value(V&: Op11), mask: m_Mask(Mask1)));
2436 if (!Match0 && !Match1)
2437 return false;
2438
2439 Op00 = Match0 ? Op00 : BinOp->getOperand(i_nocapture: 0);
2440 Op01 = Match0 ? Op01 : BinOp->getOperand(i_nocapture: 0);
2441 Op10 = Match1 ? Op10 : BinOp->getOperand(i_nocapture: 1);
2442 Op11 = Match1 ? Op11 : BinOp->getOperand(i_nocapture: 1);
2443
2444 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2445 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2446 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: BinOp->getType());
2447 auto *Op0Ty = dyn_cast<FixedVectorType>(Val: Op00->getType());
2448 auto *Op1Ty = dyn_cast<FixedVectorType>(Val: Op10->getType());
2449 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2450 return false;
2451
2452 unsigned NumSrcElts = BinOpTy->getNumElements();
2453
2454 // Don't accept shuffles that reference the second operand in
2455 // div/rem or if its an undef arg.
2456 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(Val: I.getOperand(i: 1))) &&
2457 any_of(Range&: OuterMask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2458 return false;
2459
2460 // Merge outer / inner (or identity if no match) shuffles.
2461 SmallVector<int> NewMask0, NewMask1;
2462 for (int M : OuterMask) {
2463 if (M < 0 || M >= (int)NumSrcElts) {
2464 NewMask0.push_back(Elt: PoisonMaskElem);
2465 NewMask1.push_back(Elt: PoisonMaskElem);
2466 } else {
2467 NewMask0.push_back(Elt: Match0 ? Mask0[M] : M);
2468 NewMask1.push_back(Elt: Match1 ? Mask1[M] : M);
2469 }
2470 }
2471
2472 unsigned NumOpElts = Op0Ty->getNumElements();
2473 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2474 all_of(Range&: NewMask0, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2475 ShuffleVectorInst::isIdentityMask(Mask: NewMask0, NumSrcElts: NumOpElts);
2476 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2477 all_of(Range&: NewMask1, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2478 ShuffleVectorInst::isIdentityMask(Mask: NewMask1, NumSrcElts: NumOpElts);
2479
2480 InstructionCost NewCost = 0;
2481 // Try to merge shuffles across the binop if the new shuffles are not costly.
2482 InstructionCost BinOpCost =
2483 TTI.getArithmeticInstrCost(Opcode, Ty: BinOpTy, CostKind);
2484 InstructionCost OldCost =
2485 BinOpCost + TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
2486 DstTy: ShuffleDstTy, SrcTy: BinOpTy, Mask: OuterMask, CostKind,
2487 Index: 0, SubTp: nullptr, Args: {BinOp}, CxtI: &I);
2488 if (!BinOp->hasOneUse())
2489 NewCost += BinOpCost;
2490
2491 if (Match0) {
2492 InstructionCost Shuf0Cost = TTI.getShuffleCost(
2493 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op0Ty, Mask: Mask0, CostKind,
2494 Index: 0, SubTp: nullptr, Args: {Op00, Op01}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 0)));
2495 OldCost += Shuf0Cost;
2496 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 0)->hasOneUse())
2497 NewCost += Shuf0Cost;
2498 }
2499 if (Match1) {
2500 InstructionCost Shuf1Cost = TTI.getShuffleCost(
2501 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op1Ty, Mask: Mask1, CostKind,
2502 Index: 0, SubTp: nullptr, Args: {Op10, Op11}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 1)));
2503 OldCost += Shuf1Cost;
2504 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 1)->hasOneUse())
2505 NewCost += Shuf1Cost;
2506 }
2507
2508 NewCost += TTI.getArithmeticInstrCost(Opcode, Ty: ShuffleDstTy, CostKind);
2509
2510 if (!IsIdentity0)
2511 NewCost +=
2512 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2513 SrcTy: Op0Ty, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {Op00, Op01});
2514 if (!IsIdentity1)
2515 NewCost +=
2516 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2517 SrcTy: Op1Ty, Mask: NewMask1, CostKind, Index: 0, SubTp: nullptr, Args: {Op10, Op11});
2518
2519 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2520 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2521 << "\n");
2522
2523 // If costs are equal, still fold as we reduce instruction count.
2524 if (NewCost > OldCost)
2525 return false;
2526
2527 Value *LHS =
2528 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(V1: Op00, V2: Op01, Mask: NewMask0);
2529 Value *RHS =
2530 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(V1: Op10, V2: Op11, Mask: NewMask1);
2531 Value *NewBO = Builder.CreateBinOp(Opc: Opcode, LHS, RHS);
2532
2533 // Intersect flags from the old binops.
2534 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO))
2535 NewInst->copyIRFlags(V: BinOp);
2536
2537 Worklist.pushValue(V: LHS);
2538 Worklist.pushValue(V: RHS);
2539 replaceValue(Old&: I, New&: *NewBO);
2540 return true;
2541}
2542
2543/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2544/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2545bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2546 ArrayRef<int> OldMask;
2547 Instruction *LHS, *RHS;
2548 if (!match(V: &I, P: m_Shuffle(v1: m_Instruction(I&: LHS), v2: m_Instruction(I&: RHS),
2549 mask: m_Mask(OldMask))))
2550 return false;
2551
2552 // TODO: Add support for addlike etc.
2553 if (LHS->getOpcode() != RHS->getOpcode())
2554 return false;
2555
2556 Value *X, *Y, *Z, *W;
2557 bool IsCommutative = false;
2558 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2559 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2560 if (match(V: LHS, P: m_BinOp(L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2561 match(V: RHS, P: m_BinOp(L: m_Value(V&: Z), R: m_Value(V&: W)))) {
2562 auto *BO = cast<BinaryOperator>(Val: LHS);
2563 // Don't introduce poison into div/rem.
2564 if (llvm::is_contained(Range&: OldMask, Element: PoisonMaskElem) && BO->isIntDivRem())
2565 return false;
2566 IsCommutative = BinaryOperator::isCommutative(Opcode: BO->getOpcode());
2567 } else if (match(V: LHS, P: m_Cmp(Pred&: PredLHS, L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2568 match(V: RHS, P: m_Cmp(Pred&: PredRHS, L: m_Value(V&: Z), R: m_Value(V&: W))) &&
2569 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2570 IsCommutative = cast<CmpInst>(Val: LHS)->isCommutative();
2571 } else
2572 return false;
2573
2574 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2575 auto *BinResTy = dyn_cast<FixedVectorType>(Val: LHS->getType());
2576 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: X->getType());
2577 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2578 return false;
2579
2580 bool SameBinOp = LHS == RHS;
2581 unsigned NumSrcElts = BinOpTy->getNumElements();
2582
2583 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2584 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2585 std::swap(a&: X, b&: Y);
2586
2587 auto ConvertToUnary = [NumSrcElts](int &M) {
2588 if (M >= (int)NumSrcElts)
2589 M -= NumSrcElts;
2590 };
2591
2592 SmallVector<int> NewMask0(OldMask);
2593 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2594 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Y: Z);
2595 if (X == Z) {
2596 llvm::for_each(Range&: NewMask0, F: ConvertToUnary);
2597 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2598 Z = PoisonValue::get(T: BinOpTy);
2599 }
2600
2601 SmallVector<int> NewMask1(OldMask);
2602 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2603 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(X: Y, Y: W);
2604 if (Y == W) {
2605 llvm::for_each(Range&: NewMask1, F: ConvertToUnary);
2606 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2607 W = PoisonValue::get(T: BinOpTy);
2608 }
2609
2610 // Try to replace a binop with a shuffle if the shuffle is not costly.
2611 // When SameBinOp, only count the binop cost once.
2612 InstructionCost LHSCost = TTI.getInstructionCost(U: LHS, CostKind);
2613 InstructionCost RHSCost = TTI.getInstructionCost(U: RHS, CostKind);
2614
2615 InstructionCost OldCost = LHSCost;
2616 if (!SameBinOp) {
2617 OldCost += RHSCost;
2618 }
2619 OldCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc,
2620 DstTy: ShuffleDstTy, SrcTy: BinResTy, Mask: OldMask, CostKind, Index: 0,
2621 SubTp: nullptr, Args: {LHS, RHS}, CxtI: &I);
2622
2623 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2624 // where one use shuffles have gotten split across the binop/cmp. These
2625 // often allow a major reduction in total cost that wouldn't happen as
2626 // individual folds.
2627 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2628 TTI::TargetCostKind CostKind) -> bool {
2629 Value *InnerOp;
2630 ArrayRef<int> InnerMask;
2631 if (match(V: Op, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: InnerOp), v2: m_Undef(),
2632 mask: m_Mask(InnerMask)))) &&
2633 InnerOp->getType() == Op->getType() &&
2634 all_of(Range&: InnerMask,
2635 P: [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2636 for (int &M : Mask)
2637 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2638 M = InnerMask[M - Offset];
2639 M = 0 <= M ? M + Offset : M;
2640 }
2641 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: Op), CostKind);
2642 Op = InnerOp;
2643 return true;
2644 }
2645 return false;
2646 };
2647 bool ReducedInstCount = false;
2648 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2649 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2650 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2651 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2652 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2653 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2654 // original binop(s). If binops have multiple uses, they won't be eliminated.
2655 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2656
2657 auto *ShuffleCmpTy =
2658 FixedVectorType::get(ElementType: BinOpTy->getElementType(), FVTy: ShuffleDstTy);
2659 InstructionCost NewCost = TTI.getShuffleCost(
2660 Kind: SK0, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {X, Z});
2661 if (!SingleSrcBinOp)
2662 NewCost += TTI.getShuffleCost(Kind: SK1, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask1,
2663 CostKind, Index: 0, SubTp: nullptr, Args: {Y, W});
2664
2665 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2666 NewCost += TTI.getArithmeticInstrCost(Opcode: LHS->getOpcode(), Ty: ShuffleDstTy,
2667 CostKind, Opd1Info: Op0Info, Opd2Info: Op1Info);
2668 } else {
2669 NewCost +=
2670 TTI.getCmpSelInstrCost(Opcode: LHS->getOpcode(), ValTy: ShuffleCmpTy, CondTy: ShuffleDstTy,
2671 VecPred: PredLHS, CostKind, Op1Info: Op0Info, Op2Info: Op1Info);
2672 }
2673 // If LHS/RHS have other uses, we need to account for the cost of keeping
2674 // the original instructions. When SameBinOp, only add the cost once.
2675 if (!LHS->hasOneUser())
2676 NewCost += LHSCost;
2677 if (!SameBinOp && !RHS->hasOneUser())
2678 NewCost += RHSCost;
2679
2680 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2681 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2682 << "\n");
2683
2684 // If either shuffle will constant fold away, then fold for the same cost as
2685 // we will reduce the instruction count.
2686 ReducedInstCount |= (isa<Constant>(Val: X) && isa<Constant>(Val: Z)) ||
2687 (isa<Constant>(Val: Y) && isa<Constant>(Val: W));
2688 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2689 return false;
2690
2691 Value *Shuf0 = Builder.CreateShuffleVector(V1: X, V2: Z, Mask: NewMask0);
2692 Value *Shuf1 =
2693 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(V1: Y, V2: W, Mask: NewMask1);
2694 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2695 ? Builder.CreateBinOp(
2696 Opc: cast<BinaryOperator>(Val: LHS)->getOpcode(), LHS: Shuf0, RHS: Shuf1)
2697 : Builder.CreateCmp(Pred: PredLHS, LHS: Shuf0, RHS: Shuf1);
2698
2699 // Intersect flags from the old binops.
2700 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
2701 NewInst->copyIRFlags(V: LHS);
2702 NewInst->andIRFlags(V: RHS);
2703 }
2704
2705 Worklist.pushValue(V: Shuf0);
2706 Worklist.pushValue(V: Shuf1);
2707 replaceValue(Old&: I, New&: *NewBO);
2708 return true;
2709}
2710
2711/// Try to convert,
2712/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2713/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2714bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2715 ArrayRef<int> Mask;
2716 Value *C1, *T1, *F1, *C2, *T2, *F2;
2717 if (!match(V: &I, P: m_Shuffle(v1: m_Select(C: m_Value(V&: C1), L: m_Value(V&: T1), R: m_Value(V&: F1)),
2718 v2: m_Select(C: m_Value(V&: C2), L: m_Value(V&: T2), R: m_Value(V&: F2)),
2719 mask: m_Mask(Mask))))
2720 return false;
2721
2722 auto *Sel1 = cast<Instruction>(Val: I.getOperand(i: 0));
2723 auto *Sel2 = cast<Instruction>(Val: I.getOperand(i: 1));
2724
2725 auto *C1VecTy = dyn_cast<FixedVectorType>(Val: C1->getType());
2726 auto *C2VecTy = dyn_cast<FixedVectorType>(Val: C2->getType());
2727 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2728 return false;
2729
2730 auto *SI0FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 0));
2731 auto *SI1FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 1));
2732 // SelectInsts must have the same FMF.
2733 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2734 ((SI0FOp != nullptr) &&
2735 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2736 return false;
2737
2738 auto *SrcVecTy = cast<FixedVectorType>(Val: T1->getType());
2739 auto *DstVecTy = cast<FixedVectorType>(Val: I.getType());
2740 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2741 auto SelOp = Instruction::Select;
2742
2743 InstructionCost CostSel1 = TTI.getCmpSelInstrCost(
2744 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C1VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2745 InstructionCost CostSel2 = TTI.getCmpSelInstrCost(
2746 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C2VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2747
2748 InstructionCost OldCost =
2749 CostSel1 + CostSel2 +
2750 TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
2751 Args: {I.getOperand(i: 0), I.getOperand(i: 1)}, CxtI: &I);
2752
2753 InstructionCost NewCost = TTI.getShuffleCost(
2754 Kind: SK, DstTy: FixedVectorType::get(ElementType: C1VecTy->getScalarType(), NumElts: Mask.size()), SrcTy: C1VecTy,
2755 Mask, CostKind, Index: 0, SubTp: nullptr, Args: {C1, C2});
2756 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2757 SubTp: nullptr, Args: {T1, T2});
2758 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2759 SubTp: nullptr, Args: {F1, F2});
2760 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2761 ElementType: Type::getInt1Ty(C&: I.getContext()), NumElts: DstVecTy->getNumElements());
2762 NewCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: DstVecTy, CondTy: C1C2ShuffledVecTy,
2763 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2764
2765 if (!Sel1->hasOneUse())
2766 NewCost += CostSel1;
2767 if (!Sel2->hasOneUse())
2768 NewCost += CostSel2;
2769
2770 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2771 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2772 << "\n");
2773 if (NewCost > OldCost)
2774 return false;
2775
2776 Value *ShuffleCmp = Builder.CreateShuffleVector(V1: C1, V2: C2, Mask);
2777 Value *ShuffleTrue = Builder.CreateShuffleVector(V1: T1, V2: T2, Mask);
2778 Value *ShuffleFalse = Builder.CreateShuffleVector(V1: F1, V2: F2, Mask);
2779 Value *NewSel;
2780 // We presuppose that the SelectInsts have the same FMF.
2781 if (SI0FOp)
2782 NewSel = Builder.CreateSelectFMF(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse,
2783 FMFSource: SI0FOp->getFastMathFlags());
2784 else
2785 NewSel = Builder.CreateSelect(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse);
2786
2787 Worklist.pushValue(V: ShuffleCmp);
2788 Worklist.pushValue(V: ShuffleTrue);
2789 Worklist.pushValue(V: ShuffleFalse);
2790 replaceValue(Old&: I, New&: *NewSel);
2791 return true;
2792}
2793
2794/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2795/// into "castop (shuffle)".
2796bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2797 Value *V0, *V1;
2798 ArrayRef<int> OldMask;
2799 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
2800 return false;
2801
2802 // Check whether this is a binary shuffle.
2803 bool IsBinaryShuffle = !isa<UndefValue>(Val: V1);
2804
2805 auto *C0 = dyn_cast<CastInst>(Val: V0);
2806 auto *C1 = dyn_cast<CastInst>(Val: V1);
2807 if (!C0 || (IsBinaryShuffle && !C1))
2808 return false;
2809
2810 Instruction::CastOps Opcode = C0->getOpcode();
2811
2812 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2813 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2814 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2815 return false;
2816
2817 if (IsBinaryShuffle) {
2818 if (C0->getSrcTy() != C1->getSrcTy())
2819 return false;
2820 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2821 if (Opcode != C1->getOpcode()) {
2822 if (match(V: C0, P: m_SExtLike(Op: m_Value())) && match(V: C1, P: m_SExtLike(Op: m_Value())))
2823 Opcode = Instruction::SExt;
2824 else
2825 return false;
2826 }
2827 }
2828
2829 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2830 auto *CastDstTy = dyn_cast<FixedVectorType>(Val: C0->getDestTy());
2831 auto *CastSrcTy = dyn_cast<FixedVectorType>(Val: C0->getSrcTy());
2832 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2833 return false;
2834
2835 unsigned NumSrcElts = CastSrcTy->getNumElements();
2836 unsigned NumDstElts = CastDstTy->getNumElements();
2837 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2838 "Only bitcasts expected to alter src/dst element counts");
2839
2840 // Check for bitcasting of unscalable vector types.
2841 // e.g. <32 x i40> -> <40 x i32>
2842 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2843 (NumDstElts % NumSrcElts) != 0)
2844 return false;
2845
2846 SmallVector<int, 16> NewMask;
2847 if (NumSrcElts >= NumDstElts) {
2848 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2849 // always be expanded to the equivalent form choosing narrower elements.
2850 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2851 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2852 narrowShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask);
2853 } else {
2854 // The bitcast is from narrow elements to wide elements. The shuffle mask
2855 // must choose consecutive elements to allow casting first.
2856 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2857 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2858 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask))
2859 return false;
2860 }
2861
2862 auto *NewShuffleDstTy =
2863 FixedVectorType::get(ElementType: CastSrcTy->getScalarType(), NumElts: NewMask.size());
2864
2865 // Try to replace a castop with a shuffle if the shuffle is not costly.
2866 InstructionCost CostC0 =
2867 TTI.getCastInstrCost(Opcode: C0->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2868 CCH: TTI::CastContextHint::None, CostKind, I: C0);
2869
2870 TargetTransformInfo::ShuffleKind ShuffleKind;
2871 if (IsBinaryShuffle)
2872 ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2873 else
2874 ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2875
2876 InstructionCost OldCost = CostC0;
2877 OldCost += TTI.getShuffleCost(Kind: ShuffleKind, DstTy: ShuffleDstTy, SrcTy: CastDstTy, Mask: OldMask,
2878 CostKind, Index: 0, SubTp: nullptr, Args: {}, CxtI: &I);
2879
2880 InstructionCost NewCost = TTI.getShuffleCost(Kind: ShuffleKind, DstTy: NewShuffleDstTy,
2881 SrcTy: CastSrcTy, Mask: NewMask, CostKind);
2882 NewCost += TTI.getCastInstrCost(Opcode, Dst: ShuffleDstTy, Src: NewShuffleDstTy,
2883 CCH: TTI::CastContextHint::None, CostKind);
2884 if (!C0->hasOneUse())
2885 NewCost += CostC0;
2886 if (IsBinaryShuffle) {
2887 InstructionCost CostC1 =
2888 TTI.getCastInstrCost(Opcode: C1->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2889 CCH: TTI::CastContextHint::None, CostKind, I: C1);
2890 OldCost += CostC1;
2891 if (!C1->hasOneUse())
2892 NewCost += CostC1;
2893 }
2894
2895 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2896 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2897 << "\n");
2898 if (NewCost > OldCost)
2899 return false;
2900
2901 Value *Shuf;
2902 if (IsBinaryShuffle)
2903 Shuf = Builder.CreateShuffleVector(V1: C0->getOperand(i_nocapture: 0), V2: C1->getOperand(i_nocapture: 0),
2904 Mask: NewMask);
2905 else
2906 Shuf = Builder.CreateShuffleVector(V: C0->getOperand(i_nocapture: 0), Mask: NewMask);
2907
2908 Value *Cast = Builder.CreateCast(Op: Opcode, V: Shuf, DestTy: ShuffleDstTy);
2909
2910 // Intersect flags from the old casts.
2911 if (auto *NewInst = dyn_cast<Instruction>(Val: Cast)) {
2912 NewInst->copyIRFlags(V: C0);
2913 if (IsBinaryShuffle)
2914 NewInst->andIRFlags(V: C1);
2915 }
2916
2917 Worklist.pushValue(V: Shuf);
2918 replaceValue(Old&: I, New&: *Cast);
2919 return true;
2920}
2921
2922/// Try to convert any of:
2923/// "shuffle (shuffle x, y), (shuffle y, x)"
2924/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2925/// "shuffle (shuffle x, undef), y"
2926/// "shuffle x, (shuffle y, undef)"
2927/// into "shuffle x, y".
2928bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2929 ArrayRef<int> OuterMask;
2930 Value *OuterV0, *OuterV1;
2931 if (!match(V: &I,
2932 P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1), mask: m_Mask(OuterMask))))
2933 return false;
2934
2935 ArrayRef<int> InnerMask0, InnerMask1;
2936 Value *X0, *X1, *Y0, *Y1;
2937 bool Match0 =
2938 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: X0), v2: m_Value(V&: Y0), mask: m_Mask(InnerMask0)));
2939 bool Match1 =
2940 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: X1), v2: m_Value(V&: Y1), mask: m_Mask(InnerMask1)));
2941 if (!Match0 && !Match1)
2942 return false;
2943
2944 // If the outer shuffle is a permute, then create a fake inner all-poison
2945 // shuffle. This is easier than accounting for length-changing shuffles below.
2946 SmallVector<int, 16> PoisonMask1;
2947 if (!Match1 && isa<PoisonValue>(Val: OuterV1)) {
2948 X1 = X0;
2949 Y1 = Y0;
2950 PoisonMask1.append(NumInputs: InnerMask0.size(), Elt: PoisonMaskElem);
2951 InnerMask1 = PoisonMask1;
2952 Match1 = true; // fake match
2953 }
2954
2955 X0 = Match0 ? X0 : OuterV0;
2956 Y0 = Match0 ? Y0 : OuterV0;
2957 X1 = Match1 ? X1 : OuterV1;
2958 Y1 = Match1 ? Y1 : OuterV1;
2959 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2960 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(Val: X0->getType());
2961 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(Val: OuterV0->getType());
2962 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2963 X0->getType() != X1->getType())
2964 return false;
2965
2966 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2967 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2968
2969 // Attempt to merge shuffles, matching upto 2 source operands.
2970 // Replace index to a poison arg with PoisonMaskElem.
2971 // Bail if either inner masks reference an undef arg.
2972 SmallVector<int, 16> NewMask(OuterMask);
2973 Value *NewX = nullptr, *NewY = nullptr;
2974 for (int &M : NewMask) {
2975 Value *Src = nullptr;
2976 if (0 <= M && M < (int)NumImmElts) {
2977 Src = OuterV0;
2978 if (Match0) {
2979 M = InnerMask0[M];
2980 Src = M >= (int)NumSrcElts ? Y0 : X0;
2981 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2982 }
2983 } else if (M >= (int)NumImmElts) {
2984 Src = OuterV1;
2985 M -= NumImmElts;
2986 if (Match1) {
2987 M = InnerMask1[M];
2988 Src = M >= (int)NumSrcElts ? Y1 : X1;
2989 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2990 }
2991 }
2992 if (Src && M != PoisonMaskElem) {
2993 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
2994 if (isa<UndefValue>(Val: Src)) {
2995 // We've referenced an undef element - if its poison, update the shuffle
2996 // mask, else bail.
2997 if (!isa<PoisonValue>(Val: Src))
2998 return false;
2999 M = PoisonMaskElem;
3000 continue;
3001 }
3002 if (!NewX || NewX == Src) {
3003 NewX = Src;
3004 continue;
3005 }
3006 if (!NewY || NewY == Src) {
3007 M += NumSrcElts;
3008 NewY = Src;
3009 continue;
3010 }
3011 return false;
3012 }
3013 }
3014
3015 if (!NewX)
3016 return PoisonValue::get(T: ShuffleDstTy);
3017 if (!NewY)
3018 NewY = PoisonValue::get(T: ShuffleSrcTy);
3019
3020 // Have we folded to an Identity shuffle?
3021 if (ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts)) {
3022 replaceValue(Old&: I, New&: *NewX);
3023 return true;
3024 }
3025
3026 // Try to merge the shuffles if the new shuffle is not costly.
3027 InstructionCost InnerCost0 = 0;
3028 if (Match0)
3029 InnerCost0 = TTI.getInstructionCost(U: cast<User>(Val: OuterV0), CostKind);
3030
3031 InstructionCost InnerCost1 = 0;
3032 if (Match1)
3033 InnerCost1 = TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3034
3035 InstructionCost OuterCost = TTI.getInstructionCost(U: &I, CostKind);
3036
3037 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3038
3039 bool IsUnary = all_of(Range&: NewMask, P: [&](int M) { return M < (int)NumSrcElts; });
3040 TargetTransformInfo::ShuffleKind SK =
3041 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
3042 : TargetTransformInfo::SK_PermuteTwoSrc;
3043 InstructionCost NewCost =
3044 TTI.getShuffleCost(Kind: SK, DstTy: ShuffleDstTy, SrcTy: ShuffleSrcTy, Mask: NewMask, CostKind, Index: 0,
3045 SubTp: nullptr, Args: {NewX, NewY});
3046 if (!OuterV0->hasOneUse())
3047 NewCost += InnerCost0;
3048 if (!OuterV1->hasOneUse())
3049 NewCost += InnerCost1;
3050
3051 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3052 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3053 << "\n");
3054 if (NewCost > OldCost)
3055 return false;
3056
3057 Value *Shuf = Builder.CreateShuffleVector(V1: NewX, V2: NewY, Mask: NewMask);
3058 replaceValue(Old&: I, New&: *Shuf);
3059 return true;
3060}
3061
3062/// Try to convert a chain of length-preserving shuffles that are fed by
3063/// length-changing shuffles from the same source, e.g. a chain of length 3:
3064///
3065/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3066/// (shuffle y, undef)),
3067// (shuffle y, undef)"
3068///
3069/// into a single shuffle fed by a length-changing shuffle:
3070///
3071/// "shuffle x, (shuffle y, undef)"
3072///
3073/// Such chains arise e.g. from folding extract/insert sequences.
3074bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3075 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(Val: I.getType());
3076 if (!TrunkType)
3077 return false;
3078
3079 unsigned ChainLength = 0;
3080 SmallVector<int> Mask;
3081 SmallVector<int> YMask;
3082 InstructionCost OldCost = 0;
3083 InstructionCost NewCost = 0;
3084 Value *Trunk = &I;
3085 unsigned NumTrunkElts = TrunkType->getNumElements();
3086 Value *Y = nullptr;
3087
3088 for (;;) {
3089 // Match the current trunk against (commutations of) the pattern
3090 // "shuffle trunk', (shuffle y, undef)"
3091 ArrayRef<int> OuterMask;
3092 Value *OuterV0, *OuterV1;
3093 if (ChainLength != 0 && !Trunk->hasOneUse())
3094 break;
3095 if (!match(V: Trunk, P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1),
3096 mask: m_Mask(OuterMask))))
3097 break;
3098 if (OuterV0->getType() != TrunkType) {
3099 // This shuffle is not length-preserving, so it cannot be part of the
3100 // chain.
3101 break;
3102 }
3103
3104 ArrayRef<int> InnerMask0, InnerMask1;
3105 Value *A0, *A1, *B0, *B1;
3106 bool Match0 =
3107 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: A0), v2: m_Value(V&: B0), mask: m_Mask(InnerMask0)));
3108 bool Match1 =
3109 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: A1), v2: m_Value(V&: B1), mask: m_Mask(InnerMask1)));
3110 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3111 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3112 if (Match0Leaf == Match1Leaf) {
3113 // Only handle the case of exactly one leaf in each step. The "two leaves"
3114 // case is handled by foldShuffleOfShuffles.
3115 break;
3116 }
3117
3118 SmallVector<int> CommutedOuterMask;
3119 if (Match0Leaf) {
3120 std::swap(a&: OuterV0, b&: OuterV1);
3121 std::swap(a&: InnerMask0, b&: InnerMask1);
3122 std::swap(a&: A0, b&: A1);
3123 std::swap(a&: B0, b&: B1);
3124 llvm::append_range(C&: CommutedOuterMask, R&: OuterMask);
3125 for (int &M : CommutedOuterMask) {
3126 if (M == PoisonMaskElem)
3127 continue;
3128 if (M < (int)NumTrunkElts)
3129 M += NumTrunkElts;
3130 else
3131 M -= NumTrunkElts;
3132 }
3133 OuterMask = CommutedOuterMask;
3134 }
3135 if (!OuterV1->hasOneUse())
3136 break;
3137
3138 if (!isa<UndefValue>(Val: A1)) {
3139 if (!Y)
3140 Y = A1;
3141 else if (Y != A1)
3142 break;
3143 }
3144 if (!isa<UndefValue>(Val: B1)) {
3145 if (!Y)
3146 Y = B1;
3147 else if (Y != B1)
3148 break;
3149 }
3150
3151 auto *YType = cast<FixedVectorType>(Val: A1->getType());
3152 int NumLeafElts = YType->getNumElements();
3153 SmallVector<int> LocalYMask(InnerMask1);
3154 for (int &M : LocalYMask) {
3155 if (M >= NumLeafElts)
3156 M -= NumLeafElts;
3157 }
3158
3159 InstructionCost LocalOldCost =
3160 TTI.getInstructionCost(U: cast<User>(Val: Trunk), CostKind) +
3161 TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3162
3163 // Handle the initial (start of chain) case.
3164 if (!ChainLength) {
3165 Mask.assign(AR: OuterMask);
3166 YMask.assign(RHS: LocalYMask);
3167 OldCost = NewCost = LocalOldCost;
3168 Trunk = OuterV0;
3169 ChainLength++;
3170 continue;
3171 }
3172
3173 // For the non-root case, first attempt to combine masks.
3174 SmallVector<int> NewYMask(YMask);
3175 bool Valid = true;
3176 for (auto [CombinedM, LeafM] : llvm::zip(t&: NewYMask, u&: LocalYMask)) {
3177 if (LeafM == -1 || CombinedM == LeafM)
3178 continue;
3179 if (CombinedM == -1) {
3180 CombinedM = LeafM;
3181 } else {
3182 Valid = false;
3183 break;
3184 }
3185 }
3186 if (!Valid)
3187 break;
3188
3189 SmallVector<int> NewMask;
3190 NewMask.reserve(N: NumTrunkElts);
3191 for (int M : Mask) {
3192 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3193 NewMask.push_back(Elt: M);
3194 else
3195 NewMask.push_back(Elt: OuterMask[M]);
3196 }
3197
3198 // Break the chain if adding this new step complicates the shuffles such
3199 // that it would increase the new cost by more than the old cost of this
3200 // step.
3201 InstructionCost LocalNewCost =
3202 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: TrunkType,
3203 SrcTy: YType, Mask: NewYMask, CostKind) +
3204 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: TrunkType,
3205 SrcTy: TrunkType, Mask: NewMask, CostKind);
3206
3207 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3208 break;
3209
3210 LLVM_DEBUG({
3211 if (ChainLength == 1) {
3212 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3213 << I << '\n';
3214 }
3215 dbgs() << " next chain link: " << *Trunk << '\n'
3216 << " old cost: " << (OldCost + LocalOldCost)
3217 << " new cost: " << LocalNewCost << '\n';
3218 });
3219
3220 Mask = NewMask;
3221 YMask = NewYMask;
3222 OldCost += LocalOldCost;
3223 NewCost = LocalNewCost;
3224 Trunk = OuterV0;
3225 ChainLength++;
3226 }
3227 if (ChainLength <= 1)
3228 return false;
3229
3230 if (llvm::all_of(Range&: Mask, P: [&](int M) {
3231 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3232 })) {
3233 // Produce a canonical simplified form if all elements are sourced from Y.
3234 for (int &M : Mask) {
3235 if (M >= static_cast<int>(NumTrunkElts))
3236 M = YMask[M - NumTrunkElts];
3237 }
3238 Value *Root =
3239 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask);
3240 replaceValue(Old&: I, New&: *Root);
3241 return true;
3242 }
3243
3244 Value *Leaf =
3245 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask: YMask);
3246 Value *Root = Builder.CreateShuffleVector(V1: Trunk, V2: Leaf, Mask);
3247 replaceValue(Old&: I, New&: *Root);
3248 return true;
3249}
3250
3251/// Try to convert
3252/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3253bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3254 Value *V0, *V1;
3255 ArrayRef<int> OldMask;
3256 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
3257 return false;
3258
3259 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3260 auto *II1 = dyn_cast<IntrinsicInst>(Val: V1);
3261 if (!II0 || !II1)
3262 return false;
3263
3264 Intrinsic::ID IID = II0->getIntrinsicID();
3265 if (IID != II1->getIntrinsicID())
3266 return false;
3267 InstructionCost CostII0 =
3268 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3269 InstructionCost CostII1 =
3270 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II1), CostKind);
3271
3272 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3273 auto *II0Ty = dyn_cast<FixedVectorType>(Val: II0->getType());
3274 if (!ShuffleDstTy || !II0Ty)
3275 return false;
3276
3277 if (!isTriviallyVectorizable(ID: IID))
3278 return false;
3279
3280 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3281 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI) &&
3282 II0->getArgOperand(i: I) != II1->getArgOperand(i: I))
3283 return false;
3284
3285 InstructionCost OldCost =
3286 CostII0 + CostII1 +
3287 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
3288 SrcTy: II0Ty, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {II0, II1}, CxtI: &I);
3289
3290 SmallVector<Type *> NewArgsTy;
3291 InstructionCost NewCost = 0;
3292 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3293 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3294 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3295 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3296 } else {
3297 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3298 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3299 NumElts: ShuffleDstTy->getNumElements());
3300 NewArgsTy.push_back(Elt: ArgTy);
3301 std::pair<Value *, Value *> OperandPair =
3302 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3303 if (!SeenOperandPairs.insert(V: OperandPair).second) {
3304 // We've already computed the cost for this operand pair.
3305 continue;
3306 }
3307 NewCost += TTI.getShuffleCost(
3308 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ArgTy, SrcTy: VecTy, Mask: OldMask,
3309 CostKind, Index: 0, SubTp: nullptr, Args: {II0->getArgOperand(i: I), II1->getArgOperand(i: I)});
3310 }
3311 }
3312 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3313
3314 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3315 if (!II0->hasOneUse())
3316 NewCost += CostII0;
3317 if (II1 != II0 && !II1->hasOneUse())
3318 NewCost += CostII1;
3319
3320 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3321 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3322 << "\n");
3323
3324 if (NewCost > OldCost)
3325 return false;
3326
3327 SmallVector<Value *> NewArgs;
3328 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3329 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3330 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3331 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3332 } else {
3333 std::pair<Value *, Value *> OperandPair =
3334 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3335 auto It = ShuffleCache.find(Val: OperandPair);
3336 if (It != ShuffleCache.end()) {
3337 // Reuse previously created shuffle for this operand pair.
3338 NewArgs.push_back(Elt: It->second);
3339 continue;
3340 }
3341 Value *Shuf = Builder.CreateShuffleVector(V1: II0->getArgOperand(i: I),
3342 V2: II1->getArgOperand(i: I), Mask: OldMask);
3343 ShuffleCache[OperandPair] = Shuf;
3344 NewArgs.push_back(Elt: Shuf);
3345 Worklist.pushValue(V: Shuf);
3346 }
3347 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3348
3349 // Intersect flags from the old intrinsics.
3350 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic)) {
3351 NewInst->copyIRFlags(V: II0);
3352 NewInst->andIRFlags(V: II1);
3353 }
3354
3355 replaceValue(Old&: I, New&: *NewIntrinsic);
3356 return true;
3357}
3358
3359/// Try to convert
3360/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3361bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3362 Value *V0;
3363 ArrayRef<int> Mask;
3364 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Undef(), mask: m_Mask(Mask))))
3365 return false;
3366
3367 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3368 if (!II0)
3369 return false;
3370
3371 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3372 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(Val: II0->getType());
3373 if (!ShuffleDstTy || !IntrinsicSrcTy)
3374 return false;
3375
3376 // Validate it's a pure permute, mask should only reference the first vector
3377 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3378 if (any_of(Range&: Mask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3379 return false;
3380
3381 Intrinsic::ID IID = II0->getIntrinsicID();
3382 if (!isTriviallyVectorizable(ID: IID))
3383 return false;
3384
3385 // Cost analysis
3386 InstructionCost IntrinsicCost =
3387 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3388 InstructionCost OldCost =
3389 IntrinsicCost +
3390 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: ShuffleDstTy,
3391 SrcTy: IntrinsicSrcTy, Mask, CostKind, Index: 0, SubTp: nullptr, Args: {V0}, CxtI: &I);
3392
3393 SmallVector<Type *> NewArgsTy;
3394 InstructionCost NewCost = 0;
3395 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3396 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3397 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3398 } else {
3399 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3400 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3401 NumElts: ShuffleDstTy->getNumElements());
3402 NewArgsTy.push_back(Elt: ArgTy);
3403 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
3404 DstTy: ArgTy, SrcTy: VecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
3405 Args: {II0->getArgOperand(i: I)});
3406 }
3407 }
3408 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3409 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3410
3411 // If the intrinsic has multiple uses, we need to account for the cost of
3412 // keeping the original intrinsic around.
3413 if (!II0->hasOneUse())
3414 NewCost += IntrinsicCost;
3415
3416 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3417 << OldCost << " vs NewCost: " << NewCost << "\n");
3418
3419 if (NewCost > OldCost)
3420 return false;
3421
3422 // Transform
3423 SmallVector<Value *> NewArgs;
3424 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3425 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3426 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3427 } else {
3428 Value *Shuf = Builder.CreateShuffleVector(V: II0->getArgOperand(i: I), Mask);
3429 NewArgs.push_back(Elt: Shuf);
3430 Worklist.pushValue(V: Shuf);
3431 }
3432 }
3433
3434 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3435
3436 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic))
3437 NewInst->copyIRFlags(V: II0);
3438
3439 replaceValue(Old&: I, New&: *NewIntrinsic);
3440 return true;
3441}
3442
3443using InstLane = std::pair<Use *, int>;
3444
3445static InstLane lookThroughShuffles(Use *U, int Lane) {
3446 while (auto *SV = dyn_cast<ShuffleVectorInst>(Val: U->get())) {
3447 unsigned NumElts =
3448 cast<FixedVectorType>(Val: SV->getOperand(i_nocapture: 0)->getType())->getNumElements();
3449 int M = SV->getMaskValue(Elt: Lane);
3450 if (M < 0)
3451 return {nullptr, PoisonMaskElem};
3452 if (static_cast<unsigned>(M) < NumElts) {
3453 U = &SV->getOperandUse(i: 0);
3454 Lane = M;
3455 } else {
3456 U = &SV->getOperandUse(i: 1);
3457 Lane = M - NumElts;
3458 }
3459 }
3460 return InstLane{U, Lane};
3461}
3462
3463static SmallVector<InstLane>
3464generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
3465 SmallVector<InstLane> NItem;
3466 for (InstLane IL : Item) {
3467 auto [U, Lane] = IL;
3468 InstLane OpLane =
3469 U ? lookThroughShuffles(U: &cast<Instruction>(Val: U->get())->getOperandUse(i: Op),
3470 Lane)
3471 : InstLane{nullptr, PoisonMaskElem};
3472 NItem.emplace_back(Args&: OpLane);
3473 }
3474 return NItem;
3475}
3476
3477/// Detect concat of multiple values into a vector
3478static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
3479 const TargetTransformInfo &TTI) {
3480 auto *Ty = cast<FixedVectorType>(Val: Item.front().first->get()->getType());
3481 unsigned NumElts = Ty->getNumElements();
3482 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3483 return false;
3484
3485 // Check that the concat is free, usually meaning that the type will be split
3486 // during legalization.
3487 SmallVector<int, 16> ConcatMask(NumElts * 2);
3488 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
3489 if (TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc,
3490 DstTy: FixedVectorType::get(ElementType: Ty->getScalarType(), NumElts: NumElts * 2),
3491 SrcTy: Ty, Mask: ConcatMask, CostKind) != 0)
3492 return false;
3493
3494 unsigned NumSlices = Item.size() / NumElts;
3495 // Currently we generate a tree of shuffles for the concats, which limits us
3496 // to a power2.
3497 if (!isPowerOf2_32(Value: NumSlices))
3498 return false;
3499 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3500 Use *SliceV = Item[Slice * NumElts].first;
3501 if (!SliceV || SliceV->get()->getType() != Ty)
3502 return false;
3503 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3504 auto [V, Lane] = Item[Slice * NumElts + Elt];
3505 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
3506 return false;
3507 }
3508 }
3509 return true;
3510}
3511
3512static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
3513 const SmallPtrSet<Use *, 4> &IdentityLeafs,
3514 const SmallPtrSet<Use *, 4> &SplatLeafs,
3515 const SmallPtrSet<Use *, 4> &ConcatLeafs,
3516 IRBuilderBase &Builder,
3517 const TargetTransformInfo *TTI) {
3518 auto [FrontU, FrontLane] = Item.front();
3519
3520 if (IdentityLeafs.contains(Ptr: FrontU)) {
3521 return FrontU->get();
3522 }
3523 if (SplatLeafs.contains(Ptr: FrontU)) {
3524 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3525 return Builder.CreateShuffleVector(V: FrontU->get(), Mask);
3526 }
3527 if (ConcatLeafs.contains(Ptr: FrontU)) {
3528 unsigned NumElts =
3529 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements();
3530 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3531 for (unsigned S = 0; S < Values.size(); ++S)
3532 Values[S] = Item[S * NumElts].first->get();
3533
3534 while (Values.size() > 1) {
3535 NumElts *= 2;
3536 SmallVector<int, 16> Mask(NumElts, 0);
3537 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
3538 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3539 for (unsigned S = 0; S < NewValues.size(); ++S)
3540 NewValues[S] =
3541 Builder.CreateShuffleVector(V1: Values[S * 2], V2: Values[S * 2 + 1], Mask);
3542 Values = NewValues;
3543 }
3544 return Values[0];
3545 }
3546
3547 auto *I = cast<Instruction>(Val: FrontU->get());
3548 auto *II = dyn_cast<IntrinsicInst>(Val: I);
3549 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3550 SmallVector<Value *> Ops(NumOps);
3551 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3552 if (II &&
3553 isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI)) {
3554 Ops[Idx] = II->getOperand(i_nocapture: Idx);
3555 continue;
3556 }
3557 Ops[Idx] = generateNewInstTree(Item: generateInstLaneVectorFromOperand(Item, Op: Idx),
3558 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
3559 Builder, TTI);
3560 }
3561
3562 SmallVector<Value *, 8> ValueList;
3563 for (const auto &Lane : Item)
3564 if (Lane.first)
3565 ValueList.push_back(Elt: Lane.first->get());
3566
3567 Type *DstTy =
3568 FixedVectorType::get(ElementType: I->getType()->getScalarType(), NumElts: Ty->getNumElements());
3569 if (auto *BI = dyn_cast<BinaryOperator>(Val: I)) {
3570 auto *Value = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)BI->getOpcode(),
3571 LHS: Ops[0], RHS: Ops[1]);
3572 propagateIRFlags(I: Value, VL: ValueList);
3573 return Value;
3574 }
3575 if (auto *CI = dyn_cast<CmpInst>(Val: I)) {
3576 auto *Value = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: Ops[0], RHS: Ops[1]);
3577 propagateIRFlags(I: Value, VL: ValueList);
3578 return Value;
3579 }
3580 if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
3581 auto *Value = Builder.CreateSelect(C: Ops[0], True: Ops[1], False: Ops[2], Name: "", MDFrom: SI);
3582 propagateIRFlags(I: Value, VL: ValueList);
3583 return Value;
3584 }
3585 if (auto *CI = dyn_cast<CastInst>(Val: I)) {
3586 auto *Value = Builder.CreateCast(Op: CI->getOpcode(), V: Ops[0], DestTy: DstTy);
3587 propagateIRFlags(I: Value, VL: ValueList);
3588 return Value;
3589 }
3590 if (II) {
3591 auto *Value = Builder.CreateIntrinsic(RetTy: DstTy, ID: II->getIntrinsicID(), Args: Ops);
3592 propagateIRFlags(I: Value, VL: ValueList);
3593 return Value;
3594 }
3595 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3596 auto *Value =
3597 Builder.CreateUnOp(Opc: (Instruction::UnaryOps)I->getOpcode(), V: Ops[0]);
3598 propagateIRFlags(I: Value, VL: ValueList);
3599 return Value;
3600}
3601
3602// Starting from a shuffle, look up through operands tracking the shuffled index
3603// of each lane. If we can simplify away the shuffles to identities then
3604// do so.
3605bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3606 auto *Ty = dyn_cast<FixedVectorType>(Val: I.getType());
3607 if (!Ty || I.use_empty())
3608 return false;
3609
3610 SmallVector<InstLane> Start(Ty->getNumElements());
3611 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3612 Start[M] = lookThroughShuffles(U: &*I.use_begin(), Lane: M);
3613
3614 SmallVector<SmallVector<InstLane>> Worklist;
3615 Worklist.push_back(Elt: Start);
3616 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
3617 unsigned NumVisited = 0;
3618
3619 while (!Worklist.empty()) {
3620 if (++NumVisited > MaxInstrsToScan)
3621 return false;
3622
3623 SmallVector<InstLane> Item = Worklist.pop_back_val();
3624 auto [FrontU, FrontLane] = Item.front();
3625
3626 // If we found an undef first lane then bail out to keep things simple.
3627 if (!FrontU)
3628 return false;
3629
3630 // Helper to peek through bitcasts to the same value.
3631 auto IsEquiv = [&](Value *X, Value *Y) {
3632 return X->getType() == Y->getType() &&
3633 peekThroughBitcasts(V: X) == peekThroughBitcasts(V: Y);
3634 };
3635
3636 // Look for an identity value.
3637 if (FrontLane == 0 &&
3638 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements() ==
3639 Ty->getNumElements() &&
3640 all_of(Range: drop_begin(RangeOrContainer: enumerate(First&: Item)), P: [IsEquiv, Item](const auto &E) {
3641 Value *FrontV = Item.front().first->get();
3642 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
3643 E.value().second == (int)E.index());
3644 })) {
3645 IdentityLeafs.insert(Ptr: FrontU);
3646 continue;
3647 }
3648 // Look for constants, for the moment only supporting constant splats.
3649 if (auto *C = dyn_cast<Constant>(Val: FrontU);
3650 C && C->getSplatValue() &&
3651 all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3652 Value *FrontV = Item.front().first->get();
3653 Use *U = IL.first;
3654 return !U || (isa<Constant>(Val: U->get()) &&
3655 cast<Constant>(Val: U->get())->getSplatValue() ==
3656 cast<Constant>(Val: FrontV)->getSplatValue());
3657 })) {
3658 SplatLeafs.insert(Ptr: FrontU);
3659 continue;
3660 }
3661 // Look for a splat value.
3662 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3663 auto [FrontU, FrontLane] = Item.front();
3664 auto [U, Lane] = IL;
3665 return !U || (U->get() == FrontU->get() && Lane == FrontLane);
3666 })) {
3667 SplatLeafs.insert(Ptr: FrontU);
3668 continue;
3669 }
3670
3671 // We need each element to be the same type of value, and check that each
3672 // element has a single use.
3673 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3674 Value *FrontV = Item.front().first->get();
3675 if (!IL.first)
3676 return true;
3677 Value *V = IL.first->get();
3678 if (auto *I = dyn_cast<Instruction>(Val: V); I && !I->hasOneUser())
3679 return false;
3680 if (V->getValueID() != FrontV->getValueID())
3681 return false;
3682 if (auto *CI = dyn_cast<CmpInst>(Val: V))
3683 if (CI->getPredicate() != cast<CmpInst>(Val: FrontV)->getPredicate())
3684 return false;
3685 if (auto *CI = dyn_cast<CastInst>(Val: V))
3686 if (CI->getSrcTy()->getScalarType() !=
3687 cast<CastInst>(Val: FrontV)->getSrcTy()->getScalarType())
3688 return false;
3689 if (auto *SI = dyn_cast<SelectInst>(Val: V))
3690 if (!isa<VectorType>(Val: SI->getOperand(i_nocapture: 0)->getType()) ||
3691 SI->getOperand(i_nocapture: 0)->getType() !=
3692 cast<SelectInst>(Val: FrontV)->getOperand(i_nocapture: 0)->getType())
3693 return false;
3694 if (isa<CallInst>(Val: V) && !isa<IntrinsicInst>(Val: V))
3695 return false;
3696 auto *II = dyn_cast<IntrinsicInst>(Val: V);
3697 return !II || (isa<IntrinsicInst>(Val: FrontV) &&
3698 II->getIntrinsicID() ==
3699 cast<IntrinsicInst>(Val: FrontV)->getIntrinsicID() &&
3700 !II->hasOperandBundles());
3701 };
3702 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: CheckLaneIsEquivalentToFirst)) {
3703 // Check the operator is one that we support.
3704 if (isa<BinaryOperator, CmpInst>(Val: FrontU)) {
3705 // We exclude div/rem in case they hit UB from poison lanes.
3706 if (auto *BO = dyn_cast<BinaryOperator>(Val: FrontU);
3707 BO && BO->isIntDivRem())
3708 return false;
3709 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3710 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
3711 continue;
3712 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3713 FPToUIInst, SIToFPInst, UIToFPInst>(Val: FrontU)) {
3714 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3715 continue;
3716 } else if (auto *BitCast = dyn_cast<BitCastInst>(Val: FrontU)) {
3717 // TODO: Handle vector widening/narrowing bitcasts.
3718 auto *DstTy = dyn_cast<FixedVectorType>(Val: BitCast->getDestTy());
3719 auto *SrcTy = dyn_cast<FixedVectorType>(Val: BitCast->getSrcTy());
3720 if (DstTy && SrcTy &&
3721 SrcTy->getNumElements() == DstTy->getNumElements()) {
3722 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3723 continue;
3724 }
3725 } else if (isa<SelectInst>(Val: FrontU)) {
3726 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
3727 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
3728 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 2));
3729 continue;
3730 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: FrontU);
3731 II && isTriviallyVectorizable(ID: II->getIntrinsicID()) &&
3732 !II->hasOperandBundles()) {
3733 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3734 if (isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Op,
3735 TTI: &TTI)) {
3736 if (!all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item, Op](InstLane &IL) {
3737 Value *FrontV = Item.front().first->get();
3738 Use *U = IL.first;
3739 return !U || (cast<Instruction>(Val: U->get())->getOperand(i: Op) ==
3740 cast<Instruction>(Val: FrontV)->getOperand(i: Op));
3741 }))
3742 return false;
3743 continue;
3744 }
3745 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op));
3746 }
3747 continue;
3748 }
3749 }
3750
3751 if (isFreeConcat(Item, CostKind, TTI)) {
3752 ConcatLeafs.insert(Ptr: FrontU);
3753 continue;
3754 }
3755
3756 return false;
3757 }
3758
3759 if (NumVisited <= 1)
3760 return false;
3761
3762 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3763
3764 // If we got this far, we know the shuffles are superfluous and can be
3765 // removed. Scan through again and generate the new tree of instructions.
3766 Builder.SetInsertPoint(&I);
3767 Value *V = generateNewInstTree(Item: Start, Ty, IdentityLeafs, SplatLeafs,
3768 ConcatLeafs, Builder, TTI: &TTI);
3769 replaceValue(Old&: I, New&: *V);
3770 return true;
3771}
3772
3773/// Given a commutative reduction, the order of the input lanes does not alter
3774/// the results. We can use this to remove certain shuffles feeding the
3775/// reduction, removing the need to shuffle at all.
3776bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3777 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
3778 if (!II)
3779 return false;
3780 switch (II->getIntrinsicID()) {
3781 case Intrinsic::vector_reduce_add:
3782 case Intrinsic::vector_reduce_mul:
3783 case Intrinsic::vector_reduce_and:
3784 case Intrinsic::vector_reduce_or:
3785 case Intrinsic::vector_reduce_xor:
3786 case Intrinsic::vector_reduce_smin:
3787 case Intrinsic::vector_reduce_smax:
3788 case Intrinsic::vector_reduce_umin:
3789 case Intrinsic::vector_reduce_umax:
3790 break;
3791 default:
3792 return false;
3793 }
3794
3795 // Find all the inputs when looking through operations that do not alter the
3796 // lane order (binops, for example). Currently we look for a single shuffle,
3797 // and can ignore splat values.
3798 std::queue<Value *> Worklist;
3799 SmallPtrSet<Value *, 4> Visited;
3800 ShuffleVectorInst *Shuffle = nullptr;
3801 if (auto *Op = dyn_cast<Instruction>(Val: I.getOperand(i: 0)))
3802 Worklist.push(x: Op);
3803
3804 while (!Worklist.empty()) {
3805 Value *CV = Worklist.front();
3806 Worklist.pop();
3807 if (Visited.contains(Ptr: CV))
3808 continue;
3809
3810 // Splats don't change the order, so can be safely ignored.
3811 if (isSplatValue(V: CV))
3812 continue;
3813
3814 Visited.insert(Ptr: CV);
3815
3816 if (auto *CI = dyn_cast<Instruction>(Val: CV)) {
3817 if (CI->isBinaryOp()) {
3818 for (auto *Op : CI->operand_values())
3819 Worklist.push(x: Op);
3820 continue;
3821 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: CI)) {
3822 if (Shuffle && Shuffle != SV)
3823 return false;
3824 Shuffle = SV;
3825 continue;
3826 }
3827 }
3828
3829 // Anything else is currently an unknown node.
3830 return false;
3831 }
3832
3833 if (!Shuffle)
3834 return false;
3835
3836 // Check all uses of the binary ops and shuffles are also included in the
3837 // lane-invariant operations (Visited should be the list of lanewise
3838 // instructions, including the shuffle that we found).
3839 for (auto *V : Visited)
3840 for (auto *U : V->users())
3841 if (!Visited.contains(Ptr: U) && U != &I)
3842 return false;
3843
3844 FixedVectorType *VecType =
3845 dyn_cast<FixedVectorType>(Val: II->getOperand(i_nocapture: 0)->getType());
3846 if (!VecType)
3847 return false;
3848 FixedVectorType *ShuffleInputType =
3849 dyn_cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType());
3850 if (!ShuffleInputType)
3851 return false;
3852 unsigned NumInputElts = ShuffleInputType->getNumElements();
3853
3854 // Find the mask from sorting the lanes into order. This is most likely to
3855 // become a identity or concat mask. Undef elements are pushed to the end.
3856 SmallVector<int> ConcatMask;
3857 Shuffle->getShuffleMask(Result&: ConcatMask);
3858 sort(C&: ConcatMask, Comp: [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3859 bool UsesSecondVec =
3860 any_of(Range&: ConcatMask, P: [&](int M) { return M >= (int)NumInputElts; });
3861
3862 InstructionCost OldCost = TTI.getShuffleCost(
3863 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3864 SrcTy: ShuffleInputType, Mask: Shuffle->getShuffleMask(), CostKind);
3865 InstructionCost NewCost = TTI.getShuffleCost(
3866 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3867 SrcTy: ShuffleInputType, Mask: ConcatMask, CostKind);
3868
3869 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3870 << "\n");
3871 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3872 << "\n");
3873 bool MadeChanges = false;
3874 if (NewCost < OldCost) {
3875 Builder.SetInsertPoint(Shuffle);
3876 Value *NewShuffle = Builder.CreateShuffleVector(
3877 V1: Shuffle->getOperand(i_nocapture: 0), V2: Shuffle->getOperand(i_nocapture: 1), Mask: ConcatMask);
3878 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3879 replaceValue(Old&: *Shuffle, New&: *NewShuffle);
3880 return true;
3881 }
3882
3883 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3884 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3885 MadeChanges |= foldSelectShuffle(I&: *Shuffle, FromReduction: true);
3886 return MadeChanges;
3887}
3888
3889/// For a given chain of patterns of the following form:
3890///
3891/// ```
3892/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
3893///
3894/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
3895/// ty1> %1)
3896/// OR
3897/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3898///
3899/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3900/// ...
3901/// ...
3902/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3903/// 3), <n x ty1> %(i - 2)
3904/// OR
3905/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3906///
3907/// %(i) = extractelement <n x ty1> %(i - 1), 0
3908/// ```
3909///
3910/// Where:
3911/// `mask` follows a partition pattern:
3912///
3913/// Ex:
3914/// [n = 8, p = poison]
3915///
3916/// 4 5 6 7 | p p p p
3917/// 2 3 | p p p p p p
3918/// 1 | p p p p p p p
3919///
3920/// For powers of 2, there's a consistent pattern, but for other cases
3921/// the parity of the current half value at each step decides the
3922/// next partition half (see `ExpectedParityMask` for more logical details
3923/// in generalising this).
3924///
3925/// Ex:
3926/// [n = 6]
3927///
3928/// 3 4 5 | p p p
3929/// 1 2 | p p p p
3930/// 1 | p p p p p
3931bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3932 // Going bottom-up for the pattern.
3933 std::queue<Value *> InstWorklist;
3934 InstructionCost OrigCost = 0;
3935
3936 // Common instruction operation after each shuffle op.
3937 std::optional<unsigned int> CommonCallOp = std::nullopt;
3938 std::optional<Instruction::BinaryOps> CommonBinOp = std::nullopt;
3939
3940 bool IsFirstCallOrBinInst = true;
3941 bool ShouldBeCallOrBinInst = true;
3942
3943 // This stores the last used instructions for shuffle/common op.
3944 //
3945 // PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3946 // instructions from either shuffle/common op.
3947 SmallVector<Value *, 2> PrevVecV(2, nullptr);
3948
3949 Value *VecOpEE;
3950 if (!match(V: &I, P: m_ExtractElt(Val: m_Value(V&: VecOpEE), Idx: m_Zero())))
3951 return false;
3952
3953 auto *FVT = dyn_cast<FixedVectorType>(Val: VecOpEE->getType());
3954 if (!FVT)
3955 return false;
3956
3957 int64_t VecSize = FVT->getNumElements();
3958 if (VecSize < 2)
3959 return false;
3960
3961 // Number of levels would be ~log2(n), considering we always partition
3962 // by half for this fold pattern.
3963 unsigned int NumLevels = Log2_64_Ceil(Value: VecSize), VisitedCnt = 0;
3964 int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3965
3966 // This is how we generalise for all element sizes.
3967 // At each step, if vector size is odd, we need non-poison
3968 // values to cover the dominant half so we don't miss out on any element.
3969 //
3970 // This mask will help us retrieve this as we go from bottom to top:
3971 //
3972 // Mask Set -> N = N * 2 - 1
3973 // Mask Unset -> N = N * 2
3974 for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3975 Cur = (Cur + 1) / 2, --Mask) {
3976 if (Cur & 1)
3977 ExpectedParityMask |= (1ll << Mask);
3978 }
3979
3980 InstWorklist.push(x: VecOpEE);
3981
3982 while (!InstWorklist.empty()) {
3983 Value *CI = InstWorklist.front();
3984 InstWorklist.pop();
3985
3986 if (auto *II = dyn_cast<IntrinsicInst>(Val: CI)) {
3987 if (!ShouldBeCallOrBinInst)
3988 return false;
3989
3990 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
3991 return false;
3992
3993 // For the first found call/bin op, the vector has to come from the
3994 // extract element op.
3995 if (II != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
3996 return false;
3997 IsFirstCallOrBinInst = false;
3998
3999 if (!CommonCallOp)
4000 CommonCallOp = II->getIntrinsicID();
4001 if (II->getIntrinsicID() != *CommonCallOp)
4002 return false;
4003
4004 switch (II->getIntrinsicID()) {
4005 case Intrinsic::umin:
4006 case Intrinsic::umax:
4007 case Intrinsic::smin:
4008 case Intrinsic::smax: {
4009 auto *Op0 = II->getOperand(i_nocapture: 0);
4010 auto *Op1 = II->getOperand(i_nocapture: 1);
4011 PrevVecV[0] = Op0;
4012 PrevVecV[1] = Op1;
4013 break;
4014 }
4015 default:
4016 return false;
4017 }
4018 ShouldBeCallOrBinInst ^= 1;
4019
4020 IntrinsicCostAttributes ICA(
4021 *CommonCallOp, II->getType(),
4022 {PrevVecV[0]->getType(), PrevVecV[1]->getType()});
4023 OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4024
4025 // We may need a swap here since it can be (a, b) or (b, a)
4026 // and accordingly change as we go up.
4027 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4028 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4029 InstWorklist.push(x: PrevVecV[1]);
4030 InstWorklist.push(x: PrevVecV[0]);
4031 } else if (auto *BinOp = dyn_cast<BinaryOperator>(Val: CI)) {
4032 // Similar logic for bin ops.
4033
4034 if (!ShouldBeCallOrBinInst)
4035 return false;
4036
4037 if (!IsFirstCallOrBinInst && any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4038 return false;
4039
4040 if (BinOp != (IsFirstCallOrBinInst ? VecOpEE : PrevVecV[0]))
4041 return false;
4042 IsFirstCallOrBinInst = false;
4043
4044 if (!CommonBinOp)
4045 CommonBinOp = BinOp->getOpcode();
4046
4047 if (BinOp->getOpcode() != *CommonBinOp)
4048 return false;
4049
4050 switch (*CommonBinOp) {
4051 case BinaryOperator::Add:
4052 case BinaryOperator::Mul:
4053 case BinaryOperator::Or:
4054 case BinaryOperator::And:
4055 case BinaryOperator::Xor: {
4056 auto *Op0 = BinOp->getOperand(i_nocapture: 0);
4057 auto *Op1 = BinOp->getOperand(i_nocapture: 1);
4058 PrevVecV[0] = Op0;
4059 PrevVecV[1] = Op1;
4060 break;
4061 }
4062 default:
4063 return false;
4064 }
4065 ShouldBeCallOrBinInst ^= 1;
4066
4067 OrigCost +=
4068 TTI.getArithmeticInstrCost(Opcode: *CommonBinOp, Ty: BinOp->getType(), CostKind);
4069
4070 if (!isa<ShuffleVectorInst>(Val: PrevVecV[1]))
4071 std::swap(a&: PrevVecV[0], b&: PrevVecV[1]);
4072 InstWorklist.push(x: PrevVecV[1]);
4073 InstWorklist.push(x: PrevVecV[0]);
4074 } else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(Val: CI)) {
4075 // We shouldn't have any null values in the previous vectors,
4076 // is so, there was a mismatch in pattern.
4077 if (ShouldBeCallOrBinInst || any_of(Range&: PrevVecV, P: equal_to(Arg: nullptr)))
4078 return false;
4079
4080 if (SVInst != PrevVecV[1])
4081 return false;
4082
4083 ArrayRef<int> CurMask;
4084 if (!match(V: SVInst, P: m_Shuffle(v1: m_Specific(V: PrevVecV[0]), v2: m_Poison(),
4085 mask: m_Mask(CurMask))))
4086 return false;
4087
4088 // Subtract the parity mask when checking the condition.
4089 for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
4090 if (Mask < ShuffleMaskHalf &&
4091 CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
4092 return false;
4093 if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
4094 return false;
4095 }
4096
4097 // Update mask values.
4098 ShuffleMaskHalf *= 2;
4099 ShuffleMaskHalf -= (ExpectedParityMask & 1);
4100 ExpectedParityMask >>= 1;
4101
4102 OrigCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
4103 DstTy: SVInst->getType(), SrcTy: SVInst->getType(),
4104 Mask: CurMask, CostKind);
4105
4106 VisitedCnt += 1;
4107 if (!ExpectedParityMask && VisitedCnt == NumLevels)
4108 break;
4109
4110 ShouldBeCallOrBinInst ^= 1;
4111 } else {
4112 return false;
4113 }
4114 }
4115
4116 // Pattern should end with a shuffle op.
4117 if (ShouldBeCallOrBinInst)
4118 return false;
4119
4120 assert(VecSize != -1 && "Expected Match for Vector Size");
4121
4122 Value *FinalVecV = PrevVecV[0];
4123 if (!FinalVecV)
4124 return false;
4125
4126 auto *FinalVecVTy = cast<FixedVectorType>(Val: FinalVecV->getType());
4127
4128 Intrinsic::ID ReducedOp =
4129 (CommonCallOp ? getMinMaxReductionIntrinsicID(IID: *CommonCallOp)
4130 : getReductionForBinop(Opc: *CommonBinOp));
4131 if (!ReducedOp)
4132 return false;
4133
4134 IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
4135 InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);
4136
4137 if (NewCost >= OrigCost)
4138 return false;
4139
4140 auto *ReducedResult =
4141 Builder.CreateIntrinsic(ID: ReducedOp, Types: {FinalVecV->getType()}, Args: {FinalVecV});
4142 replaceValue(Old&: I, New&: *ReducedResult);
4143
4144 return true;
4145}
4146
4147/// Determine if its more efficient to fold:
4148/// reduce(trunc(x)) -> trunc(reduce(x)).
4149/// reduce(sext(x)) -> sext(reduce(x)).
4150/// reduce(zext(x)) -> zext(reduce(x)).
4151bool VectorCombine::foldCastFromReductions(Instruction &I) {
4152 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
4153 if (!II)
4154 return false;
4155
4156 bool TruncOnly = false;
4157 Intrinsic::ID IID = II->getIntrinsicID();
4158 switch (IID) {
4159 case Intrinsic::vector_reduce_add:
4160 case Intrinsic::vector_reduce_mul:
4161 TruncOnly = true;
4162 break;
4163 case Intrinsic::vector_reduce_and:
4164 case Intrinsic::vector_reduce_or:
4165 case Intrinsic::vector_reduce_xor:
4166 break;
4167 default:
4168 return false;
4169 }
4170
4171 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4172 Value *ReductionSrc = I.getOperand(i: 0);
4173
4174 Value *Src;
4175 if (!match(V: ReductionSrc, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: Src)))) &&
4176 (TruncOnly || !match(V: ReductionSrc, P: m_OneUse(SubPattern: m_ZExtOrSExt(Op: m_Value(V&: Src))))))
4177 return false;
4178
4179 auto CastOpc =
4180 (Instruction::CastOps)cast<Instruction>(Val: ReductionSrc)->getOpcode();
4181
4182 auto *SrcTy = cast<VectorType>(Val: Src->getType());
4183 auto *ReductionSrcTy = cast<VectorType>(Val: ReductionSrc->getType());
4184 Type *ResultTy = I.getType();
4185
4186 InstructionCost OldCost = TTI.getArithmeticReductionCost(
4187 Opcode: ReductionOpc, Ty: ReductionSrcTy, FMF: std::nullopt, CostKind);
4188 OldCost += TTI.getCastInstrCost(Opcode: CastOpc, Dst: ReductionSrcTy, Src: SrcTy,
4189 CCH: TTI::CastContextHint::None, CostKind,
4190 I: cast<CastInst>(Val: ReductionSrc));
4191 InstructionCost NewCost =
4192 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: SrcTy, FMF: std::nullopt,
4193 CostKind) +
4194 TTI.getCastInstrCost(Opcode: CastOpc, Dst: ResultTy, Src: ReductionSrcTy->getScalarType(),
4195 CCH: TTI::CastContextHint::None, CostKind);
4196
4197 if (OldCost <= NewCost || !NewCost.isValid())
4198 return false;
4199
4200 Value *NewReduction = Builder.CreateIntrinsic(RetTy: SrcTy->getScalarType(),
4201 ID: II->getIntrinsicID(), Args: {Src});
4202 Value *NewCast = Builder.CreateCast(Op: CastOpc, V: NewReduction, DestTy: ResultTy);
4203 replaceValue(Old&: I, New&: *NewCast);
4204 return true;
4205}
4206
4207/// Fold:
4208/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4209/// into:
4210/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4211///
4212/// Sign-bit reductions produce values with known semantics:
4213/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4214/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4215/// - reduce.add: count of negative elements (0 to NumElts)
4216///
4217/// Both lshr and ashr are supported:
4218/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4219/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4220///
4221/// The fold generalizes to multiple source vectors combined with the same
4222/// operation as the reduction. For example:
4223/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4224/// For reduce.add, this changes the count to M*N where M is the number of
4225/// source vectors.
4226///
4227/// We transform to a direct sign check on the original vector using
4228/// reduce.{or,umax} or reduce.{and,umin}.
4229///
4230/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4231bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4232 CmpPredicate Pred;
4233 IntrinsicInst *ReduceOp;
4234 const APInt *CmpVal;
4235 if (!match(V: &I,
4236 P: m_ICmp(Pred, L: m_OneUse(SubPattern: m_AnyIntrinsic(I&: ReduceOp)), R: m_APInt(Res&: CmpVal))))
4237 return false;
4238
4239 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4240 switch (OrigIID) {
4241 case Intrinsic::vector_reduce_or:
4242 case Intrinsic::vector_reduce_umax:
4243 case Intrinsic::vector_reduce_and:
4244 case Intrinsic::vector_reduce_umin:
4245 case Intrinsic::vector_reduce_add:
4246 break;
4247 default:
4248 return false;
4249 }
4250
4251 Value *ReductionSrc = ReduceOp->getArgOperand(i: 0);
4252 auto *VecTy = dyn_cast<FixedVectorType>(Val: ReductionSrc->getType());
4253 if (!VecTy)
4254 return false;
4255
4256 unsigned BitWidth = VecTy->getScalarSizeInBits();
4257 if (BitWidth == 1)
4258 return false;
4259
4260 unsigned NumElts = VecTy->getNumElements();
4261
4262 // Determine the expected tree opcode for multi-vector patterns.
4263 // The tree opcode must match the reduction's underlying operation.
4264 //
4265 // TODO: for pairs of equivalent operators, we should match both,
4266 // not only the most common.
4267 Instruction::BinaryOps TreeOpcode;
4268 switch (OrigIID) {
4269 case Intrinsic::vector_reduce_or:
4270 case Intrinsic::vector_reduce_umax:
4271 TreeOpcode = Instruction::Or;
4272 break;
4273 case Intrinsic::vector_reduce_and:
4274 case Intrinsic::vector_reduce_umin:
4275 TreeOpcode = Instruction::And;
4276 break;
4277 case Intrinsic::vector_reduce_add:
4278 TreeOpcode = Instruction::Add;
4279 break;
4280 default:
4281 llvm_unreachable("Unexpected intrinsic");
4282 }
4283
4284 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4285 // The tree conceptually extends the vector being reduced.
4286 SmallVector<Value *, 8> Worklist;
4287 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4288 Worklist.push_back(Elt: ReductionSrc);
4289 std::optional<bool> IsAShr;
4290 constexpr unsigned MaxSources = 8;
4291
4292 // Calculate old cost: all shifts + tree ops + reduction
4293 InstructionCost OldCost = TTI.getInstructionCost(U: ReduceOp, CostKind);
4294
4295 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4296 Sources.size() <= MaxSources) {
4297 Value *V = Worklist.pop_back_val();
4298
4299 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4300 Value *X;
4301 if (match(V, P: m_OneUse(SubPattern: m_Shr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth - 1))))) {
4302 auto *Shr = cast<Instruction>(Val: V);
4303
4304 // All shifts must be the same type (all lshr or all ashr)
4305 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4306 if (!IsAShr)
4307 IsAShr = ThisIsAShr;
4308 else if (*IsAShr != ThisIsAShr)
4309 return false;
4310
4311 Sources.push_back(Elt: X);
4312
4313 // As part of the fold, we remove all of the shifts, so we need to keep
4314 // track of their costs.
4315 OldCost += TTI.getInstructionCost(U: Shr, CostKind);
4316
4317 continue;
4318 }
4319
4320 // Try to extend through a tree node of the expected opcode
4321 Value *A, *B;
4322 if (!match(V, P: m_OneUse(SubPattern: m_BinOp(Opcode: TreeOpcode, L: m_Value(V&: A), R: m_Value(V&: B)))))
4323 return false;
4324
4325 // We are potentially replacing these operations as well, so we add them
4326 // to the costs.
4327 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: V), CostKind);
4328
4329 Worklist.push_back(Elt: A);
4330 Worklist.push_back(Elt: B);
4331 }
4332
4333 // Must have at least one source and not exceed limit
4334 if (Sources.empty() || Sources.size() > MaxSources ||
4335 Worklist.size() > MaxSources || !IsAShr)
4336 return false;
4337
4338 unsigned NumSources = Sources.size();
4339
4340 // For reduce.add, the total count must fit as a signed integer.
4341 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4342 if (OrigIID == Intrinsic::vector_reduce_add &&
4343 !isIntN(N: BitWidth, x: NumSources * NumElts))
4344 return false;
4345
4346 // Compute the boundary value when all elements are negative:
4347 // - Per-element contribution: 1 for lshr, -1 for ashr
4348 // - For add: M*N (total elements across all sources); for others: just 1
4349 unsigned Count =
4350 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4351 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4352 if (*IsAShr)
4353 NegativeVal.negate();
4354
4355 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4356 APInt Zero = APInt::getZero(numBits: CmpVal->getBitWidth());
4357 APInt RangeLow = APIntOps::smin(A: Zero, B: NegativeVal);
4358 APInt RangeHigh = APIntOps::smax(A: Zero, B: NegativeVal);
4359
4360 // Determine comparison semantics:
4361 // - IsEq: true for equality test, false for inequality
4362 // - TestsNegative: true if testing against AllNegVal, false for zero
4363 //
4364 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4365 // that fold to boundary tests given the narrow value range:
4366 // < RangeHigh -> != RangeHigh
4367 // > RangeHigh-1 -> == RangeHigh
4368 // > RangeLow -> != RangeLow
4369 // < RangeLow+1 -> == RangeLow
4370 //
4371 // For inequalities, we work with signed predicates only. Unsigned predicates
4372 // are canonicalized to signed when the range is non-negative (where they are
4373 // equivalent). When the range includes negative values, unsigned predicates
4374 // would have different semantics due to wrap-around, so we reject them.
4375 if (!ICmpInst::isEquality(P: Pred) && !ICmpInst::isSigned(Pred)) {
4376 if (RangeLow.isNegative())
4377 return false;
4378 Pred = ICmpInst::getSignedPredicate(Pred);
4379 }
4380
4381 bool IsEq;
4382 bool TestsNegative;
4383 if (ICmpInst::isEquality(P: Pred)) {
4384 if (CmpVal->isZero()) {
4385 TestsNegative = false;
4386 } else if (*CmpVal == NegativeVal) {
4387 TestsNegative = true;
4388 } else {
4389 return false;
4390 }
4391 IsEq = Pred == ICmpInst::ICMP_EQ;
4392 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4393 IsEq = false;
4394 TestsNegative = (RangeHigh == NegativeVal);
4395 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4396 IsEq = true;
4397 TestsNegative = (RangeHigh == NegativeVal);
4398 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4399 IsEq = false;
4400 TestsNegative = (RangeLow == NegativeVal);
4401 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4402 IsEq = true;
4403 TestsNegative = (RangeLow == NegativeVal);
4404 } else {
4405 return false;
4406 }
4407
4408 // For this fold we support four types of checks:
4409 //
4410 // 1. All lanes are negative - AllNeg
4411 // 2. All lanes are non-negative - AllNonNeg
4412 // 3. At least one negative lane - AnyNeg
4413 // 4. At least one non-negative lane - AnyNonNeg
4414 //
4415 // For each case, we can generate the following code:
4416 //
4417 // 1. AllNeg - reduce.and/umin(X) < 0
4418 // 2. AllNonNeg - reduce.or/umax(X) > -1
4419 // 3. AnyNeg - reduce.or/umax(X) < 0
4420 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4421 //
4422 // The table below shows the aggregation of all supported cases
4423 // using these four cases.
4424 //
4425 // Reduction | == 0 | != 0 | == MAX | != MAX
4426 // ------------+-----------+-----------+-----------+-----------
4427 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4428 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4429 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4430 //
4431 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4432 //
4433 // For easier codegen and check inversion, we use the following encoding:
4434 //
4435 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4436 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4437 // 3. Bit-1 === universal (1) or existential (0) check
4438 //
4439 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4440 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4441 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4442 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4443 //
4444 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4445 //
4446 enum CheckKind : unsigned {
4447 AnyNonNeg = 0b000,
4448 AllNeg = 0b011,
4449 AllNonNeg = 0b101,
4450 AnyNeg = 0b110,
4451 };
4452 // Return true if we fold this check into or/umax and false for and/umin
4453 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4454 // Return true if we should check if result is negative and false otherwise
4455 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4456 // Logically invert the check
4457 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4458
4459 CheckKind Base;
4460 switch (OrigIID) {
4461 case Intrinsic::vector_reduce_or:
4462 case Intrinsic::vector_reduce_umax:
4463 Base = TestsNegative ? AnyNeg : AllNonNeg;
4464 break;
4465 case Intrinsic::vector_reduce_and:
4466 case Intrinsic::vector_reduce_umin:
4467 Base = TestsNegative ? AllNeg : AnyNonNeg;
4468 break;
4469 case Intrinsic::vector_reduce_add:
4470 Base = TestsNegative ? AllNeg : AllNonNeg;
4471 break;
4472 default:
4473 llvm_unreachable("Unexpected intrinsic");
4474 }
4475
4476 CheckKind Check = IsEq ? Base : Invert(Base);
4477
4478 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4479 InstructionCost ArithCost =
4480 TTI.getArithmeticReductionCost(Opcode: getArithmeticReductionInstruction(RdxID: Arith),
4481 Ty: VecTy, FMF: std::nullopt, CostKind);
4482 InstructionCost MinMaxCost =
4483 TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: MinMax), Ty: VecTy,
4484 FMF: FastMathFlags(), CostKind);
4485 return ArithCost <= MinMaxCost ? std::make_pair(x&: Arith, y&: ArithCost)
4486 : std::make_pair(x&: MinMax, y&: MinMaxCost);
4487 };
4488
4489 // Choose output reduction based on encoding's MSB
4490 auto [NewIID, NewCost] = RequiresOr(Check)
4491 ? PickCheaper(Intrinsic::vector_reduce_or,
4492 Intrinsic::vector_reduce_umax)
4493 : PickCheaper(Intrinsic::vector_reduce_and,
4494 Intrinsic::vector_reduce_umin);
4495
4496 // Add cost of combining multiple sources with or/and
4497 if (NumSources > 1) {
4498 unsigned CombineOpc =
4499 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4500 NewCost += TTI.getArithmeticInstrCost(Opcode: CombineOpc, Ty: VecTy, CostKind) *
4501 (NumSources - 1);
4502 }
4503
4504 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4505 << OldCost << " vs NewCost: " << NewCost << "\n");
4506
4507 if (NewCost > OldCost)
4508 return false;
4509
4510 // Generate the combined input and reduction
4511 Builder.SetInsertPoint(&I);
4512 Type *ScalarTy = VecTy->getScalarType();
4513
4514 Value *Input;
4515 if (NumSources == 1) {
4516 Input = Sources[0];
4517 } else {
4518 // Combine sources with or/and based on check type
4519 Input = RequiresOr(Check) ? Builder.CreateOr(Ops: Sources)
4520 : Builder.CreateAnd(Ops: Sources);
4521 }
4522
4523 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: NewIID, Args: {Input});
4524 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(Arg: NewReduce)
4525 : Builder.CreateIsNotNeg(Arg: NewReduce);
4526 replaceValue(Old&: I, New&: *NewCmp);
4527 return true;
4528}
4529
4530/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4531///
4532/// We can prove it for cases when:
4533///
4534/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4535/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4536/// 2. f(x) == 0 <=> x == 0
4537///
4538/// From 1 and 2 (or 1' and 2), we can infer that
4539///
4540/// OP f(X_i) == 0 <=> OP X_i == 0.
4541///
4542/// (1)
4543/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4544/// (2)
4545/// <=> \forall i \in [1, N] X_i == 0
4546/// (1)
4547/// <=> OP(X_i) == 0
4548///
4549/// For some of the OP's and f's, we need to have domain constraints on X
4550/// to ensure properties 1 (or 1') and 2.
4551bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4552 CmpPredicate Pred;
4553 Value *Op;
4554 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op), R: m_Zero())) ||
4555 !ICmpInst::isEquality(P: Pred))
4556 return false;
4557
4558 auto *II = dyn_cast<IntrinsicInst>(Val: Op);
4559 if (!II)
4560 return false;
4561
4562 switch (II->getIntrinsicID()) {
4563 case Intrinsic::vector_reduce_add:
4564 case Intrinsic::vector_reduce_or:
4565 case Intrinsic::vector_reduce_umin:
4566 case Intrinsic::vector_reduce_umax:
4567 case Intrinsic::vector_reduce_smin:
4568 case Intrinsic::vector_reduce_smax:
4569 break;
4570 default:
4571 return false;
4572 }
4573
4574 Value *InnerOp = II->getArgOperand(i: 0);
4575
4576 // TODO: fixed vector type might be too restrictive
4577 if (!II->hasOneUse() || !isa<FixedVectorType>(Val: InnerOp->getType()))
4578 return false;
4579
4580 Value *X = nullptr;
4581
4582 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4583 //
4584 // 1. f(x) = shl nuw x, y for arbitrary y
4585 // 2. f(x) = mul nuw x, c for defined c != 0
4586 // 3. f(x) = zext x
4587 // 4. f(x) = sext x
4588 // 5. f(x) = neg x
4589 //
4590 if (!(match(V: InnerOp, P: m_NUWShl(L: m_Value(V&: X), R: m_Value())) || // Case 1
4591 match(V: InnerOp, P: m_NUWMul(L: m_Value(V&: X), R: m_NonZeroInt())) || // Case 2
4592 match(V: InnerOp, P: m_ZExt(Op: m_Value(V&: X))) || // Case 3
4593 match(V: InnerOp, P: m_SExt(Op: m_Value(V&: X))) || // Case 4
4594 match(V: InnerOp, P: m_Neg(V: m_Value(V&: X))) // Case 5
4595 ))
4596 return false;
4597
4598 SimplifyQuery S = SQ.getWithInstruction(I: &I);
4599 auto *XTy = cast<FixedVectorType>(Val: X->getType());
4600
4601 // Check for domain constraints for all supported reductions.
4602 //
4603 // a. OR X_i - has property 1 for every X
4604 // b. UMAX X_i - has property 1 for every X
4605 // c. UMIN X_i - has property 1' for every X
4606 // d. SMAX X_i - has property 1 for X >= 0
4607 // e. SMIN X_i - has property 1' for X >= 0
4608 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4609 //
4610 // In order for the proof to work, we need 1 (or 1') to be true for both
4611 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4612 //
4613 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4614 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4615 // of known bits, we can't reasonably hold knowledge of "either 0
4616 // or negative".
4617 switch (II->getIntrinsicID()) {
4618 case Intrinsic::vector_reduce_add: {
4619 // We need to check that both X_i and f(X_i) have enough leading
4620 // zeros to not overflow.
4621 KnownBits KnownX = computeKnownBits(V: X, Q: S);
4622 KnownBits KnownFX = computeKnownBits(V: InnerOp, Q: S);
4623 unsigned NumElems = XTy->getNumElements();
4624 // Adding N elements loses at most ceil(log2(N)) leading bits.
4625 unsigned LostBits = Log2_32_Ceil(Value: NumElems);
4626 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4627 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4628 // Need at least one leading zero left after summation to ensure no overflow
4629 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4630 return false;
4631
4632 // We are not checking whether X or f(X) are positive explicitly because
4633 // we implicitly checked for it when we checked if both cases have enough
4634 // leading zeros to not wrap addition.
4635 break;
4636 }
4637 case Intrinsic::vector_reduce_smin:
4638 case Intrinsic::vector_reduce_smax:
4639 // Check whether X >= 0 and f(X) >= 0
4640 if (!isKnownNonNegative(V: InnerOp, SQ: S) || !isKnownNonNegative(V: X, SQ: S))
4641 return false;
4642
4643 break;
4644 default:
4645 break;
4646 };
4647
4648 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4649 << *II << "\n");
4650
4651 // For zext/sext, check if the transform is profitable using cost model.
4652 // For other operations (shl, mul, neg), we're removing an instruction
4653 // while keeping the same reduction type, so it's always profitable.
4654 if (isa<ZExtInst>(Val: InnerOp) || isa<SExtInst>(Val: InnerOp)) {
4655 auto *FXTy = cast<FixedVectorType>(Val: InnerOp->getType());
4656 Intrinsic::ID IID = II->getIntrinsicID();
4657
4658 InstructionCost ExtCost = TTI.getCastInstrCost(
4659 Opcode: cast<CastInst>(Val: InnerOp)->getOpcode(), Dst: FXTy, Src: XTy,
4660 CCH: TTI::CastContextHint::None, CostKind, I: cast<CastInst>(Val: InnerOp));
4661
4662 InstructionCost OldReduceCost, NewReduceCost;
4663 switch (IID) {
4664 case Intrinsic::vector_reduce_add:
4665 case Intrinsic::vector_reduce_or:
4666 OldReduceCost = TTI.getArithmeticReductionCost(
4667 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: FXTy, FMF: std::nullopt, CostKind);
4668 NewReduceCost = TTI.getArithmeticReductionCost(
4669 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: XTy, FMF: std::nullopt, CostKind);
4670 break;
4671 case Intrinsic::vector_reduce_umin:
4672 case Intrinsic::vector_reduce_umax:
4673 case Intrinsic::vector_reduce_smin:
4674 case Intrinsic::vector_reduce_smax:
4675 OldReduceCost = TTI.getMinMaxReductionCost(
4676 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: FXTy, FMF: FastMathFlags(), CostKind);
4677 NewReduceCost = TTI.getMinMaxReductionCost(
4678 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: XTy, FMF: FastMathFlags(), CostKind);
4679 break;
4680 default:
4681 llvm_unreachable("Unexpected reduction");
4682 }
4683
4684 InstructionCost OldCost = OldReduceCost + ExtCost;
4685 InstructionCost NewCost =
4686 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4687
4688 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4689 << *InnerOp << "\n OldCost: " << OldCost
4690 << " vs NewCost: " << NewCost << "\n");
4691
4692 // We consider transformation to still be potentially beneficial even
4693 // when the costs are the same because we might remove a use from f(X)
4694 // and unlock other optimizations. Equal costs would just mean that we
4695 // didn't make it worse in the worst case.
4696 if (NewCost > OldCost)
4697 return false;
4698 }
4699
4700 // Since we support zext and sext as f, we might change the scalar type
4701 // of the intrinsic.
4702 Type *Ty = XTy->getScalarType();
4703 Value *NewReduce = Builder.CreateIntrinsic(RetTy: Ty, ID: II->getIntrinsicID(), Args: {X});
4704 Value *NewCmp =
4705 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::getNullValue(Ty));
4706 replaceValue(Old&: I, New&: *NewCmp);
4707 return true;
4708}
4709
4710/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4711/// based on cost, preserving the comparison semantics.
4712///
4713/// We use two fundamental properties for each pair:
4714///
4715/// 1. or(X) == 0 <=> umax(X) == 0
4716/// 2. or(X) == 1 <=> umax(X) == 1
4717/// 3. sign(or(X)) == sign(umax(X))
4718///
4719/// 1. and(X) == -1 <=> umin(X) == -1
4720/// 2. and(X) == -2 <=> umin(X) == -2
4721/// 3. sign(and(X)) == sign(umin(X))
4722///
4723/// From these we can infer the following transformations:
4724/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4725/// b. or(X) s< 0 <-> umax(X) s< 0
4726/// c. or(X) s> -1 <-> umax(X) s> -1
4727/// d. or(X) s< 1 <-> umax(X) s< 1
4728/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4729/// f. or(X) s< 2 <-> umax(X) s< 2
4730/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4731/// h. and(X) s< 0 <-> umin(X) s< 0
4732/// i. and(X) s> -1 <-> umin(X) s> -1
4733/// j. and(X) s> -2 <-> umin(X) s> -2
4734/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4735/// l. and(X) s> -3 <-> umin(X) s> -3
4736///
4737bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4738 CmpPredicate Pred;
4739 Value *ReduceOp;
4740 const APInt *CmpVal;
4741 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: ReduceOp), R: m_APInt(Res&: CmpVal))))
4742 return false;
4743
4744 auto *II = dyn_cast<IntrinsicInst>(Val: ReduceOp);
4745 if (!II || !II->hasOneUse())
4746 return false;
4747
4748 const auto IsValidOrUmaxCmp = [&]() {
4749 // or === umax for i1
4750 if (CmpVal->getBitWidth() == 1)
4751 return true;
4752
4753 // Cases a and e
4754 bool IsEquality =
4755 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(P: Pred);
4756 // Case c
4757 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4758 // Cases b, d, and f
4759 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4760 Pred == ICmpInst::ICMP_SLT;
4761 return IsEquality || IsPositive || IsNegative;
4762 };
4763
4764 const auto IsValidAndUminCmp = [&]() {
4765 // and === umin for i1
4766 if (CmpVal->getBitWidth() == 1)
4767 return true;
4768
4769 const auto LeadingOnes = CmpVal->countl_one();
4770
4771 // Cases g and k
4772 bool IsEquality =
4773 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4774 ICmpInst::isEquality(P: Pred);
4775 // Case h
4776 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4777 // Cases i, j, and l
4778 bool IsPositive =
4779 // if the number has at least N - 2 leading ones
4780 // and the two LSBs are:
4781 // - 1 x 1 -> -1
4782 // - 1 x 0 -> -2
4783 // - 0 x 1 -> -3
4784 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4785 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4786 return IsEquality || IsNegative || IsPositive;
4787 };
4788
4789 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4790 Intrinsic::ID AlternativeIID;
4791
4792 // Check if this is a valid comparison pattern and determine the alternate
4793 // reduction intrinsic.
4794 switch (OriginalIID) {
4795 case Intrinsic::vector_reduce_or:
4796 if (!IsValidOrUmaxCmp())
4797 return false;
4798 AlternativeIID = Intrinsic::vector_reduce_umax;
4799 break;
4800 case Intrinsic::vector_reduce_umax:
4801 if (!IsValidOrUmaxCmp())
4802 return false;
4803 AlternativeIID = Intrinsic::vector_reduce_or;
4804 break;
4805 case Intrinsic::vector_reduce_and:
4806 if (!IsValidAndUminCmp())
4807 return false;
4808 AlternativeIID = Intrinsic::vector_reduce_umin;
4809 break;
4810 case Intrinsic::vector_reduce_umin:
4811 if (!IsValidAndUminCmp())
4812 return false;
4813 AlternativeIID = Intrinsic::vector_reduce_and;
4814 break;
4815 default:
4816 return false;
4817 }
4818
4819 Value *X = II->getArgOperand(i: 0);
4820 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
4821 if (!VecTy)
4822 return false;
4823
4824 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4825 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4826 if (ReductionOpc != Instruction::ICmp)
4827 return TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecTy, FMF: std::nullopt,
4828 CostKind);
4829 return TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: VecTy,
4830 FMF: FastMathFlags(), CostKind);
4831 };
4832
4833 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4834 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4835
4836 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4837 << "\n OrigCost: " << OrigCost
4838 << " vs AltCost: " << AltCost << "\n");
4839
4840 if (AltCost >= OrigCost)
4841 return false;
4842
4843 Builder.SetInsertPoint(&I);
4844 Type *ScalarTy = VecTy->getScalarType();
4845 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: AlternativeIID, Args: {X});
4846 Value *NewCmp =
4847 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::get(Ty: ScalarTy, V: *CmpVal));
4848
4849 replaceValue(Old&: I, New&: *NewCmp);
4850 return true;
4851}
4852
4853/// Returns true if this ShuffleVectorInst eventually feeds into a
4854/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
4855/// chains of shuffles and binary operators (in any combination/order).
4856/// The search does not go deeper than the given Depth.
4857static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI) {
4858 constexpr unsigned MaxVisited = 32;
4859 SmallPtrSet<Instruction *, 8> Visited;
4860 SmallVector<Instruction *, 4> WorkList;
4861 bool FoundReduction = false;
4862
4863 WorkList.push_back(Elt: SVI);
4864 while (!WorkList.empty()) {
4865 Instruction *I = WorkList.pop_back_val();
4866 for (User *U : I->users()) {
4867 auto *UI = cast<Instruction>(Val: U);
4868 if (!UI || !Visited.insert(Ptr: UI).second)
4869 continue;
4870 if (Visited.size() > MaxVisited)
4871 return false;
4872 if (auto *II = dyn_cast<IntrinsicInst>(Val: UI)) {
4873 // More than one reduction reached
4874 if (FoundReduction)
4875 return false;
4876 switch (II->getIntrinsicID()) {
4877 case Intrinsic::vector_reduce_add:
4878 case Intrinsic::vector_reduce_mul:
4879 case Intrinsic::vector_reduce_and:
4880 case Intrinsic::vector_reduce_or:
4881 case Intrinsic::vector_reduce_xor:
4882 case Intrinsic::vector_reduce_smin:
4883 case Intrinsic::vector_reduce_smax:
4884 case Intrinsic::vector_reduce_umin:
4885 case Intrinsic::vector_reduce_umax:
4886 FoundReduction = true;
4887 continue;
4888 default:
4889 return false;
4890 }
4891 }
4892
4893 if (!isa<BinaryOperator>(Val: UI) && !isa<ShuffleVectorInst>(Val: UI))
4894 return false;
4895
4896 WorkList.emplace_back(Args&: UI);
4897 }
4898 }
4899 return FoundReduction;
4900}
4901
4902/// This method looks for groups of shuffles acting on binops, of the form:
4903/// %x = shuffle ...
4904/// %y = shuffle ...
4905/// %a = binop %x, %y
4906/// %b = binop %x, %y
4907/// shuffle %a, %b, selectmask
4908/// We may, especially if the shuffle is wider than legal, be able to convert
4909/// the shuffle to a form where only parts of a and b need to be computed. On
4910/// architectures with no obvious "select" shuffle, this can reduce the total
4911/// number of operations if the target reports them as cheaper.
4912bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
4913 auto *SVI = cast<ShuffleVectorInst>(Val: &I);
4914 auto *VT = cast<FixedVectorType>(Val: I.getType());
4915 auto *Op0 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 0));
4916 auto *Op1 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 1));
4917 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
4918 VT != Op0->getType())
4919 return false;
4920
4921 auto *SVI0A = dyn_cast<Instruction>(Val: Op0->getOperand(i: 0));
4922 auto *SVI0B = dyn_cast<Instruction>(Val: Op0->getOperand(i: 1));
4923 auto *SVI1A = dyn_cast<Instruction>(Val: Op1->getOperand(i: 0));
4924 auto *SVI1B = dyn_cast<Instruction>(Val: Op1->getOperand(i: 1));
4925 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
4926 auto checkSVNonOpUses = [&](Instruction *I) {
4927 if (!I || I->getOperand(i: 0)->getType() != VT)
4928 return true;
4929 return any_of(Range: I->users(), P: [&](User *U) {
4930 return U != Op0 && U != Op1 &&
4931 !(isa<ShuffleVectorInst>(Val: U) &&
4932 (InputShuffles.contains(Ptr: cast<Instruction>(Val: U)) ||
4933 isInstructionTriviallyDead(I: cast<Instruction>(Val: U))));
4934 });
4935 };
4936 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
4937 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
4938 return false;
4939
4940 // Collect all the uses that are shuffles that we can transform together. We
4941 // may not have a single shuffle, but a group that can all be transformed
4942 // together profitably.
4943 SmallVector<ShuffleVectorInst *> Shuffles;
4944 auto collectShuffles = [&](Instruction *I) {
4945 for (auto *U : I->users()) {
4946 auto *SV = dyn_cast<ShuffleVectorInst>(Val: U);
4947 if (!SV || SV->getType() != VT)
4948 return false;
4949 if ((SV->getOperand(i_nocapture: 0) != Op0 && SV->getOperand(i_nocapture: 0) != Op1) ||
4950 (SV->getOperand(i_nocapture: 1) != Op0 && SV->getOperand(i_nocapture: 1) != Op1))
4951 return false;
4952 if (!llvm::is_contained(Range&: Shuffles, Element: SV))
4953 Shuffles.push_back(Elt: SV);
4954 }
4955 return true;
4956 };
4957 if (!collectShuffles(Op0) || !collectShuffles(Op1))
4958 return false;
4959 // From a reduction, we need to be processing a single shuffle, otherwise the
4960 // other uses will not be lane-invariant.
4961 if (FromReduction && Shuffles.size() > 1)
4962 return false;
4963
4964 // Add any shuffle uses for the shuffles we have found, to include them in our
4965 // cost calculations.
4966 if (!FromReduction) {
4967 for (ShuffleVectorInst *SV : Shuffles) {
4968 for (auto *U : SV->users()) {
4969 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(Val: U);
4970 if (SSV && isa<UndefValue>(Val: SSV->getOperand(i_nocapture: 1)) && SSV->getType() == VT)
4971 Shuffles.push_back(Elt: SSV);
4972 }
4973 }
4974 }
4975
4976 // For each of the output shuffles, we try to sort all the first vector
4977 // elements to the beginning, followed by the second array elements at the
4978 // end. If the binops are legalized to smaller vectors, this may reduce total
4979 // number of binops. We compute the ReconstructMask mask needed to convert
4980 // back to the original lane order.
4981 SmallVector<std::pair<int, int>> V1, V2;
4982 SmallVector<SmallVector<int>> OrigReconstructMasks;
4983 int MaxV1Elt = 0, MaxV2Elt = 0;
4984 unsigned NumElts = VT->getNumElements();
4985 for (ShuffleVectorInst *SVN : Shuffles) {
4986 SmallVector<int> Mask;
4987 SVN->getShuffleMask(Result&: Mask);
4988
4989 // Check the operands are the same as the original, or reversed (in which
4990 // case we need to commute the mask).
4991 Value *SVOp0 = SVN->getOperand(i_nocapture: 0);
4992 Value *SVOp1 = SVN->getOperand(i_nocapture: 1);
4993 if (isa<UndefValue>(Val: SVOp1)) {
4994 auto *SSV = cast<ShuffleVectorInst>(Val: SVOp0);
4995 SVOp0 = SSV->getOperand(i_nocapture: 0);
4996 SVOp1 = SSV->getOperand(i_nocapture: 1);
4997 for (int &Elem : Mask) {
4998 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
4999 return false;
5000 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elt: Elem);
5001 }
5002 }
5003 if (SVOp0 == Op1 && SVOp1 == Op0) {
5004 std::swap(a&: SVOp0, b&: SVOp1);
5005 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumElts);
5006 }
5007 if (SVOp0 != Op0 || SVOp1 != Op1)
5008 return false;
5009
5010 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5011 // take the packed values from Op0/Op1 and reconstructing to the original
5012 // order.
5013 SmallVector<int> ReconstructMask;
5014 for (unsigned I = 0; I < Mask.size(); I++) {
5015 if (Mask[I] < 0) {
5016 ReconstructMask.push_back(Elt: -1);
5017 } else if (Mask[I] < static_cast<int>(NumElts)) {
5018 MaxV1Elt = std::max(a: MaxV1Elt, b: Mask[I]);
5019 auto It = find_if(Range&: V1, P: [&](const std::pair<int, int> &A) {
5020 return Mask[I] == A.first;
5021 });
5022 if (It != V1.end())
5023 ReconstructMask.push_back(Elt: It - V1.begin());
5024 else {
5025 ReconstructMask.push_back(Elt: V1.size());
5026 V1.emplace_back(Args&: Mask[I], Args: V1.size());
5027 }
5028 } else {
5029 MaxV2Elt = std::max<int>(a: MaxV2Elt, b: Mask[I] - NumElts);
5030 auto It = find_if(Range&: V2, P: [&](const std::pair<int, int> &A) {
5031 return Mask[I] - static_cast<int>(NumElts) == A.first;
5032 });
5033 if (It != V2.end())
5034 ReconstructMask.push_back(Elt: NumElts + It - V2.begin());
5035 else {
5036 ReconstructMask.push_back(Elt: NumElts + V2.size());
5037 V2.emplace_back(Args: Mask[I] - NumElts, Args: NumElts + V2.size());
5038 }
5039 }
5040 }
5041
5042 // For reductions, we know that the lane ordering out doesn't alter the
5043 // result. In-order can help simplify the shuffle away.
5044 if (FromReduction)
5045 sort(C&: ReconstructMask);
5046 OrigReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5047 }
5048
5049 // If the Maximum element used from V1 and V2 are not larger than the new
5050 // vectors, the vectors are already packes and performing the optimization
5051 // again will likely not help any further. This also prevents us from getting
5052 // stuck in a cycle in case the costs do not also rule it out.
5053 if (V1.empty() || V2.empty() ||
5054 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5055 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5056 return false;
5057
5058 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5059 // shuffle of another shuffle, or not a shuffle (that is treated like a
5060 // identity shuffle).
5061 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5062 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5063 if (!SV)
5064 return M;
5065 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5066 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5067 if (InputShuffles.contains(Ptr: SSV))
5068 return SSV->getMaskValue(Elt: SV->getMaskValue(Elt: M));
5069 return SV->getMaskValue(Elt: M);
5070 };
5071
5072 // Attempt to sort the inputs my ascending mask values to make simpler input
5073 // shuffles and push complex shuffles down to the uses. We sort on the first
5074 // of the two input shuffle orders, to try and get at least one input into a
5075 // nice order.
5076 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5077 std::pair<int, int> Y) {
5078 int MXA = GetBaseMaskValue(A, X.first);
5079 int MYA = GetBaseMaskValue(A, Y.first);
5080 return MXA < MYA;
5081 };
5082 stable_sort(Range&: V1, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5083 return SortBase(SVI0A, A, B);
5084 });
5085 stable_sort(Range&: V2, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5086 return SortBase(SVI1A, A, B);
5087 });
5088 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5089 // modified order of the input shuffles.
5090 SmallVector<SmallVector<int>> ReconstructMasks;
5091 for (const auto &Mask : OrigReconstructMasks) {
5092 SmallVector<int> ReconstructMask;
5093 for (int M : Mask) {
5094 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5095 auto It = find_if(Range: V, P: [M](auto A) { return A.second == M; });
5096 assert(It != V.end() && "Expected all entries in Mask");
5097 return std::distance(first: V.begin(), last: It);
5098 };
5099 if (M < 0)
5100 ReconstructMask.push_back(Elt: -1);
5101 else if (M < static_cast<int>(NumElts)) {
5102 ReconstructMask.push_back(Elt: FindIndex(V1, M));
5103 } else {
5104 ReconstructMask.push_back(Elt: NumElts + FindIndex(V2, M));
5105 }
5106 }
5107 ReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5108 }
5109
5110 // Calculate the masks needed for the new input shuffles, which get padded
5111 // with undef
5112 SmallVector<int> V1A, V1B, V2A, V2B;
5113 for (unsigned I = 0; I < V1.size(); I++) {
5114 V1A.push_back(Elt: GetBaseMaskValue(SVI0A, V1[I].first));
5115 V1B.push_back(Elt: GetBaseMaskValue(SVI0B, V1[I].first));
5116 }
5117 for (unsigned I = 0; I < V2.size(); I++) {
5118 V2A.push_back(Elt: GetBaseMaskValue(SVI1A, V2[I].first));
5119 V2B.push_back(Elt: GetBaseMaskValue(SVI1B, V2[I].first));
5120 }
5121 while (V1A.size() < NumElts) {
5122 V1A.push_back(Elt: PoisonMaskElem);
5123 V1B.push_back(Elt: PoisonMaskElem);
5124 }
5125 while (V2A.size() < NumElts) {
5126 V2A.push_back(Elt: PoisonMaskElem);
5127 V2B.push_back(Elt: PoisonMaskElem);
5128 }
5129
5130 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5131 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5132 if (!SV)
5133 return C;
5134 return C + TTI.getShuffleCost(Kind: isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1))
5135 ? TTI::SK_PermuteSingleSrc
5136 : TTI::SK_PermuteTwoSrc,
5137 DstTy: VT, SrcTy: VT, Mask: SV->getShuffleMask(), CostKind);
5138 };
5139 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5140 return C +
5141 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5142 };
5143
5144 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5145 unsigned MaxVectorSize =
5146 TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector);
5147 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5148 if (MaxElementsInVector == 0)
5149 return false;
5150 // When there are multiple shufflevector operations on the same input,
5151 // especially when the vector length is larger than the register size,
5152 // identical shuffle patterns may occur across different groups of elements.
5153 // To avoid overestimating the cost by counting these repeated shuffles more
5154 // than once, we only account for unique shuffle patterns. This adjustment
5155 // prevents inflated costs in the cost model for wide vectors split into
5156 // several register-sized groups.
5157 std::set<SmallVector<int, 4>> UniqueShuffles;
5158 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5159 // Compute the cost for performing the shuffle over the full vector.
5160 auto ShuffleCost =
5161 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5162 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5163 if (NumFullVectors < 2)
5164 return C + ShuffleCost;
5165 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5166 unsigned NumUniqueGroups = 0;
5167 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5168 // For each group of MaxElementsInVector contiguous elements,
5169 // collect their shuffle pattern and insert into the set of unique patterns.
5170 for (unsigned I = 0; I < NumFullVectors; ++I) {
5171 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5172 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5173 if (UniqueShuffles.insert(x: SubShuffle).second)
5174 NumUniqueGroups += 1;
5175 }
5176 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5177 };
5178 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5179 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5180 if (!SV)
5181 return C;
5182 SmallVector<int, 16> Mask;
5183 SV->getShuffleMask(Result&: Mask);
5184 return AddShuffleMaskAdjustedCost(C, Mask);
5185 };
5186 // Check that input consists of ShuffleVectors applied to the same input
5187 auto AllShufflesHaveSameOperands =
5188 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5189 if (InputShuffles.size() < 2)
5190 return false;
5191 ShuffleVectorInst *FirstSV =
5192 dyn_cast<ShuffleVectorInst>(Val: *InputShuffles.begin());
5193 if (!FirstSV)
5194 return false;
5195
5196 Value *In0 = FirstSV->getOperand(i_nocapture: 0), *In1 = FirstSV->getOperand(i_nocapture: 1);
5197 return std::all_of(
5198 first: std::next(x: InputShuffles.begin()), last: InputShuffles.end(),
5199 pred: [&](Instruction *I) {
5200 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5201 return SV && SV->getOperand(i_nocapture: 0) == In0 && SV->getOperand(i_nocapture: 1) == In1;
5202 });
5203 };
5204
5205 // Get the costs of the shuffles + binops before and after with the new
5206 // shuffle masks.
5207 InstructionCost CostBefore =
5208 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: VT, CostKind) +
5209 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: VT, CostKind);
5210 CostBefore += std::accumulate(first: Shuffles.begin(), last: Shuffles.end(),
5211 init: InstructionCost(0), binary_op: AddShuffleCost);
5212 if (AllShufflesHaveSameOperands(InputShuffles)) {
5213 UniqueShuffles.clear();
5214 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5215 init: InstructionCost(0), binary_op: AddShuffleAdjustedCost);
5216 } else {
5217 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5218 init: InstructionCost(0), binary_op: AddShuffleCost);
5219 }
5220
5221 // The new binops will be unused for lanes past the used shuffle lengths.
5222 // These types attempt to get the correct cost for that from the target.
5223 FixedVectorType *Op0SmallVT =
5224 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V1.size());
5225 FixedVectorType *Op1SmallVT =
5226 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V2.size());
5227 InstructionCost CostAfter =
5228 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: Op0SmallVT, CostKind) +
5229 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: Op1SmallVT, CostKind);
5230 UniqueShuffles.clear();
5231 CostAfter += std::accumulate(first: ReconstructMasks.begin(), last: ReconstructMasks.end(),
5232 init: InstructionCost(0), binary_op: AddShuffleMaskAdjustedCost);
5233 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5234 CostAfter +=
5235 std::accumulate(first: OutputShuffleMasks.begin(), last: OutputShuffleMasks.end(),
5236 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
5237
5238 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5239 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5240 << " vs CostAfter: " << CostAfter << "\n");
5241 if (CostBefore < CostAfter ||
5242 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5243 return false;
5244
5245 // The cost model has passed, create the new instructions.
5246 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5247 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5248 if (!SV)
5249 return I;
5250 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5251 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5252 if (InputShuffles.contains(Ptr: SSV))
5253 return SSV->getOperand(i_nocapture: Op);
5254 return SV->getOperand(i_nocapture: Op);
5255 };
5256 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5257 Value *NSV0A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0A, 0),
5258 V2: GetShuffleOperand(SVI0A, 1), Mask: V1A);
5259 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5260 Value *NSV0B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0B, 0),
5261 V2: GetShuffleOperand(SVI0B, 1), Mask: V1B);
5262 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5263 Value *NSV1A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1A, 0),
5264 V2: GetShuffleOperand(SVI1A, 1), Mask: V2A);
5265 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5266 Value *NSV1B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1B, 0),
5267 V2: GetShuffleOperand(SVI1B, 1), Mask: V2B);
5268 Builder.SetInsertPoint(Op0);
5269 Value *NOp0 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op0->getOpcode(),
5270 LHS: NSV0A, RHS: NSV0B);
5271 if (auto *I = dyn_cast<Instruction>(Val: NOp0))
5272 I->copyIRFlags(V: Op0, IncludeWrapFlags: true);
5273 Builder.SetInsertPoint(Op1);
5274 Value *NOp1 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op1->getOpcode(),
5275 LHS: NSV1A, RHS: NSV1B);
5276 if (auto *I = dyn_cast<Instruction>(Val: NOp1))
5277 I->copyIRFlags(V: Op1, IncludeWrapFlags: true);
5278
5279 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5280 Builder.SetInsertPoint(Shuffles[S]);
5281 Value *NSV = Builder.CreateShuffleVector(V1: NOp0, V2: NOp1, Mask: ReconstructMasks[S]);
5282 replaceValue(Old&: *Shuffles[S], New&: *NSV, Erase: false);
5283 }
5284
5285 Worklist.pushValue(V: NSV0A);
5286 Worklist.pushValue(V: NSV0B);
5287 Worklist.pushValue(V: NSV1A);
5288 Worklist.pushValue(V: NSV1B);
5289 return true;
5290}
5291
5292/// Check if instruction depends on ZExt and this ZExt can be moved after the
5293/// instruction. Move ZExt if it is profitable. For example:
5294/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5295/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5296/// Cost model calculations takes into account if zext(x) has other users and
5297/// whether it can be propagated through them too.
5298bool VectorCombine::shrinkType(Instruction &I) {
5299 Value *ZExted, *OtherOperand;
5300 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_ZExt(Op: m_Value(V&: ZExted)),
5301 R: m_Value(V&: OtherOperand))) &&
5302 !match(V: &I, P: m_LShr(L: m_ZExt(Op: m_Value(V&: ZExted)), R: m_Value(V&: OtherOperand))))
5303 return false;
5304
5305 Value *ZExtOperand = I.getOperand(i: I.getOperand(i: 0) == OtherOperand ? 1 : 0);
5306
5307 auto *BigTy = cast<FixedVectorType>(Val: I.getType());
5308 auto *SmallTy = cast<FixedVectorType>(Val: ZExted->getType());
5309 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5310
5311 if (I.getOpcode() == Instruction::LShr) {
5312 // Check that the shift amount is less than the number of bits in the
5313 // smaller type. Otherwise, the smaller lshr will return a poison value.
5314 KnownBits ShAmtKB = computeKnownBits(V: I.getOperand(i: 1), DL: *DL);
5315 if (ShAmtKB.getMaxValue().uge(RHS: BW))
5316 return false;
5317 } else {
5318 // Check that the expression overall uses at most the same number of bits as
5319 // ZExted
5320 KnownBits KB = computeKnownBits(V: &I, DL: *DL);
5321 if (KB.countMaxActiveBits() > BW)
5322 return false;
5323 }
5324
5325 // Calculate costs of leaving current IR as it is and moving ZExt operation
5326 // later, along with adding truncates if needed
5327 InstructionCost ZExtCost = TTI.getCastInstrCost(
5328 Opcode: Instruction::ZExt, Dst: BigTy, Src: SmallTy,
5329 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5330 InstructionCost CurrentCost = ZExtCost;
5331 InstructionCost ShrinkCost = 0;
5332
5333 // Calculate total cost and check that we can propagate through all ZExt users
5334 for (User *U : ZExtOperand->users()) {
5335 auto *UI = cast<Instruction>(Val: U);
5336 if (UI == &I) {
5337 CurrentCost +=
5338 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5339 ShrinkCost +=
5340 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5341 ShrinkCost += ZExtCost;
5342 continue;
5343 }
5344
5345 if (!Instruction::isBinaryOp(Opcode: UI->getOpcode()))
5346 return false;
5347
5348 // Check if we can propagate ZExt through its other users
5349 KnownBits KB = computeKnownBits(V: UI, DL: *DL);
5350 if (KB.countMaxActiveBits() > BW)
5351 return false;
5352
5353 CurrentCost += TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5354 ShrinkCost +=
5355 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5356 ShrinkCost += ZExtCost;
5357 }
5358
5359 // If the other instruction operand is not a constant, we'll need to
5360 // generate a truncate instruction. So we have to adjust cost
5361 if (!isa<Constant>(Val: OtherOperand))
5362 ShrinkCost += TTI.getCastInstrCost(
5363 Opcode: Instruction::Trunc, Dst: SmallTy, Src: BigTy,
5364 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5365
5366 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5367 // towards modifying the IR because shrinking opens opportunities for other
5368 // shrinking optimisations.
5369 if (ShrinkCost > CurrentCost)
5370 return false;
5371
5372 Builder.SetInsertPoint(&I);
5373 Value *Op0 = ZExted;
5374 Value *Op1 = Builder.CreateTrunc(V: OtherOperand, DestTy: SmallTy);
5375 // Keep the order of operands the same
5376 if (I.getOperand(i: 0) == OtherOperand)
5377 std::swap(a&: Op0, b&: Op1);
5378 Value *NewBinOp =
5379 Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(), LHS: Op0, RHS: Op1);
5380 cast<Instruction>(Val: NewBinOp)->copyIRFlags(V: &I);
5381 cast<Instruction>(Val: NewBinOp)->copyMetadata(SrcInst: I);
5382 Value *NewZExtr = Builder.CreateZExt(V: NewBinOp, DestTy: BigTy);
5383 replaceValue(Old&: I, New&: *NewZExtr);
5384 return true;
5385}
5386
5387/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5388/// shuffle (DstVec, SrcVec, Mask)
5389bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5390 Value *DstVec, *SrcVec;
5391 uint64_t ExtIdx, InsIdx;
5392 if (!match(V: &I,
5393 P: m_InsertElt(Val: m_Value(V&: DstVec),
5394 Elt: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx)),
5395 Idx: m_ConstantInt(V&: InsIdx))))
5396 return false;
5397
5398 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
5399 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
5400 // We can try combining vectors with different element sizes.
5401 if (!DstVecTy || !SrcVecTy ||
5402 SrcVecTy->getElementType() != DstVecTy->getElementType())
5403 return false;
5404
5405 unsigned NumDstElts = DstVecTy->getNumElements();
5406 unsigned NumSrcElts = SrcVecTy->getNumElements();
5407 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5408 return false;
5409
5410 // Insertion into poison is a cheaper single operand shuffle.
5411 TargetTransformInfo::ShuffleKind SK;
5412 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5413
5414 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5415 bool NeedDstSrcSwap = isa<PoisonValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec);
5416 if (NeedDstSrcSwap) {
5417 SK = TargetTransformInfo::SK_PermuteSingleSrc;
5418 Mask[InsIdx] = ExtIdx % NumDstElts;
5419 std::swap(a&: DstVec, b&: SrcVec);
5420 } else {
5421 SK = TargetTransformInfo::SK_PermuteTwoSrc;
5422 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
5423 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5424 }
5425
5426 // Cost
5427 auto *Ins = cast<InsertElementInst>(Val: &I);
5428 auto *Ext = cast<ExtractElementInst>(Val: I.getOperand(i: 1));
5429 InstructionCost InsCost =
5430 TTI.getVectorInstrCost(I: *Ins, Val: DstVecTy, CostKind, Index: InsIdx);
5431 InstructionCost ExtCost =
5432 TTI.getVectorInstrCost(I: *Ext, Val: DstVecTy, CostKind, Index: ExtIdx);
5433 InstructionCost OldCost = ExtCost + InsCost;
5434
5435 InstructionCost NewCost = 0;
5436 SmallVector<int> ExtToVecMask;
5437 if (!NeedExpOrNarrow) {
5438 // Ignore 'free' identity insertion shuffle.
5439 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5440 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5441 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind, Index: 0,
5442 SubTp: nullptr, Args: {DstVec, SrcVec});
5443 } else {
5444 // When creating a length-changing-vector, always try to keep the relevant
5445 // element in an equivalent position, so that bulk shuffles are more likely
5446 // to be useful.
5447 ExtToVecMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
5448 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5449 // Add cost for expanding or narrowing
5450 NewCost = TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
5451 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: ExtToVecMask, CostKind);
5452 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind);
5453 }
5454
5455 if (!Ext->hasOneUse())
5456 NewCost += ExtCost;
5457
5458 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5459 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5460 << "\n");
5461
5462 if (OldCost < NewCost)
5463 return false;
5464
5465 if (NeedExpOrNarrow) {
5466 if (!NeedDstSrcSwap)
5467 SrcVec = Builder.CreateShuffleVector(V: SrcVec, Mask: ExtToVecMask);
5468 else
5469 DstVec = Builder.CreateShuffleVector(V: DstVec, Mask: ExtToVecMask);
5470 }
5471
5472 // Canonicalize undef param to RHS to help further folds.
5473 if (isa<UndefValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec)) {
5474 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumDstElts);
5475 std::swap(a&: DstVec, b&: SrcVec);
5476 }
5477
5478 Value *Shuf = Builder.CreateShuffleVector(V1: DstVec, V2: SrcVec, Mask);
5479 replaceValue(Old&: I, New&: *Shuf);
5480
5481 return true;
5482}
5483
5484/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5485/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5486/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5487/// before casting it back into `<vscale x 16 x i32>`.
5488bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5489 const APInt *SplatVal0, *SplatVal1;
5490 if (!match(V: &I, P: m_Intrinsic<Intrinsic::vector_interleave2>(
5491 Op0: m_APInt(Res&: SplatVal0), Op1: m_APInt(Res&: SplatVal1))))
5492 return false;
5493
5494 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5495 << "\n");
5496
5497 auto *VTy =
5498 cast<VectorType>(Val: cast<IntrinsicInst>(Val&: I).getArgOperand(i: 0)->getType());
5499 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5500 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5501
5502 // Just in case the cost of interleave2 intrinsic and bitcast are both
5503 // invalid, in which case we want to bail out, we use <= rather
5504 // than < here. Even they both have valid and equal costs, it's probably
5505 // not a good idea to emit a high-cost constant splat.
5506 if (TTI.getInstructionCost(U: &I, CostKind) <=
5507 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: I.getType(), Src: ExtVTy,
5508 CCH: TTI::CastContextHint::None, CostKind)) {
5509 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5510 << *I.getType() << " is too high.\n");
5511 return false;
5512 }
5513
5514 APInt NewSplatVal = SplatVal1->zext(width: Width * 2);
5515 NewSplatVal <<= Width;
5516 NewSplatVal |= SplatVal0->zext(width: Width * 2);
5517 auto *NewSplat = ConstantVector::getSplat(
5518 EC: ExtVTy->getElementCount(), Elt: ConstantInt::get(Context&: F.getContext(), V: NewSplatVal));
5519
5520 IRBuilder<> Builder(&I);
5521 replaceValue(Old&: I, New&: *Builder.CreateBitCast(V: NewSplat, DestTy: I.getType()));
5522 return true;
5523}
5524
5525// Attempt to shrink loads that are only used by shufflevector instructions.
5526bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
5527 auto *OldLoad = dyn_cast<LoadInst>(Val: &I);
5528 if (!OldLoad || !OldLoad->isSimple())
5529 return false;
5530
5531 auto *OldLoadTy = dyn_cast<FixedVectorType>(Val: OldLoad->getType());
5532 if (!OldLoadTy)
5533 return false;
5534
5535 unsigned const OldNumElements = OldLoadTy->getNumElements();
5536
5537 // Search all uses of load. If all uses are shufflevector instructions, and
5538 // the second operands are all poison values, find the minimum and maximum
5539 // indices of the vector elements referenced by all shuffle masks.
5540 // Otherwise return `std::nullopt`.
5541 using IndexRange = std::pair<int, int>;
5542 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
5543 IndexRange OutputRange = IndexRange(OldNumElements, -1);
5544 for (llvm::Use &Use : I.uses()) {
5545 // Ensure all uses match the required pattern.
5546 User *Shuffle = Use.getUser();
5547 ArrayRef<int> Mask;
5548
5549 if (!match(V: Shuffle,
5550 P: m_Shuffle(v1: m_Specific(V: OldLoad), v2: m_Undef(), mask: m_Mask(Mask))))
5551 return std::nullopt;
5552
5553 // Ignore shufflevector instructions that have no uses.
5554 if (Shuffle->use_empty())
5555 continue;
5556
5557 // Find the min and max indices used by the shufflevector instruction.
5558 for (int Index : Mask) {
5559 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
5560 OutputRange.first = std::min(a: Index, b: OutputRange.first);
5561 OutputRange.second = std::max(a: Index, b: OutputRange.second);
5562 }
5563 }
5564 }
5565
5566 if (OutputRange.second < OutputRange.first)
5567 return std::nullopt;
5568
5569 return OutputRange;
5570 };
5571
5572 // Get the range of vector elements used by shufflevector instructions.
5573 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
5574 unsigned const NewNumElements = Indices->second + 1u;
5575
5576 // If the range of vector elements is smaller than the full load, attempt
5577 // to create a smaller load.
5578 if (NewNumElements < OldNumElements) {
5579 IRBuilder Builder(&I);
5580 Builder.SetCurrentDebugLocation(I.getDebugLoc());
5581
5582 // Calculate costs of old and new ops.
5583 Type *ElemTy = OldLoadTy->getElementType();
5584 FixedVectorType *NewLoadTy = FixedVectorType::get(ElementType: ElemTy, NumElts: NewNumElements);
5585 Value *PtrOp = OldLoad->getPointerOperand();
5586
5587 InstructionCost OldCost = TTI.getMemoryOpCost(
5588 Opcode: Instruction::Load, Src: OldLoad->getType(), Alignment: OldLoad->getAlign(),
5589 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5590 InstructionCost NewCost =
5591 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: NewLoadTy, Alignment: OldLoad->getAlign(),
5592 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
5593
5594 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
5595 SmallVector<UseEntry, 4u> NewUses;
5596 unsigned const MaxIndex = NewNumElements * 2u;
5597
5598 for (llvm::Use &Use : I.uses()) {
5599 auto *Shuffle = cast<ShuffleVectorInst>(Val: Use.getUser());
5600
5601 // Ignore shufflevector instructions that have no uses.
5602 if (Shuffle->use_empty())
5603 continue;
5604
5605 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
5606
5607 // Create entry for new use.
5608 NewUses.push_back(Elt: {Shuffle, OldMask});
5609
5610 // Validate mask indices.
5611 for (int Index : OldMask) {
5612 if (Index >= static_cast<int>(MaxIndex))
5613 return false;
5614 }
5615
5616 // Update costs.
5617 OldCost +=
5618 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5619 SrcTy: OldLoadTy, Mask: OldMask, CostKind);
5620 NewCost +=
5621 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
5622 SrcTy: NewLoadTy, Mask: OldMask, CostKind);
5623 }
5624
5625 LLVM_DEBUG(
5626 dbgs() << "Found a load used only by shufflevector instructions: "
5627 << I << "\n OldCost: " << OldCost
5628 << " vs NewCost: " << NewCost << "\n");
5629
5630 if (OldCost < NewCost || !NewCost.isValid())
5631 return false;
5632
5633 // Create new load of smaller vector.
5634 auto *NewLoad = cast<LoadInst>(
5635 Val: Builder.CreateAlignedLoad(Ty: NewLoadTy, Ptr: PtrOp, Align: OldLoad->getAlign()));
5636 NewLoad->copyMetadata(SrcInst: I);
5637
5638 // Replace all uses.
5639 for (UseEntry &Use : NewUses) {
5640 ShuffleVectorInst *Shuffle = Use.first;
5641 std::vector<int> &NewMask = Use.second;
5642
5643 Builder.SetInsertPoint(Shuffle);
5644 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
5645 Value *NewShuffle = Builder.CreateShuffleVector(
5646 V1: NewLoad, V2: PoisonValue::get(T: NewLoadTy), Mask: NewMask);
5647
5648 replaceValue(Old&: *Shuffle, New&: *NewShuffle, Erase: false);
5649 }
5650
5651 return true;
5652 }
5653 }
5654 return false;
5655}
5656
5657// Attempt to narrow a phi of shufflevector instructions where the two incoming
5658// values have the same operands but different masks. If the two shuffle masks
5659// are offsets of one another we can use one branch to rotate the incoming
5660// vector and perform one larger shuffle after the phi.
5661bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
5662 auto *Phi = dyn_cast<PHINode>(Val: &I);
5663 if (!Phi || Phi->getNumIncomingValues() != 2u)
5664 return false;
5665
5666 Value *Op = nullptr;
5667 ArrayRef<int> Mask0;
5668 ArrayRef<int> Mask1;
5669
5670 if (!match(V: Phi->getOperand(i_nocapture: 0u),
5671 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op), v2: m_Poison(), mask: m_Mask(Mask0)))) ||
5672 !match(V: Phi->getOperand(i_nocapture: 1u),
5673 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Specific(V: Op), v2: m_Poison(), mask: m_Mask(Mask1)))))
5674 return false;
5675
5676 auto *Shuf = cast<ShuffleVectorInst>(Val: Phi->getOperand(i_nocapture: 0u));
5677
5678 // Ensure result vectors are wider than the argument vector.
5679 auto *InputVT = cast<FixedVectorType>(Val: Op->getType());
5680 auto *ResultVT = cast<FixedVectorType>(Val: Shuf->getType());
5681 auto const InputNumElements = InputVT->getNumElements();
5682
5683 if (InputNumElements >= ResultVT->getNumElements())
5684 return false;
5685
5686 // Take the difference of the two shuffle masks at each index. Ignore poison
5687 // values at the same index in both masks.
5688 SmallVector<int, 16> NewMask;
5689 NewMask.reserve(N: Mask0.size());
5690
5691 for (auto [M0, M1] : zip(t&: Mask0, u&: Mask1)) {
5692 if (M0 >= 0 && M1 >= 0)
5693 NewMask.push_back(Elt: M0 - M1);
5694 else if (M0 == -1 && M1 == -1)
5695 continue;
5696 else
5697 return false;
5698 }
5699
5700 // Ensure all elements of the new mask are equal. If the difference between
5701 // the incoming mask elements is the same, the two must be constant offsets
5702 // of one another.
5703 if (NewMask.empty() || !all_equal(Range&: NewMask))
5704 return false;
5705
5706 // Create new mask using difference of the two incoming masks.
5707 int MaskOffset = NewMask[0u];
5708 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
5709 NewMask.clear();
5710
5711 for (unsigned I = 0u; I < InputNumElements; ++I) {
5712 NewMask.push_back(Elt: Index);
5713 Index = (Index + 1u) % InputNumElements;
5714 }
5715
5716 // Calculate costs for worst cases and compare.
5717 auto const Kind = TTI::SK_PermuteSingleSrc;
5718 auto OldCost =
5719 std::max(a: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask0, CostKind),
5720 b: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind));
5721 auto NewCost = TTI.getShuffleCost(Kind, DstTy: InputVT, SrcTy: InputVT, Mask: NewMask, CostKind) +
5722 TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind);
5723
5724 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
5725 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5726 << "\n");
5727
5728 if (NewCost > OldCost)
5729 return false;
5730
5731 // Create new shuffles and narrowed phi.
5732 auto Builder = IRBuilder(Shuf);
5733 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
5734 auto *PoisonVal = PoisonValue::get(T: InputVT);
5735 auto *NewShuf0 = Builder.CreateShuffleVector(V1: Op, V2: PoisonVal, Mask: NewMask);
5736 Worklist.push(I: cast<Instruction>(Val: NewShuf0));
5737
5738 Builder.SetInsertPoint(Phi);
5739 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
5740 auto *NewPhi = Builder.CreatePHI(Ty: NewShuf0->getType(), NumReservedValues: 2u);
5741 NewPhi->addIncoming(V: NewShuf0, BB: Phi->getIncomingBlock(i: 0u));
5742 NewPhi->addIncoming(V: Op, BB: Phi->getIncomingBlock(i: 1u));
5743
5744 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
5745 PoisonVal = PoisonValue::get(T: NewPhi->getType());
5746 auto *NewShuf1 = Builder.CreateShuffleVector(V1: NewPhi, V2: PoisonVal, Mask: Mask1);
5747
5748 replaceValue(Old&: *Phi, New&: *NewShuf1);
5749 return true;
5750}
5751
5752/// This is the entry point for all transforms. Pass manager differences are
5753/// handled in the callers of this function.
5754bool VectorCombine::run() {
5755 if (DisableVectorCombine)
5756 return false;
5757
5758 // Don't attempt vectorization if the target does not support vectors.
5759 if (!TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(/*Vector*/ true)))
5760 return false;
5761
5762 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
5763
5764 auto FoldInst = [this](Instruction &I) {
5765 Builder.SetInsertPoint(&I);
5766 bool IsVectorType = isa<VectorType>(Val: I.getType());
5767 bool IsFixedVectorType = isa<FixedVectorType>(Val: I.getType());
5768 auto Opcode = I.getOpcode();
5769
5770 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
5771
5772 // These folds should be beneficial regardless of when this pass is run
5773 // in the optimization pipeline.
5774 // The type checking is for run-time efficiency. We can avoid wasting time
5775 // dispatching to folding functions if there's no chance of matching.
5776 if (IsFixedVectorType) {
5777 switch (Opcode) {
5778 case Instruction::InsertElement:
5779 if (vectorizeLoadInsert(I))
5780 return true;
5781 break;
5782 case Instruction::ShuffleVector:
5783 if (widenSubvectorLoad(I))
5784 return true;
5785 break;
5786 default:
5787 break;
5788 }
5789 }
5790
5791 // This transform works with scalable and fixed vectors
5792 // TODO: Identify and allow other scalable transforms
5793 if (IsVectorType) {
5794 if (scalarizeOpOrCmp(I))
5795 return true;
5796 if (scalarizeLoad(I))
5797 return true;
5798 if (scalarizeExtExtract(I))
5799 return true;
5800 if (scalarizeVPIntrinsic(I))
5801 return true;
5802 if (foldInterleaveIntrinsics(I))
5803 return true;
5804 }
5805
5806 if (Opcode == Instruction::Store)
5807 if (foldSingleElementStore(I))
5808 return true;
5809
5810 // If this is an early pipeline invocation of this pass, we are done.
5811 if (TryEarlyFoldsOnly)
5812 return false;
5813
5814 // Otherwise, try folds that improve codegen but may interfere with
5815 // early IR canonicalizations.
5816 // The type checking is for run-time efficiency. We can avoid wasting time
5817 // dispatching to folding functions if there's no chance of matching.
5818 if (IsFixedVectorType) {
5819 switch (Opcode) {
5820 case Instruction::InsertElement:
5821 if (foldInsExtFNeg(I))
5822 return true;
5823 if (foldInsExtBinop(I))
5824 return true;
5825 if (foldInsExtVectorToShuffle(I))
5826 return true;
5827 break;
5828 case Instruction::ShuffleVector:
5829 if (foldPermuteOfBinops(I))
5830 return true;
5831 if (foldShuffleOfBinops(I))
5832 return true;
5833 if (foldShuffleOfSelects(I))
5834 return true;
5835 if (foldShuffleOfCastops(I))
5836 return true;
5837 if (foldShuffleOfShuffles(I))
5838 return true;
5839 if (foldPermuteOfIntrinsic(I))
5840 return true;
5841 if (foldShufflesOfLengthChangingShuffles(I))
5842 return true;
5843 if (foldShuffleOfIntrinsics(I))
5844 return true;
5845 if (foldSelectShuffle(I))
5846 return true;
5847 if (foldShuffleToIdentity(I))
5848 return true;
5849 break;
5850 case Instruction::Load:
5851 if (shrinkLoadForShuffles(I))
5852 return true;
5853 break;
5854 case Instruction::BitCast:
5855 if (foldBitcastShuffle(I))
5856 return true;
5857 if (foldSelectsFromBitcast(I))
5858 return true;
5859 break;
5860 case Instruction::And:
5861 case Instruction::Or:
5862 case Instruction::Xor:
5863 if (foldBitOpOfCastops(I))
5864 return true;
5865 if (foldBitOpOfCastConstant(I))
5866 return true;
5867 break;
5868 case Instruction::PHI:
5869 if (shrinkPhiOfShuffles(I))
5870 return true;
5871 break;
5872 default:
5873 if (shrinkType(I))
5874 return true;
5875 break;
5876 }
5877 } else {
5878 switch (Opcode) {
5879 case Instruction::Call:
5880 if (foldShuffleFromReductions(I))
5881 return true;
5882 if (foldCastFromReductions(I))
5883 return true;
5884 break;
5885 case Instruction::ExtractElement:
5886 if (foldShuffleChainsToReduce(I))
5887 return true;
5888 break;
5889 case Instruction::ICmp:
5890 if (foldSignBitReductionCmp(I))
5891 return true;
5892 if (foldICmpEqZeroVectorReduce(I))
5893 return true;
5894 if (foldEquivalentReductionCmp(I))
5895 return true;
5896 [[fallthrough]];
5897 case Instruction::FCmp:
5898 if (foldExtractExtract(I))
5899 return true;
5900 break;
5901 case Instruction::Or:
5902 if (foldConcatOfBoolMasks(I))
5903 return true;
5904 [[fallthrough]];
5905 default:
5906 if (Instruction::isBinaryOp(Opcode)) {
5907 if (foldExtractExtract(I))
5908 return true;
5909 if (foldExtractedCmps(I))
5910 return true;
5911 if (foldBinopOfReductions(I))
5912 return true;
5913 }
5914 break;
5915 }
5916 }
5917 return false;
5918 };
5919
5920 bool MadeChange = false;
5921 for (BasicBlock &BB : F) {
5922 // Ignore unreachable basic blocks.
5923 if (!DT.isReachableFromEntry(A: &BB))
5924 continue;
5925 // Use early increment range so that we can erase instructions in loop.
5926 // make_early_inc_range is not applicable here, as the next iterator may
5927 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
5928 // We manually maintain the next instruction and update it when it is about
5929 // to be deleted.
5930 Instruction *I = &BB.front();
5931 while (I) {
5932 NextInst = I->getNextNode();
5933 if (!I->isDebugOrPseudoInst())
5934 MadeChange |= FoldInst(*I);
5935 I = NextInst;
5936 }
5937 }
5938
5939 NextInst = nullptr;
5940
5941 while (!Worklist.isEmpty()) {
5942 Instruction *I = Worklist.removeOne();
5943 if (!I)
5944 continue;
5945
5946 if (isInstructionTriviallyDead(I)) {
5947 eraseInstruction(I&: *I);
5948 continue;
5949 }
5950
5951 MadeChange |= FoldInst(*I);
5952 }
5953
5954 return MadeChange;
5955}
5956
5957PreservedAnalyses VectorCombinePass::run(Function &F,
5958 FunctionAnalysisManager &FAM) {
5959 auto &AC = FAM.getResult<AssumptionAnalysis>(IR&: F);
5960 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
5961 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
5962 AAResults &AA = FAM.getResult<AAManager>(IR&: F);
5963 const DataLayout *DL = &F.getDataLayout();
5964 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
5965 TryEarlyFoldsOnly);
5966 if (!Combiner.run())
5967 return PreservedAnalyses::all();
5968 PreservedAnalyses PA;
5969 PA.preserveSet<CFGAnalyses>();
5970 return PA;
5971}
5972