1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Vectorize/VectorCombine.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/Statistic.h"
21#include "llvm/Analysis/AssumptionCache.h"
22#include "llvm/Analysis/BasicAliasAnalysis.h"
23#include "llvm/Analysis/GlobalsModRef.h"
24#include "llvm/Analysis/InstSimplifyFolder.h"
25#include "llvm/Analysis/Loads.h"
26#include "llvm/Analysis/TargetFolder.h"
27#include "llvm/Analysis/TargetTransformInfo.h"
28#include "llvm/Analysis/ValueTracking.h"
29#include "llvm/Analysis/VectorUtils.h"
30#include "llvm/IR/Dominators.h"
31#include "llvm/IR/Function.h"
32#include "llvm/IR/IRBuilder.h"
33#include "llvm/IR/Instructions.h"
34#include "llvm/IR/PatternMatch.h"
35#include "llvm/Support/CommandLine.h"
36#include "llvm/Support/KnownBits.h"
37#include "llvm/Support/MathExtras.h"
38#include "llvm/Transforms/Utils/Local.h"
39#include "llvm/Transforms/Utils/LoopUtils.h"
40#include <numeric>
41#include <optional>
42#include <queue>
43#include <set>
44
45#define DEBUG_TYPE "vector-combine"
46#include "llvm/Transforms/Utils/InstructionWorklist.h"
47
48using namespace llvm;
49using namespace llvm::PatternMatch;
50
51STATISTIC(NumVecLoad, "Number of vector loads formed");
52STATISTIC(NumVecCmp, "Number of vector compares formed");
53STATISTIC(NumVecBO, "Number of vector binops formed");
54STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
55STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
56STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
57STATISTIC(NumScalarCmp, "Number of scalar compares formed");
58STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
59
60static cl::opt<bool> DisableVectorCombine(
61 "disable-vector-combine", cl::init(Val: false), cl::Hidden,
62 cl::desc("Disable all vector combine transforms"));
63
64static cl::opt<bool> DisableBinopExtractShuffle(
65 "disable-binop-extract-shuffle", cl::init(Val: false), cl::Hidden,
66 cl::desc("Disable binop extract to shuffle transforms"));
67
68static cl::opt<unsigned> MaxInstrsToScan(
69 "vector-combine-max-scan-instrs", cl::init(Val: 30), cl::Hidden,
70 cl::desc("Max number of instructions to scan for vector combining."));
71
72static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
73
74namespace {
75class VectorCombine {
76public:
77 VectorCombine(Function &F, const TargetTransformInfo &TTI,
78 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
79 const DataLayout *DL, TTI::TargetCostKind CostKind,
80 bool TryEarlyFoldsOnly)
81 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
82 DT(DT), AA(AA), DL(DL), CostKind(CostKind),
83 SQ(*DL, /*TLI=*/nullptr, &DT, &AC),
84 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
85
86 bool run();
87
88private:
89 Function &F;
90 IRBuilder<InstSimplifyFolder> Builder;
91 const TargetTransformInfo &TTI;
92 const DominatorTree &DT;
93 AAResults &AA;
94 const DataLayout *DL;
95 TTI::TargetCostKind CostKind;
96 const SimplifyQuery SQ;
97
98 /// If true, only perform beneficial early IR transforms. Do not introduce new
99 /// vector operations.
100 bool TryEarlyFoldsOnly;
101
102 InstructionWorklist Worklist;
103
104 /// Next instruction to iterate. It will be updated when it is erased by
105 /// RecursivelyDeleteTriviallyDeadInstructions.
106 Instruction *NextInst;
107
108 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
109 // parameter. That should be updated to specific sub-classes because the
110 // run loop was changed to dispatch on opcode.
111 bool vectorizeLoadInsert(Instruction &I);
112 bool widenSubvectorLoad(Instruction &I);
113 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
114 ExtractElementInst *Ext1,
115 unsigned PreferredExtractIndex) const;
116 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
117 const Instruction &I,
118 ExtractElementInst *&ConvertToShuffle,
119 unsigned PreferredExtractIndex);
120 Value *foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
121 Value *foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex, Instruction &I);
122 bool foldExtractExtract(Instruction &I);
123 bool foldInsExtFNeg(Instruction &I);
124 bool foldInsExtBinop(Instruction &I);
125 bool foldInsExtVectorToShuffle(Instruction &I);
126 bool foldBitOpOfCastops(Instruction &I);
127 bool foldBitOpOfCastConstant(Instruction &I);
128 bool foldBitcastShuffle(Instruction &I);
129 bool scalarizeOpOrCmp(Instruction &I);
130 bool scalarizeVPIntrinsic(Instruction &I);
131 bool foldExtractedCmps(Instruction &I);
132 bool foldSelectsFromBitcast(Instruction &I);
133 bool foldBinopOfReductions(Instruction &I);
134 bool foldSingleElementStore(Instruction &I);
135 bool scalarizeLoad(Instruction &I);
136 bool scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy, Value *Ptr);
137 bool scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy, Value *Ptr);
138 bool scalarizeExtExtract(Instruction &I);
139 bool foldConcatOfBoolMasks(Instruction &I);
140 bool foldPermuteOfBinops(Instruction &I);
141 bool foldShuffleOfBinops(Instruction &I);
142 bool foldShuffleOfSelects(Instruction &I);
143 bool foldShuffleOfCastops(Instruction &I);
144 bool foldShuffleOfShuffles(Instruction &I);
145 bool foldPermuteOfIntrinsic(Instruction &I);
146 bool foldShufflesOfLengthChangingShuffles(Instruction &I);
147 bool foldShuffleOfIntrinsics(Instruction &I);
148 bool foldShuffleToIdentity(Instruction &I);
149 bool foldShuffleFromReductions(Instruction &I);
150 bool foldShuffleChainsToReduce(Instruction &I);
151 bool foldCastFromReductions(Instruction &I);
152 bool foldSignBitReductionCmp(Instruction &I);
153 bool foldReductionZeroTest(Instruction &I);
154 bool foldICmpEqZeroVectorReduce(Instruction &I);
155 bool foldEquivalentReductionCmp(Instruction &I);
156 bool foldReduceAddCmpZero(Instruction &I);
157 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
158 bool foldInterleaveIntrinsics(Instruction &I);
159 bool foldDeinterleaveIntrinsics(Instruction &I);
160 bool foldBitcastOfVPLoad(Instruction &I);
161 bool foldBitOrderReverseAndSwap(Instruction &I);
162 bool shrinkType(Instruction &I);
163 bool shrinkLoadForShuffles(Instruction &I);
164 bool shrinkPhiOfShuffles(Instruction &I);
165
166 void replaceValue(Instruction &Old, Value &New, bool Erase = true) {
167 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
168 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
169 Old.replaceAllUsesWith(V: &New);
170 if (auto *NewI = dyn_cast<Instruction>(Val: &New)) {
171 New.takeName(V: &Old);
172 Worklist.pushUsersToWorkList(I&: *NewI);
173 Worklist.pushValue(V: NewI);
174 }
175 if (Erase && isInstructionTriviallyDead(I: &Old)) {
176 eraseInstruction(I&: Old);
177 } else {
178 Worklist.push(I: &Old);
179 }
180 }
181
182 void eraseInstruction(Instruction &I) {
183 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
184 SmallVector<Value *> Ops(I.operands());
185 Worklist.remove(I: &I);
186 I.eraseFromParent();
187
188 // Push remaining users of the operands and then the operand itself - allows
189 // further folds that were hindered by OneUse limits.
190 SmallPtrSet<Value *, 4> Visited;
191 for (Value *Op : Ops) {
192 if (!Visited.contains(Ptr: Op)) {
193 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
194 if (RecursivelyDeleteTriviallyDeadInstructions(
195 V: OpI, TLI: nullptr, MSSAU: nullptr, AboutToDeleteCallback: [&](Value *V) {
196 if (auto *I = dyn_cast<Instruction>(Val: V)) {
197 LLVM_DEBUG(dbgs() << "VC: Erased: " << *I << '\n');
198 Worklist.remove(I);
199 if (I == NextInst)
200 NextInst = NextInst->getNextNode();
201 Visited.insert(Ptr: I);
202 }
203 }))
204 continue;
205 Worklist.pushUsersToWorkList(I&: *OpI);
206 Worklist.pushValue(V: OpI);
207 }
208 }
209 }
210 }
211};
212} // namespace
213
214/// Return the source operand of a potentially bitcasted value. If there is no
215/// bitcast, return the input value itself.
216static Value *peekThroughBitcasts(Value *V) {
217 while (auto *BitCast = dyn_cast<BitCastInst>(Val: V))
218 V = BitCast->getOperand(i_nocapture: 0);
219 return V;
220}
221
222static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
223 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
224 // The widened load may load data from dirty regions or create data races
225 // non-existent in the source.
226 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
227 Load->getFunction()->hasFnAttribute(Kind: Attribute::SanitizeMemTag) ||
228 mustSuppressSpeculation(LI: *Load))
229 return false;
230
231 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
232 // sure we have all of our type-based constraints in place for this target.
233 Type *ScalarTy = Load->getType()->getScalarType();
234 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
235 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
236 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
237 ScalarSize % 8 != 0)
238 return false;
239
240 return true;
241}
242
243bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
244 // Match insert into fixed vector of scalar value.
245 // TODO: Handle non-zero insert index.
246 Value *Scalar;
247 if (!match(V: &I,
248 P: m_InsertElt(Val: m_Poison(), Elt: m_OneUse(SubPattern: m_Value(V&: Scalar)), Idx: m_ZeroInt())))
249 return false;
250
251 // Optionally match an extract from another vector.
252 Value *X;
253 bool HasExtract = match(V: Scalar, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ZeroInt()));
254 if (!HasExtract)
255 X = Scalar;
256
257 auto *Load = dyn_cast<LoadInst>(Val: X);
258 if (!canWidenLoad(Load, TTI))
259 return false;
260
261 Type *ScalarTy = Scalar->getType();
262 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
263 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
264
265 // Check safety of replacing the scalar load with a larger vector load.
266 // We use minimal alignment (maximum flexibility) because we only care about
267 // the dereferenceable region. When calculating cost and creating a new op,
268 // we may use a larger value based on alignment attributes.
269 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
270 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
271
272 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
273 auto *MinVecTy = VectorType::get(ElementType: ScalarTy, NumElements: MinVecNumElts, Scalable: false);
274 unsigned OffsetEltIndex = 0;
275 Align Alignment = Load->getAlign();
276 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: SQ.AC,
277 DT: SQ.DT)) {
278 // It is not safe to load directly from the pointer, but we can still peek
279 // through gep offsets and check if it safe to load from a base address with
280 // updated alignment. If it is, we can shuffle the element(s) into place
281 // after loading.
282 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(Ty: SrcPtr->getType());
283 APInt Offset(OffsetBitWidth, 0);
284 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL: *DL, Offset);
285
286 // We want to shuffle the result down from a high element of a vector, so
287 // the offset must be positive.
288 if (Offset.isNegative())
289 return false;
290
291 // The offset must be a multiple of the scalar element to shuffle cleanly
292 // in the element's size.
293 uint64_t ScalarSizeInBytes = ScalarSize / 8;
294 if (Offset.urem(RHS: ScalarSizeInBytes) != 0)
295 return false;
296
297 // If we load MinVecNumElts, will our target element still be loaded?
298 OffsetEltIndex = Offset.udiv(RHS: ScalarSizeInBytes).getZExtValue();
299 if (OffsetEltIndex >= MinVecNumElts)
300 return false;
301
302 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load,
303 AC: SQ.AC, DT: SQ.DT))
304 return false;
305
306 // Update alignment with offset value. Note that the offset could be negated
307 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
308 // negation does not change the result of the alignment calculation.
309 Alignment = commonAlignment(A: Alignment, Offset: Offset.getZExtValue());
310 }
311
312 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
313 // Use the greater of the alignment on the load or its source pointer.
314 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
315 Type *LoadTy = Load->getType();
316 unsigned AS = Load->getPointerAddressSpace();
317 InstructionCost OldCost =
318 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
319 APInt DemandedElts = APInt::getOneBitSet(numBits: MinVecNumElts, BitNo: 0);
320 OldCost +=
321 TTI.getScalarizationOverhead(Ty: MinVecTy, DemandedElts,
322 /* Insert */ true, Extract: HasExtract, CostKind);
323
324 // New pattern: load VecPtr
325 InstructionCost NewCost =
326 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: MinVecTy, Alignment, AddressSpace: AS, CostKind);
327 // Optionally, we are shuffling the loaded vector element(s) into place.
328 // For the mask set everything but element 0 to undef to prevent poison from
329 // propagating from the extra loaded memory. This will also optionally
330 // shrink/grow the vector from the loaded size to the output size.
331 // We assume this operation has no cost in codegen if there was no offset.
332 // Note that we could use freeze to avoid poison problems, but then we might
333 // still need a shuffle to change the vector size.
334 auto *Ty = cast<FixedVectorType>(Val: I.getType());
335 unsigned OutputNumElts = Ty->getNumElements();
336 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
337 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
338 Mask[0] = OffsetEltIndex;
339 if (OffsetEltIndex)
340 NewCost += TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Ty, SrcTy: MinVecTy, Mask,
341 CostKind);
342
343 // We can aggressively convert to the vector form because the backend can
344 // invert this transform if it does not result in a performance win.
345 if (OldCost < NewCost || !NewCost.isValid())
346 return false;
347
348 // It is safe and potentially profitable to load a vector directly:
349 // inselt undef, load Scalar, 0 --> load VecPtr
350 IRBuilder<> Builder(Load);
351 Value *CastedPtr =
352 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
353 Value *VecLd = Builder.CreateAlignedLoad(Ty: MinVecTy, Ptr: CastedPtr, Align: Alignment);
354 VecLd = Builder.CreateShuffleVector(V: VecLd, Mask);
355
356 replaceValue(Old&: I, New&: *VecLd);
357 ++NumVecLoad;
358 return true;
359}
360
361/// If we are loading a vector and then inserting it into a larger vector with
362/// undefined elements, try to load the larger vector and eliminate the insert.
363/// This removes a shuffle in IR and may allow combining of other loaded values.
364bool VectorCombine::widenSubvectorLoad(Instruction &I) {
365 // Match subvector insert of fixed vector.
366 auto *Shuf = cast<ShuffleVectorInst>(Val: &I);
367 if (!Shuf->isIdentityWithPadding())
368 return false;
369
370 // Allow a non-canonical shuffle mask that is choosing elements from op1.
371 unsigned NumOpElts =
372 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
373 unsigned OpIndex = any_of(Range: Shuf->getShuffleMask(), P: [&NumOpElts](int M) {
374 return M >= (int)(NumOpElts);
375 });
376
377 auto *Load = dyn_cast<LoadInst>(Val: Shuf->getOperand(i_nocapture: OpIndex));
378 if (!canWidenLoad(Load, TTI))
379 return false;
380
381 // We use minimal alignment (maximum flexibility) because we only care about
382 // the dereferenceable region. When calculating cost and creating a new op,
383 // we may use a larger value based on alignment attributes.
384 auto *Ty = cast<FixedVectorType>(Val: I.getType());
385 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
386 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
387 Align Alignment = Load->getAlign();
388 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: SQ.AC,
389 DT: SQ.DT))
390 return false;
391
392 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
393 Type *LoadTy = Load->getType();
394 unsigned AS = Load->getPointerAddressSpace();
395
396 // Original pattern: insert_subvector (load PtrOp)
397 // This conservatively assumes that the cost of a subvector insert into an
398 // undef value is 0. We could add that cost if the cost model accurately
399 // reflects the real cost of that operation.
400 InstructionCost OldCost =
401 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
402
403 // New pattern: load PtrOp
404 InstructionCost NewCost =
405 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: Ty, Alignment, AddressSpace: AS, CostKind);
406
407 // We can aggressively convert to the vector form because the backend can
408 // invert this transform if it does not result in a performance win.
409 if (OldCost < NewCost || !NewCost.isValid())
410 return false;
411
412 IRBuilder<> Builder(Load);
413 Value *CastedPtr =
414 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
415 Value *VecLd = Builder.CreateAlignedLoad(Ty, Ptr: CastedPtr, Align: Alignment);
416 replaceValue(Old&: I, New&: *VecLd);
417 ++NumVecLoad;
418 return true;
419}
420
421/// Determine which, if any, of the inputs should be replaced by a shuffle
422/// followed by extract from a different index.
423ExtractElementInst *VectorCombine::getShuffleExtract(
424 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
425 unsigned PreferredExtractIndex = InvalidIndex) const {
426 auto *Index0C = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
427 auto *Index1C = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
428 assert(Index0C && Index1C && "Expected constant extract indexes");
429
430 unsigned Index0 = Index0C->getZExtValue();
431 unsigned Index1 = Index1C->getZExtValue();
432
433 // If the extract indexes are identical, no shuffle is needed.
434 if (Index0 == Index1)
435 return nullptr;
436
437 Type *VecTy = Ext0->getVectorOperand()->getType();
438 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
439 InstructionCost Cost0 =
440 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
441 InstructionCost Cost1 =
442 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
443
444 // If both costs are invalid no shuffle is needed
445 if (!Cost0.isValid() && !Cost1.isValid())
446 return nullptr;
447
448 // We are extracting from 2 different indexes, so one operand must be shuffled
449 // before performing a vector operation and/or extract. The more expensive
450 // extract will be replaced by a shuffle.
451 if (Cost0 > Cost1)
452 return Ext0;
453 if (Cost1 > Cost0)
454 return Ext1;
455
456 // If the costs are equal and there is a preferred extract index, shuffle the
457 // opposite operand.
458 if (PreferredExtractIndex == Index0)
459 return Ext1;
460 if (PreferredExtractIndex == Index1)
461 return Ext0;
462
463 // Otherwise, replace the extract with the higher index.
464 return Index0 > Index1 ? Ext0 : Ext1;
465}
466
467/// Compare the relative costs of 2 extracts followed by scalar operation vs.
468/// vector operation(s) followed by extract. Return true if the existing
469/// instructions are cheaper than a vector alternative. Otherwise, return false
470/// and if one of the extracts should be transformed to a shufflevector, set
471/// \p ConvertToShuffle to that extract instruction.
472bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
473 ExtractElementInst *Ext1,
474 const Instruction &I,
475 ExtractElementInst *&ConvertToShuffle,
476 unsigned PreferredExtractIndex) {
477 auto *Ext0IndexC = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
478 auto *Ext1IndexC = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
479 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
480
481 unsigned Opcode = I.getOpcode();
482 Value *Ext0Src = Ext0->getVectorOperand();
483 Value *Ext1Src = Ext1->getVectorOperand();
484 Type *ScalarTy = Ext0->getType();
485 auto *VecTy = cast<VectorType>(Val: Ext0Src->getType());
486 InstructionCost ScalarOpCost, VectorOpCost;
487
488 // Get cost estimates for scalar and vector versions of the operation.
489 bool IsBinOp = Instruction::isBinaryOp(Opcode);
490 if (IsBinOp) {
491 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
492 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
493 } else {
494 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
495 "Expected a compare");
496 CmpInst::Predicate Pred = cast<CmpInst>(Val: I).getPredicate();
497 ScalarOpCost = TTI.getCmpSelInstrCost(
498 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
499 VectorOpCost = TTI.getCmpSelInstrCost(
500 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
501 }
502
503 // Get cost estimates for the extract elements. These costs will factor into
504 // both sequences.
505 unsigned Ext0Index = Ext0IndexC->getZExtValue();
506 unsigned Ext1Index = Ext1IndexC->getZExtValue();
507
508 InstructionCost Extract0Cost =
509 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Ext0Index);
510 InstructionCost Extract1Cost =
511 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Ext1Index);
512
513 // A more expensive extract will always be replaced by a splat shuffle.
514 // For example, if Ext0 is more expensive:
515 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
516 // extelt (opcode (splat V0, Ext0), V1), Ext1
517 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
518 // check the cost of creating a broadcast shuffle and shuffling both
519 // operands to element 0.
520 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
521 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
522 InstructionCost CheapExtractCost = std::min(a: Extract0Cost, b: Extract1Cost);
523
524 // Extra uses of the extracts mean that we include those costs in the
525 // vector total because those instructions will not be eliminated.
526 InstructionCost OldCost, NewCost;
527 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
528 // Handle a special case. If the 2 extracts are identical, adjust the
529 // formulas to account for that. The extra use charge allows for either the
530 // CSE'd pattern or an unoptimized form with identical values:
531 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
532 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(N: 2)
533 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
534 OldCost = CheapExtractCost + ScalarOpCost;
535 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
536 } else {
537 // Handle the general case. Each extract is actually a different value:
538 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
539 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
540 NewCost = VectorOpCost + CheapExtractCost +
541 !Ext0->hasOneUse() * Extract0Cost +
542 !Ext1->hasOneUse() * Extract1Cost;
543 }
544
545 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
546 if (ConvertToShuffle) {
547 if (IsBinOp && DisableBinopExtractShuffle)
548 return true;
549
550 // If we are extracting from 2 different indexes, then one operand must be
551 // shuffled before performing the vector operation. The shuffle mask is
552 // poison except for 1 lane that is being translated to the remaining
553 // extraction lane. Therefore, it is a splat shuffle. Ex:
554 // ShufMask = { poison, poison, 0, poison }
555 // TODO: The cost model has an option for a "broadcast" shuffle
556 // (splat-from-element-0), but no option for a more general splat.
557 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: VecTy)) {
558 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
559 PoisonMaskElem);
560 ShuffleMask[BestInsIndex] = BestExtIndex;
561 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
562 DstTy: VecTy, SrcTy: VecTy, Mask: ShuffleMask, CostKind, Index: 0,
563 SubTp: nullptr, Args: {ConvertToShuffle});
564 } else {
565 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
566 DstTy: VecTy, SrcTy: VecTy, Mask: {}, CostKind, Index: 0, SubTp: nullptr,
567 Args: {ConvertToShuffle});
568 }
569 }
570
571 // Aggressively form a vector op if the cost is equal because the transform
572 // may enable further optimization.
573 // Codegen can reverse this transform (scalarize) if it was not profitable.
574 return OldCost < NewCost;
575}
576
577/// Create a shuffle that translates (shifts) 1 element from the input vector
578/// to a new element location.
579static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
580 unsigned NewIndex, IRBuilderBase &Builder) {
581 // The shuffle mask is poison except for 1 lane that is being translated
582 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
583 // ShufMask = { 2, poison, poison, poison }
584 auto *VecTy = cast<FixedVectorType>(Val: Vec->getType());
585 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
586 ShufMask[NewIndex] = OldIndex;
587 return Builder.CreateShuffleVector(V: Vec, Mask: ShufMask, Name: "shift");
588}
589
590/// Given an extract element instruction with constant index operand, shuffle
591/// the source vector (shift the scalar element) to a NewIndex for extraction.
592/// Return null if the input can be constant folded, so that we are not creating
593/// unnecessary instructions.
594static Value *translateExtract(ExtractElementInst *ExtElt, unsigned NewIndex,
595 IRBuilderBase &Builder) {
596 // Shufflevectors can only be created for fixed-width vectors.
597 Value *X = ExtElt->getVectorOperand();
598 if (!isa<FixedVectorType>(Val: X->getType()))
599 return nullptr;
600
601 // If the extract can be constant-folded, this code is unsimplified. Defer
602 // to other passes to handle that.
603 Value *C = ExtElt->getIndexOperand();
604 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
605 if (isa<Constant>(Val: X))
606 return nullptr;
607
608 Value *Shuf = createShiftShuffle(Vec: X, OldIndex: cast<ConstantInt>(Val: C)->getZExtValue(),
609 NewIndex, Builder);
610 return Shuf;
611}
612
613/// Try to reduce extract element costs by converting scalar compares to vector
614/// compares followed by extract.
615/// cmp (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
616Value *VectorCombine::foldExtExtCmp(Value *V0, Value *V1, Value *ExtIndex,
617 Instruction &I) {
618 assert(isa<CmpInst>(&I) && "Expected a compare");
619
620 // cmp Pred (extelt V0, ExtIndex), (extelt V1, ExtIndex)
621 // --> extelt (cmp Pred V0, V1), ExtIndex
622 ++NumVecCmp;
623 CmpInst::Predicate Pred = cast<CmpInst>(Val: &I)->getPredicate();
624 Value *VecCmp = Builder.CreateCmp(Pred, LHS: V0, RHS: V1);
625 return Builder.CreateExtractElement(Vec: VecCmp, Idx: ExtIndex, Name: "foldExtExtCmp");
626}
627
628/// Try to reduce extract element costs by converting scalar binops to vector
629/// binops followed by extract.
630/// bo (ext0 V0, ExtIndex), (ext1 V1, ExtIndex)
631Value *VectorCombine::foldExtExtBinop(Value *V0, Value *V1, Value *ExtIndex,
632 Instruction &I) {
633 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
634
635 // bo (extelt V0, ExtIndex), (extelt V1, ExtIndex)
636 // --> extelt (bo V0, V1), ExtIndex
637 ++NumVecBO;
638 Value *VecBO = Builder.CreateBinOp(Opc: cast<BinaryOperator>(Val: &I)->getOpcode(), LHS: V0,
639 RHS: V1, Name: "foldExtExtBinop");
640
641 // All IR flags are safe to back-propagate because any potential poison
642 // created in unused vector elements is discarded by the extract.
643 if (auto *VecBOInst = dyn_cast<Instruction>(Val: VecBO))
644 VecBOInst->copyIRFlags(V: &I);
645
646 return Builder.CreateExtractElement(Vec: VecBO, Idx: ExtIndex, Name: "foldExtExtBinop");
647}
648
649/// Match an instruction with extracted vector operands.
650bool VectorCombine::foldExtractExtract(Instruction &I) {
651 // It is not safe to transform things like div, urem, etc. because we may
652 // create undefined behavior when executing those on unknown vector elements.
653 if (!isSafeToSpeculativelyExecute(I: &I))
654 return false;
655
656 Instruction *I0, *I1;
657 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
658 if (!match(V: &I, P: m_Cmp(Pred, L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))) &&
659 !match(V: &I, P: m_BinOp(L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))))
660 return false;
661
662 Value *V0, *V1;
663 uint64_t C0, C1;
664 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: V0), Idx: m_ConstantInt(V&: C0))) ||
665 !match(V: I1, P: m_ExtractElt(Val: m_Value(V&: V1), Idx: m_ConstantInt(V&: C1))) ||
666 V0->getType() != V1->getType())
667 return false;
668
669 // For fixed-width vectors, reject out-of-bounds extract indexes
670 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: V0->getType())) {
671 unsigned NumElts = FixedVecTy->getNumElements();
672 if (C0 >= NumElts || C1 >= NumElts)
673 return false;
674 }
675
676 // If the scalar value 'I' is going to be re-inserted into a vector, then try
677 // to create an extract to that same element. The extract/insert can be
678 // reduced to a "select shuffle".
679 // TODO: If we add a larger pattern match that starts from an insert, this
680 // probably becomes unnecessary.
681 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
682 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
683 uint64_t InsertIndex = InvalidIndex;
684 if (I.hasOneUse())
685 match(V: I.user_back(),
686 P: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ConstantInt(V&: InsertIndex)));
687
688 ExtractElementInst *ExtractToChange;
689 if (isExtractExtractCheap(Ext0, Ext1, I, ConvertToShuffle&: ExtractToChange, PreferredExtractIndex: InsertIndex))
690 return false;
691
692 Value *ExtOp0 = Ext0->getVectorOperand();
693 Value *ExtOp1 = Ext1->getVectorOperand();
694
695 if (ExtractToChange) {
696 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
697 Value *NewExtOp =
698 translateExtract(ExtElt: ExtractToChange, NewIndex: CheapExtractIdx, Builder);
699 if (!NewExtOp)
700 return false;
701 if (ExtractToChange == Ext0)
702 ExtOp0 = NewExtOp;
703 else
704 ExtOp1 = NewExtOp;
705 }
706
707 Value *ExtIndex = ExtractToChange == Ext0 ? Ext1->getIndexOperand()
708 : Ext0->getIndexOperand();
709 Value *NewExt = Pred != CmpInst::BAD_ICMP_PREDICATE
710 ? foldExtExtCmp(V0: ExtOp0, V1: ExtOp1, ExtIndex, I)
711 : foldExtExtBinop(V0: ExtOp0, V1: ExtOp1, ExtIndex, I);
712 Worklist.push(I: Ext0);
713 Worklist.push(I: Ext1);
714 replaceValue(Old&: I, New&: *NewExt);
715 return true;
716}
717
718/// Try to replace an extract + scalar fneg + insert with a vector fneg +
719/// shuffle.
720bool VectorCombine::foldInsExtFNeg(Instruction &I) {
721 // Match an insert (op (extract)) pattern.
722 Value *DstVec;
723 uint64_t ExtIdx, InsIdx;
724 Instruction *FNeg;
725 if (!match(V: &I, P: m_InsertElt(Val: m_Value(V&: DstVec), Elt: m_OneUse(SubPattern: m_Instruction(I&: FNeg)),
726 Idx: m_ConstantInt(V&: InsIdx))))
727 return false;
728
729 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
730 Value *SrcVec;
731 Instruction *Extract;
732 if (!match(V: FNeg, P: m_FNeg(X: m_CombineAnd(
733 Ps: m_Instruction(I&: Extract),
734 Ps: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx))))))
735 return false;
736
737 auto *DstVecTy = cast<FixedVectorType>(Val: DstVec->getType());
738 auto *DstVecScalarTy = DstVecTy->getScalarType();
739 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
740 if (!SrcVecTy || DstVecScalarTy != SrcVecTy->getScalarType())
741 return false;
742
743 // Ignore if insert/extract index is out of bounds or destination vector has
744 // one element
745 unsigned NumDstElts = DstVecTy->getNumElements();
746 unsigned NumSrcElts = SrcVecTy->getNumElements();
747 if (ExtIdx > NumSrcElts || InsIdx >= NumDstElts || NumDstElts == 1)
748 return false;
749
750 // We are inserting the negated element into the same lane that we extracted
751 // from. This is equivalent to a select-shuffle that chooses all but the
752 // negated element from the destination vector.
753 SmallVector<int> Mask(NumDstElts);
754 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
755 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
756 InstructionCost OldCost =
757 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: DstVecScalarTy, CostKind) +
758 TTI.getVectorInstrCost(I, Val: DstVecTy, CostKind, Index: InsIdx);
759
760 // If the extract has one use, it will be eliminated, so count it in the
761 // original cost. If it has more than one use, ignore the cost because it will
762 // be the same before/after.
763 if (Extract->hasOneUse())
764 OldCost += TTI.getVectorInstrCost(I: *Extract, Val: SrcVecTy, CostKind, Index: ExtIdx);
765
766 InstructionCost NewCost =
767 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: SrcVecTy, CostKind) +
768 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: DstVecTy,
769 SrcTy: DstVecTy, Mask, CostKind);
770
771 bool NeedLenChg = SrcVecTy->getNumElements() != NumDstElts;
772 // If the lengths of the two vectors are not equal,
773 // we need to add a length-change vector. Add this cost.
774 SmallVector<int> SrcMask;
775 if (NeedLenChg) {
776 SrcMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
777 SrcMask[ExtIdx % NumDstElts] = ExtIdx;
778 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
779 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: SrcMask, CostKind);
780 }
781
782 LLVM_DEBUG(dbgs() << "Found an insertion of (extract)fneg : " << I
783 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
784 << "\n");
785 if (NewCost > OldCost)
786 return false;
787
788 Value *NewShuf, *LenChgShuf = nullptr;
789 // insertelt DstVec, (fneg (extractelt SrcVec, Index)), Index
790 Value *VecFNeg = Builder.CreateFNegFMF(V: SrcVec, FMFSource: FNeg);
791 if (NeedLenChg) {
792 // shuffle DstVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
793 LenChgShuf = Builder.CreateShuffleVector(V: VecFNeg, Mask: SrcMask);
794 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: LenChgShuf, Mask);
795 Worklist.pushValue(V: LenChgShuf);
796 } else {
797 // shuffle DstVec, (fneg SrcVec), Mask
798 NewShuf = Builder.CreateShuffleVector(V1: DstVec, V2: VecFNeg, Mask);
799 }
800
801 Worklist.pushValue(V: VecFNeg);
802 replaceValue(Old&: I, New&: *NewShuf);
803 return true;
804}
805
806/// Try to fold insert(binop(x,y),binop(a,b),idx)
807/// --> binop(insert(x,a,idx),insert(y,b,idx))
808bool VectorCombine::foldInsExtBinop(Instruction &I) {
809 BinaryOperator *VecBinOp, *SclBinOp;
810 uint64_t Index;
811 if (!match(V: &I,
812 P: m_InsertElt(Val: m_OneUse(SubPattern: m_BinOp(I&: VecBinOp)),
813 Elt: m_OneUse(SubPattern: m_BinOp(I&: SclBinOp)), Idx: m_ConstantInt(V&: Index))))
814 return false;
815
816 // TODO: Add support for addlike etc.
817 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
818 if (BinOpcode != SclBinOp->getOpcode())
819 return false;
820
821 auto *ResultTy = dyn_cast<FixedVectorType>(Val: I.getType());
822 if (!ResultTy)
823 return false;
824
825 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
826 // shuffle?
827
828 InstructionCost OldCost = TTI.getInstructionCost(U: &I, CostKind) +
829 TTI.getInstructionCost(U: VecBinOp, CostKind) +
830 TTI.getInstructionCost(U: SclBinOp, CostKind);
831 InstructionCost NewCost =
832 TTI.getArithmeticInstrCost(Opcode: BinOpcode, Ty: ResultTy, CostKind) +
833 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
834 Index, Op0: VecBinOp->getOperand(i_nocapture: 0),
835 Op1: SclBinOp->getOperand(i_nocapture: 0)) +
836 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
837 Index, Op0: VecBinOp->getOperand(i_nocapture: 1),
838 Op1: SclBinOp->getOperand(i_nocapture: 1));
839
840 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
841 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
842 << "\n");
843 if (NewCost > OldCost)
844 return false;
845
846 Value *NewIns0 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 0),
847 NewElt: SclBinOp->getOperand(i_nocapture: 0), Idx: Index);
848 Value *NewIns1 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 1),
849 NewElt: SclBinOp->getOperand(i_nocapture: 1), Idx: Index);
850 Value *NewBO = Builder.CreateBinOp(Opc: BinOpcode, LHS: NewIns0, RHS: NewIns1);
851
852 // Intersect flags from the old binops.
853 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
854 NewInst->copyIRFlags(V: VecBinOp);
855 NewInst->andIRFlags(V: SclBinOp);
856 }
857
858 Worklist.pushValue(V: NewIns0);
859 Worklist.pushValue(V: NewIns1);
860 replaceValue(Old&: I, New&: *NewBO);
861 return true;
862}
863
864/// Match: bitop(castop(x), castop(y)) -> castop(bitop(x, y))
865/// Supports: bitcast, trunc, sext, zext
866bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
867 // Check if this is a bitwise logic operation
868 auto *BinOp = dyn_cast<BinaryOperator>(Val: &I);
869 if (!BinOp || !BinOp->isBitwiseLogicOp())
870 return false;
871
872 // Get the cast instructions
873 auto *LHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 0));
874 auto *RHSCast = dyn_cast<CastInst>(Val: BinOp->getOperand(i_nocapture: 1));
875 if (!LHSCast || !RHSCast) {
876 LLVM_DEBUG(dbgs() << " One or both operands are not cast instructions\n");
877 return false;
878 }
879
880 // Both casts must be the same type
881 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
882 if (CastOpcode != RHSCast->getOpcode())
883 return false;
884
885 // Only handle supported cast operations
886 switch (CastOpcode) {
887 case Instruction::BitCast:
888 case Instruction::Trunc:
889 case Instruction::SExt:
890 case Instruction::ZExt:
891 break;
892 default:
893 return false;
894 }
895
896 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
897 Value *RHSSrc = RHSCast->getOperand(i_nocapture: 0);
898
899 // Source types must match
900 if (LHSSrc->getType() != RHSSrc->getType())
901 return false;
902
903 auto *SrcTy = LHSSrc->getType();
904 auto *DstTy = I.getType();
905 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
906 // Other casts only handle vector types with integer elements.
907 if (CastOpcode != Instruction::BitCast &&
908 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
909 return false;
910
911 // Only integer scalar/vector values are legal for bitwise logic operations.
912 if (!SrcTy->getScalarType()->isIntegerTy() ||
913 !DstTy->getScalarType()->isIntegerTy())
914 return false;
915
916 // Cost Check :
917 // OldCost = bitlogic + 2*casts
918 // NewCost = bitlogic + cast
919
920 // Calculate specific costs for each cast with instruction context
921 InstructionCost LHSCastCost = TTI.getCastInstrCost(
922 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
923 InstructionCost RHSCastCost = TTI.getCastInstrCost(
924 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: RHSCast);
925
926 InstructionCost OldCost =
927 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: DstTy, CostKind) +
928 LHSCastCost + RHSCastCost;
929
930 // For new cost, we can't provide an instruction (it doesn't exist yet)
931 InstructionCost GenericCastCost = TTI.getCastInstrCost(
932 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
933
934 InstructionCost NewCost =
935 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: SrcTy, CostKind) +
936 GenericCastCost;
937
938 // Account for multi-use casts using specific costs
939 if (!LHSCast->hasOneUse())
940 NewCost += LHSCastCost;
941 if (!RHSCast->hasOneUse())
942 NewCost += RHSCastCost;
943
944 LLVM_DEBUG(dbgs() << "foldBitOpOfCastops: OldCost=" << OldCost
945 << " NewCost=" << NewCost << "\n");
946
947 if (NewCost > OldCost)
948 return false;
949
950 // Create the operation on the source type
951 Value *NewOp = Builder.CreateBinOp(Opc: BinOp->getOpcode(), LHS: LHSSrc, RHS: RHSSrc,
952 Name: BinOp->getName() + ".inner");
953 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
954 NewBinOp->copyIRFlags(V: BinOp);
955
956 Worklist.pushValue(V: NewOp);
957
958 // Create the cast operation directly to ensure we get a new instruction
959 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
960
961 // Preserve cast instruction flags
962 NewCast->copyIRFlags(V: LHSCast);
963 NewCast->andIRFlags(V: RHSCast);
964
965 // Insert the new instruction
966 Value *Result = Builder.Insert(I: NewCast);
967
968 replaceValue(Old&: I, New&: *Result);
969 return true;
970}
971
972/// Match:
973// bitop(castop(x), C) ->
974// bitop(castop(x), castop(InvC)) ->
975// castop(bitop(x, InvC))
976// Supports: bitcast
977bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
978 Instruction *LHS;
979 Constant *C;
980
981 // Check if this is a bitwise logic operation
982 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_Instruction(I&: LHS), R: m_Constant(C))))
983 return false;
984
985 // Get the cast instructions
986 auto *LHSCast = dyn_cast<CastInst>(Val: LHS);
987 if (!LHSCast)
988 return false;
989
990 Instruction::CastOps CastOpcode = LHSCast->getOpcode();
991
992 // Only handle supported cast operations
993 switch (CastOpcode) {
994 case Instruction::BitCast:
995 case Instruction::ZExt:
996 case Instruction::SExt:
997 case Instruction::Trunc:
998 break;
999 default:
1000 return false;
1001 }
1002
1003 Value *LHSSrc = LHSCast->getOperand(i_nocapture: 0);
1004
1005 auto *SrcTy = LHSSrc->getType();
1006 auto *DstTy = I.getType();
1007 // Bitcasts can handle scalar/vector mixes, such as i16 -> <16 x i1>.
1008 // Other casts only handle vector types with integer elements.
1009 if (CastOpcode != Instruction::BitCast &&
1010 (!isa<FixedVectorType>(Val: SrcTy) || !isa<FixedVectorType>(Val: DstTy)))
1011 return false;
1012
1013 // Only integer scalar/vector values are legal for bitwise logic operations.
1014 if (!SrcTy->getScalarType()->isIntegerTy() ||
1015 !DstTy->getScalarType()->isIntegerTy())
1016 return false;
1017
1018 // Find the constant InvC, such that castop(InvC) equals to C.
1019 PreservedCastFlags RHSFlags;
1020 Constant *InvC = getLosslessInvCast(C, InvCastTo: SrcTy, CastOp: CastOpcode, DL: *DL, Flags: &RHSFlags);
1021 if (!InvC)
1022 return false;
1023
1024 // Cost Check :
1025 // OldCost = bitlogic + cast
1026 // NewCost = bitlogic + cast
1027
1028 // Calculate specific costs for each cast with instruction context
1029 InstructionCost LHSCastCost = TTI.getCastInstrCost(
1030 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind, I: LHSCast);
1031
1032 InstructionCost OldCost =
1033 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: DstTy, CostKind) + LHSCastCost;
1034
1035 // For new cost, we can't provide an instruction (it doesn't exist yet)
1036 InstructionCost GenericCastCost = TTI.getCastInstrCost(
1037 Opcode: CastOpcode, Dst: DstTy, Src: SrcTy, CCH: TTI::CastContextHint::None, CostKind);
1038
1039 InstructionCost NewCost =
1040 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: SrcTy, CostKind) +
1041 GenericCastCost;
1042
1043 // Account for multi-use casts using specific costs
1044 if (!LHSCast->hasOneUse())
1045 NewCost += LHSCastCost;
1046
1047 LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1048 << " NewCost=" << NewCost << "\n");
1049
1050 if (NewCost > OldCost)
1051 return false;
1052
1053 // Create the operation on the source type
1054 Value *NewOp = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(),
1055 LHS: LHSSrc, RHS: InvC, Name: I.getName() + ".inner");
1056 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
1057 NewBinOp->copyIRFlags(V: &I);
1058
1059 Worklist.pushValue(V: NewOp);
1060
1061 // Create the cast operation directly to ensure we get a new instruction
1062 Instruction *NewCast = CastInst::Create(CastOpcode, S: NewOp, Ty: I.getType());
1063
1064 // Preserve cast instruction flags
1065 if (RHSFlags.NNeg)
1066 NewCast->setNonNeg();
1067 if (RHSFlags.NUW)
1068 NewCast->setHasNoUnsignedWrap();
1069 if (RHSFlags.NSW)
1070 NewCast->setHasNoSignedWrap();
1071
1072 NewCast->andIRFlags(V: LHSCast);
1073
1074 // Insert the new instruction
1075 Value *Result = Builder.Insert(I: NewCast);
1076
1077 replaceValue(Old&: I, New&: *Result);
1078 return true;
1079}
1080
1081/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
1082/// destination type followed by shuffle. This can enable further transforms by
1083/// moving bitcasts or shuffles together.
1084bool VectorCombine::foldBitcastShuffle(Instruction &I) {
1085 Value *V0, *V1;
1086 ArrayRef<int> Mask;
1087 if (!match(V: &I, P: m_BitCast(Op: m_OneUse(
1088 SubPattern: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(Mask))))))
1089 return false;
1090
1091 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
1092 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
1093 // mask for scalable type is a splat or not.
1094 // 2) Disallow non-vector casts.
1095 // TODO: We could allow any shuffle.
1096 auto *DestTy = dyn_cast<FixedVectorType>(Val: I.getType());
1097 auto *SrcTy = dyn_cast<FixedVectorType>(Val: V0->getType());
1098 if (!DestTy || !SrcTy)
1099 return false;
1100
1101 unsigned DestEltSize = DestTy->getScalarSizeInBits();
1102 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
1103 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
1104 return false;
1105
1106 bool IsUnary = isa<UndefValue>(Val: V1);
1107
1108 // For binary shuffles, only fold bitcast(shuffle(X,Y))
1109 // if it won't increase the number of bitcasts.
1110 if (!IsUnary) {
1111 auto *BCTy0 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V0)->getType());
1112 auto *BCTy1 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V1)->getType());
1113 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
1114 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
1115 return false;
1116 }
1117
1118 SmallVector<int, 16> NewMask;
1119 if (DestEltSize <= SrcEltSize) {
1120 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
1121 // always be expanded to the equivalent form choosing narrower elements.
1122 if (SrcEltSize % DestEltSize != 0)
1123 return false;
1124 unsigned ScaleFactor = SrcEltSize / DestEltSize;
1125 narrowShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask);
1126 } else {
1127 // The bitcast is from narrow elements to wide elements. The shuffle mask
1128 // must choose consecutive elements to allow casting first.
1129 if (DestEltSize % SrcEltSize != 0)
1130 return false;
1131 unsigned ScaleFactor = DestEltSize / SrcEltSize;
1132 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask))
1133 return false;
1134 }
1135
1136 // Bitcast the shuffle src - keep its original width but using the destination
1137 // scalar type.
1138 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
1139 auto *NewShuffleTy =
1140 FixedVectorType::get(ElementType: DestTy->getScalarType(), NumElts: NumSrcElts);
1141 auto *OldShuffleTy =
1142 FixedVectorType::get(ElementType: SrcTy->getScalarType(), NumElts: Mask.size());
1143 unsigned NumOps = IsUnary ? 1 : 2;
1144
1145 // The new shuffle must not cost more than the old shuffle.
1146 TargetTransformInfo::ShuffleKind SK =
1147 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
1148 : TargetTransformInfo::SK_PermuteTwoSrc;
1149
1150 InstructionCost NewCost =
1151 TTI.getShuffleCost(Kind: SK, DstTy: DestTy, SrcTy: NewShuffleTy, Mask: NewMask, CostKind) +
1152 (NumOps * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewShuffleTy, Src: SrcTy,
1153 CCH: TargetTransformInfo::CastContextHint::None,
1154 CostKind));
1155 InstructionCost OldCost =
1156 TTI.getShuffleCost(Kind: SK, DstTy: OldShuffleTy, SrcTy, Mask, CostKind) +
1157 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DestTy, Src: OldShuffleTy,
1158 CCH: TargetTransformInfo::CastContextHint::None,
1159 CostKind);
1160
1161 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
1162 << OldCost << " vs NewCost: " << NewCost << "\n");
1163
1164 if (NewCost > OldCost || !NewCost.isValid())
1165 return false;
1166
1167 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
1168 ++NumShufOfBitcast;
1169 Value *CastV0 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V0), DestTy: NewShuffleTy);
1170 Value *CastV1 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V1), DestTy: NewShuffleTy);
1171 Value *Shuf = Builder.CreateShuffleVector(V1: CastV0, V2: CastV1, Mask: NewMask);
1172 replaceValue(Old&: I, New&: *Shuf);
1173 return true;
1174}
1175
1176/// VP Intrinsics whose vector operands are both splat values may be simplified
1177/// into the scalar version of the operation and the result splatted. This
1178/// can lead to scalarization down the line.
1179bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
1180 if (!isa<VPIntrinsic>(Val: I))
1181 return false;
1182 VPIntrinsic &VPI = cast<VPIntrinsic>(Val&: I);
1183 Value *Op0 = VPI.getArgOperand(i: 0);
1184 Value *Op1 = VPI.getArgOperand(i: 1);
1185
1186 if (!isSplatValue(V: Op0) || !isSplatValue(V: Op1))
1187 return false;
1188
1189 // Check getSplatValue early in this function, to avoid doing unnecessary
1190 // work.
1191 Value *ScalarOp0 = getSplatValue(V: Op0);
1192 Value *ScalarOp1 = getSplatValue(V: Op1);
1193 if (!ScalarOp0 || !ScalarOp1)
1194 return false;
1195
1196 // For the binary VP intrinsics supported here, the result on disabled lanes
1197 // is a poison value. For now, only do this simplification if all lanes
1198 // are active.
1199 // TODO: Relax the condition that all lanes are active by using insertelement
1200 // on inactive lanes.
1201 auto IsAllTrueMask = [](Value *MaskVal) {
1202 if (Value *SplattedVal = getSplatValue(V: MaskVal))
1203 if (auto *ConstValue = dyn_cast<Constant>(Val: SplattedVal))
1204 return ConstValue->isAllOnesValue();
1205 return false;
1206 };
1207 if (!IsAllTrueMask(VPI.getArgOperand(i: 2)))
1208 return false;
1209
1210 // Check to make sure we support scalarization of the intrinsic
1211 Intrinsic::ID IntrID = VPI.getIntrinsicID();
1212 if (!VPBinOpIntrinsic::isVPBinOp(ID: IntrID))
1213 return false;
1214
1215 // Calculate cost of splatting both operands into vectors and the vector
1216 // intrinsic
1217 VectorType *VecTy = cast<VectorType>(Val: VPI.getType());
1218 SmallVector<int> Mask;
1219 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: VecTy))
1220 Mask.resize(N: FVTy->getNumElements(), NV: 0);
1221 InstructionCost SplatCost =
1222 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: 0) +
1223 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_Broadcast, DstTy: VecTy, SrcTy: VecTy, Mask,
1224 CostKind);
1225
1226 // Calculate the cost of the VP Intrinsic
1227 SmallVector<Type *, 4> Args;
1228 for (Value *V : VPI.args())
1229 Args.push_back(Elt: V->getType());
1230 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1231 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1232 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1233
1234 // Determine scalar opcode
1235 std::optional<unsigned> FunctionalOpcode =
1236 VPI.getFunctionalOpcode();
1237 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1238 if (!FunctionalOpcode) {
1239 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1240 if (!ScalarIntrID)
1241 return false;
1242 }
1243
1244 // Calculate cost of scalarizing
1245 InstructionCost ScalarOpCost = 0;
1246 if (ScalarIntrID) {
1247 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1248 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1249 } else {
1250 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode: *FunctionalOpcode,
1251 Ty: VecTy->getScalarType(), CostKind);
1252 }
1253
1254 // The existing splats may be kept around if other instructions use them.
1255 InstructionCost CostToKeepSplats =
1256 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1257 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1258
1259 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1260 << "\n");
1261 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1262 << ", Cost of scalarizing:" << NewCost << "\n");
1263
1264 // We want to scalarize unless the vector variant actually has lower cost.
1265 if (OldCost < NewCost || !NewCost.isValid())
1266 return false;
1267
1268 // Scalarize the intrinsic
1269 ElementCount EC = cast<VectorType>(Val: Op0->getType())->getElementCount();
1270 Value *EVL = VPI.getArgOperand(i: 3);
1271
1272 // If the VP op might introduce UB or poison, we can scalarize it provided
1273 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1274 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1275 // scalarizing it.
1276 bool SafeToSpeculate;
1277 if (ScalarIntrID)
1278 SafeToSpeculate = Intrinsic::getFnAttributes(C&: I.getContext(), id: *ScalarIntrID)
1279 .hasAttribute(Kind: Attribute::AttrKind::Speculatable);
1280 else
1281 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1282 Opcode: *FunctionalOpcode, Inst: &VPI, CtxI: nullptr, AC: SQ.AC, DT: SQ.DT);
1283 if (!SafeToSpeculate &&
1284 !isKnownNonZero(V: EVL, Q: SimplifyQuery(*DL, SQ.DT, SQ.AC, &VPI)))
1285 return false;
1286
1287 Value *ScalarVal =
1288 ScalarIntrID
1289 ? Builder.CreateIntrinsic(RetTy: VecTy->getScalarType(), ID: *ScalarIntrID,
1290 Args: {ScalarOp0, ScalarOp1})
1291 : Builder.CreateBinOp(Opc: (Instruction::BinaryOps)(*FunctionalOpcode),
1292 LHS: ScalarOp0, RHS: ScalarOp1);
1293
1294 replaceValue(Old&: VPI, New&: *Builder.CreateVectorSplat(EC, V: ScalarVal));
1295 return true;
1296}
1297
1298/// Match a vector op/compare/intrinsic with at least one
1299/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1300/// by insertelement.
1301bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1302 auto *UO = dyn_cast<UnaryOperator>(Val: &I);
1303 auto *BO = dyn_cast<BinaryOperator>(Val: &I);
1304 auto *CI = dyn_cast<CmpInst>(Val: &I);
1305 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
1306 if (!UO && !BO && !CI && !II)
1307 return false;
1308
1309 // TODO: Allow intrinsics with different argument types
1310 if (II) {
1311 if (!isTriviallyVectorizable(ID: II->getIntrinsicID()))
1312 return false;
1313 for (auto [Idx, Arg] : enumerate(First: II->args()))
1314 if (Arg->getType() != II->getType() &&
1315 !isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI))
1316 return false;
1317 }
1318
1319 // Do not convert the vector condition of a vector select into a scalar
1320 // condition. That may cause problems for codegen because of differences in
1321 // boolean formats and register-file transfers.
1322 // TODO: Can we account for that in the cost model?
1323 if (CI)
1324 for (User *U : I.users())
1325 if (match(V: U, P: m_Select(C: m_Specific(V: &I), L: m_Value(), R: m_Value())))
1326 return false;
1327
1328 // Match constant vectors or scalars being inserted into constant vectors:
1329 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1330 SmallVector<Value *> VecCs, ScalarOps;
1331 std::optional<uint64_t> Index;
1332
1333 auto Ops = II ? II->args() : I.operands();
1334 for (auto [OpNum, Op] : enumerate(First&: Ops)) {
1335 Constant *VecC;
1336 Value *V;
1337 uint64_t InsIdx = 0;
1338 if (match(V: Op.get(), P: m_InsertElt(Val: m_Constant(C&: VecC), Elt: m_Value(V),
1339 Idx: m_ConstantInt(V&: InsIdx)))) {
1340 // Bail if any inserts are out of bounds.
1341 VectorType *OpTy = cast<VectorType>(Val: Op->getType());
1342 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1343 return false;
1344 // All inserts must have the same index.
1345 // TODO: Deal with mismatched index constants and variable indexes?
1346 if (!Index)
1347 Index = InsIdx;
1348 else if (InsIdx != *Index)
1349 return false;
1350 VecCs.push_back(Elt: VecC);
1351 ScalarOps.push_back(Elt: V);
1352 } else if (II && isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(),
1353 ScalarOpdIdx: OpNum, TTI: &TTI)) {
1354 VecCs.push_back(Elt: Op.get());
1355 ScalarOps.push_back(Elt: Op.get());
1356 } else if (match(V: Op.get(), P: m_Constant(C&: VecC))) {
1357 VecCs.push_back(Elt: VecC);
1358 ScalarOps.push_back(Elt: nullptr);
1359 } else {
1360 return false;
1361 }
1362 }
1363
1364 // Bail if all operands are constant.
1365 if (!Index.has_value())
1366 return false;
1367
1368 VectorType *VecTy = cast<VectorType>(Val: I.getType());
1369 Type *ScalarTy = VecTy->getScalarType();
1370 assert(VecTy->isVectorTy() &&
1371 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1372 ScalarTy->isPointerTy()) &&
1373 "Unexpected types for insert element into binop or cmp");
1374
1375 unsigned Opcode = I.getOpcode();
1376 InstructionCost ScalarOpCost, VectorOpCost;
1377 if (CI) {
1378 CmpInst::Predicate Pred = CI->getPredicate();
1379 ScalarOpCost = TTI.getCmpSelInstrCost(
1380 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
1381 VectorOpCost = TTI.getCmpSelInstrCost(
1382 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1383 } else if (UO || BO) {
1384 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
1385 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
1386 } else {
1387 IntrinsicCostAttributes ScalarICA(
1388 II->getIntrinsicID(), ScalarTy,
1389 SmallVector<Type *>(II->arg_size(), ScalarTy));
1390 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: ScalarICA, CostKind);
1391 IntrinsicCostAttributes VectorICA(
1392 II->getIntrinsicID(), VecTy,
1393 SmallVector<Type *>(II->arg_size(), VecTy));
1394 VectorOpCost = TTI.getIntrinsicInstrCost(ICA: VectorICA, CostKind);
1395 }
1396
1397 // Fold the vector constants in the original vectors into a new base vector to
1398 // get more accurate cost modelling.
1399 Value *NewVecC = nullptr;
1400 if (CI)
1401 NewVecC = simplifyCmpInst(Predicate: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1402 else if (UO)
1403 NewVecC =
1404 simplifyUnOp(Opcode: UO->getOpcode(), Op: VecCs[0], FMF: UO->getFastMathFlags(), Q: SQ);
1405 else if (BO)
1406 NewVecC = simplifyBinOp(Opcode: BO->getOpcode(), LHS: VecCs[0], RHS: VecCs[1], Q: SQ);
1407 else if (II)
1408 NewVecC = simplifyCall(Call: II, Callee: II->getCalledOperand(), Args: VecCs, Q: SQ);
1409
1410 if (!NewVecC)
1411 return false;
1412
1413 // Get cost estimate for the insert element. This cost will factor into
1414 // both sequences.
1415 InstructionCost OldCost = VectorOpCost;
1416 InstructionCost NewCost =
1417 ScalarOpCost + TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy,
1418 CostKind, Index: *Index, Op0: NewVecC);
1419
1420 for (auto [Idx, Op, VecC, Scalar] : enumerate(First&: Ops, Rest&: VecCs, Rest&: ScalarOps)) {
1421 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1422 ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI)))
1423 continue;
1424 InstructionCost InsertCost = TTI.getVectorInstrCost(
1425 Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: *Index, Op0: VecC, Op1: Scalar);
1426 OldCost += InsertCost;
1427 NewCost += !Op->hasOneUse() * InsertCost;
1428 }
1429
1430 // We want to scalarize unless the vector variant actually has lower cost.
1431 if (OldCost < NewCost || !NewCost.isValid())
1432 return false;
1433
1434 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1435 // inselt NewVecC, (scalar_op V0, V1), Index
1436 if (CI)
1437 ++NumScalarCmp;
1438 else if (UO || BO)
1439 ++NumScalarOps;
1440 else
1441 ++NumScalarIntrinsic;
1442
1443 // For constant cases, extract the scalar element, this should constant fold.
1444 for (auto [OpIdx, Scalar, VecC] : enumerate(First&: ScalarOps, Rest&: VecCs))
1445 if (!Scalar)
1446 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1447 Vec: cast<Constant>(Val: VecC), Idx: Builder.getInt64(C: *Index));
1448
1449 Value *Scalar;
1450 if (CI)
1451 Scalar = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: ScalarOps[0], RHS: ScalarOps[1]);
1452 else if (UO || BO)
1453 Scalar = Builder.CreateNAryOp(Opc: Opcode, Ops: ScalarOps);
1454 else
1455 Scalar = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: II->getIntrinsicID(), Args: ScalarOps);
1456
1457 Scalar->setName(I.getName() + ".scalar");
1458
1459 // All IR flags are safe to back-propagate. There is no potential for extra
1460 // poison to be created by the scalar instruction.
1461 if (auto *ScalarInst = dyn_cast<Instruction>(Val: Scalar))
1462 ScalarInst->copyIRFlags(V: &I);
1463
1464 Value *Insert = Builder.CreateInsertElement(Vec: NewVecC, NewElt: Scalar, Idx: *Index);
1465 replaceValue(Old&: I, New&: *Insert);
1466 return true;
1467}
1468
1469/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1470/// a vector into vector operations followed by extract. Note: The SLP pass
1471/// may miss this pattern because of implementation problems.
1472bool VectorCombine::foldExtractedCmps(Instruction &I) {
1473 auto *BI = dyn_cast<BinaryOperator>(Val: &I);
1474
1475 // We are looking for a scalar binop of booleans.
1476 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1477 if (!BI || !I.getType()->isIntegerTy(BitWidth: 1))
1478 return false;
1479
1480 // The compare predicates should match, and each compare should have a
1481 // constant operand.
1482 Value *B0 = I.getOperand(i: 0), *B1 = I.getOperand(i: 1);
1483 Instruction *I0, *I1;
1484 Constant *C0, *C1;
1485 CmpPredicate P0, P1;
1486 if (!match(V: B0, P: m_Cmp(Pred&: P0, L: m_Instruction(I&: I0), R: m_Constant(C&: C0))) ||
1487 !match(V: B1, P: m_Cmp(Pred&: P1, L: m_Instruction(I&: I1), R: m_Constant(C&: C1))))
1488 return false;
1489
1490 auto MatchingPred = CmpPredicate::getMatching(A: P0, B: P1);
1491 if (!MatchingPred)
1492 return false;
1493
1494 // The compare operands must be extracts of the same vector with constant
1495 // extract indexes.
1496 Value *X;
1497 uint64_t Index0, Index1;
1498 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ConstantInt(V&: Index0))) ||
1499 !match(V: I1, P: m_ExtractElt(Val: m_Specific(V: X), Idx: m_ConstantInt(V&: Index1))))
1500 return false;
1501
1502 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
1503 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
1504 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex: CostKind);
1505 if (!ConvertToShuf)
1506 return false;
1507 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1508 "Unknown ExtractElementInst");
1509
1510 // The original scalar pattern is:
1511 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1512 CmpInst::Predicate Pred = *MatchingPred;
1513 unsigned CmpOpcode =
1514 CmpInst::isFPPredicate(P: Pred) ? Instruction::FCmp : Instruction::ICmp;
1515 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
1516 if (!VecTy)
1517 return false;
1518
1519 if (Index0 >= VecTy->getNumElements() || Index1 >= VecTy->getNumElements())
1520 return false;
1521
1522 InstructionCost Ext0Cost =
1523 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
1524 InstructionCost Ext1Cost =
1525 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
1526 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1527 Opcode: CmpOpcode, ValTy: I0->getType(), CondTy: CmpInst::makeCmpResultType(opnd_type: I0->getType()), VecPred: Pred,
1528 CostKind);
1529
1530 InstructionCost OldCost =
1531 Ext0Cost + Ext1Cost + CmpCost * 2 +
1532 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: I.getType(), CostKind);
1533
1534 // The proposed vector pattern is:
1535 // vcmp = cmp Pred X, VecC
1536 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1537 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1538 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1539 auto *CmpTy = cast<FixedVectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
1540 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1541 Opcode: CmpOpcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1542 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1543 ShufMask[CheapIndex] = ExpensiveIndex;
1544 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: CmpTy,
1545 SrcTy: CmpTy, Mask: ShufMask, CostKind);
1546 NewCost += TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: CmpTy, CostKind);
1547 NewCost += TTI.getVectorInstrCost(I: *Ext0, Val: CmpTy, CostKind, Index: CheapIndex);
1548 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1549 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1550
1551 // Aggressively form vector ops if the cost is equal because the transform
1552 // may enable further optimization.
1553 // Codegen can reverse this transform (scalarize) if it was not profitable.
1554 if (OldCost < NewCost || !NewCost.isValid())
1555 return false;
1556
1557 // Create a vector constant from the 2 scalar constants.
1558 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1559 PoisonValue::get(T: VecTy->getElementType()));
1560 CmpC[Index0] = C0;
1561 CmpC[Index1] = C1;
1562 Value *VCmp = Builder.CreateCmp(Pred, LHS: X, RHS: ConstantVector::get(V: CmpC));
1563 Value *Shuf = createShiftShuffle(Vec: VCmp, OldIndex: ExpensiveIndex, NewIndex: CheapIndex, Builder);
1564 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1565 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1566 Value *VecLogic = Builder.CreateBinOp(Opc: BI->getOpcode(), LHS, RHS);
1567 Value *NewExt = Builder.CreateExtractElement(Vec: VecLogic, Idx: CheapIndex);
1568 replaceValue(Old&: I, New&: *NewExt);
1569 ++NumVecCmpBO;
1570 return true;
1571}
1572
1573/// Try to fold scalar selects that select between extracted elements and zero
1574/// into extracting from a vector select. This is rooted at the bitcast.
1575///
1576/// This pattern arises when a vector is bitcast to a smaller element type,
1577/// elements are extracted, and then conditionally selected with zero:
1578///
1579/// %bc = bitcast <4 x i32> %src to <16 x i8>
1580/// %e0 = extractelement <16 x i8> %bc, i32 0
1581/// %s0 = select i1 %cond, i8 %e0, i8 0
1582/// %e1 = extractelement <16 x i8> %bc, i32 1
1583/// %s1 = select i1 %cond, i8 %e1, i8 0
1584/// ...
1585///
1586/// Transforms to:
1587/// %sel = select i1 %cond, <4 x i32> %src, <4 x i32> zeroinitializer
1588/// %bc = bitcast <4 x i32> %sel to <16 x i8>
1589/// %e0 = extractelement <16 x i8> %bc, i32 0
1590/// %e1 = extractelement <16 x i8> %bc, i32 1
1591/// ...
1592///
1593/// This is profitable because vector select on wider types produces fewer
1594/// select/cndmask instructions than scalar selects on each element.
1595bool VectorCombine::foldSelectsFromBitcast(Instruction &I) {
1596 auto *BC = dyn_cast<BitCastInst>(Val: &I);
1597 if (!BC)
1598 return false;
1599
1600 FixedVectorType *SrcVecTy = dyn_cast<FixedVectorType>(Val: BC->getSrcTy());
1601 FixedVectorType *DstVecTy = dyn_cast<FixedVectorType>(Val: BC->getDestTy());
1602 if (!SrcVecTy || !DstVecTy)
1603 return false;
1604
1605 // Source must be 32-bit or 64-bit elements, destination must be smaller
1606 // integer elements. Zero in all these types is all-bits-zero.
1607 Type *SrcEltTy = SrcVecTy->getElementType();
1608 Type *DstEltTy = DstVecTy->getElementType();
1609 unsigned SrcEltBits = SrcEltTy->getPrimitiveSizeInBits();
1610 unsigned DstEltBits = DstEltTy->getPrimitiveSizeInBits();
1611
1612 if (SrcEltBits != 32 && SrcEltBits != 64)
1613 return false;
1614
1615 if (!DstEltTy->isIntegerTy() || DstEltBits >= SrcEltBits)
1616 return false;
1617
1618 // Check profitability using TTI before collecting users.
1619 Type *CondTy = CmpInst::makeCmpResultType(opnd_type: DstEltTy);
1620 Type *VecCondTy = CmpInst::makeCmpResultType(opnd_type: SrcVecTy);
1621
1622 InstructionCost ScalarSelCost =
1623 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: DstEltTy, CondTy,
1624 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1625 InstructionCost VecSelCost =
1626 TTI.getCmpSelInstrCost(Opcode: Instruction::Select, ValTy: SrcVecTy, CondTy: VecCondTy,
1627 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
1628
1629 // We need at least this many selects for vectorization to be profitable.
1630 // VecSelCost < ScalarSelCost * NumSelects => NumSelects > VecSelCost /
1631 // ScalarSelCost
1632 if (!ScalarSelCost.isValid() || ScalarSelCost == 0)
1633 return false;
1634
1635 unsigned MinSelects = (VecSelCost.getValue() / ScalarSelCost.getValue()) + 1;
1636
1637 // Quick check: if bitcast doesn't have enough users, bail early.
1638 if (!BC->hasNUsesOrMore(N: MinSelects))
1639 return false;
1640
1641 // Collect all select users that match the pattern, grouped by condition.
1642 // Pattern: select i1 %cond, (extractelement %bc, idx), 0
1643 DenseMap<Value *, SmallVector<SelectInst *, 8>> CondToSelects;
1644
1645 for (User *U : BC->users()) {
1646 auto *Ext = dyn_cast<ExtractElementInst>(Val: U);
1647 if (!Ext)
1648 continue;
1649
1650 for (User *ExtUser : Ext->users()) {
1651 Value *Cond;
1652 // Match: select i1 %cond, %ext, 0
1653 if (match(V: ExtUser, P: m_Select(C: m_Value(V&: Cond), L: m_Specific(V: Ext), R: m_Zero())) &&
1654 Cond->getType()->isIntegerTy(BitWidth: 1))
1655 CondToSelects[Cond].push_back(Elt: cast<SelectInst>(Val: ExtUser));
1656 }
1657 }
1658
1659 if (CondToSelects.empty())
1660 return false;
1661
1662 bool MadeChange = false;
1663 Value *SrcVec = BC->getOperand(i_nocapture: 0);
1664
1665 // Process each group of selects with the same condition.
1666 for (auto [Cond, Selects] : CondToSelects) {
1667 // Only profitable if vector select cost < total scalar select cost.
1668 if (Selects.size() < MinSelects) {
1669 LLVM_DEBUG(dbgs() << "VectorCombine: foldSelectsFromBitcast not "
1670 << "profitable (VecCost=" << VecSelCost
1671 << ", ScalarCost=" << ScalarSelCost
1672 << ", NumSelects=" << Selects.size() << ")\n");
1673 continue;
1674 }
1675
1676 // Create the vector select and bitcast once for this condition.
1677 auto InsertPt = std::next(x: BC->getIterator());
1678
1679 if (auto *CondInst = dyn_cast<Instruction>(Val: Cond))
1680 if (DT.dominates(Def: BC, User: CondInst))
1681 InsertPt = std::next(x: CondInst->getIterator());
1682
1683 Builder.SetInsertPoint(InsertPt);
1684 Value *VecSel =
1685 Builder.CreateSelect(C: Cond, True: SrcVec, False: Constant::getNullValue(Ty: SrcVecTy));
1686 Value *NewBC = Builder.CreateBitCast(V: VecSel, DestTy: DstVecTy);
1687
1688 // Replace each scalar select with an extract from the new bitcast.
1689 for (SelectInst *Sel : Selects) {
1690 auto *Ext = cast<ExtractElementInst>(Val: Sel->getTrueValue());
1691 Value *Idx = Ext->getIndexOperand();
1692
1693 Builder.SetInsertPoint(Sel);
1694 Value *NewExt = Builder.CreateExtractElement(Vec: NewBC, Idx);
1695 replaceValue(Old&: *Sel, New&: *NewExt);
1696 MadeChange = true;
1697 }
1698
1699 LLVM_DEBUG(dbgs() << "VectorCombine: folded " << Selects.size()
1700 << " selects into vector select\n");
1701 }
1702
1703 return MadeChange;
1704}
1705
1706static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1707 TTI::TargetCostKind CostKind,
1708 const TargetTransformInfo &TTI,
1709 InstructionCost &CostBeforeReduction,
1710 InstructionCost &CostAfterReduction) {
1711 Instruction *Op0, *Op1;
1712 auto *RedOp = dyn_cast<Instruction>(Val: II.getOperand(i_nocapture: 0));
1713 auto *VecRedTy = cast<VectorType>(Val: II.getOperand(i_nocapture: 0)->getType());
1714 unsigned ReductionOpc =
1715 getArithmeticReductionInstruction(RdxID: II.getIntrinsicID());
1716 if (RedOp && match(V: RedOp, P: m_ZExtOrSExt(Op: m_Value()))) {
1717 bool IsUnsigned = isa<ZExtInst>(Val: RedOp);
1718 auto *ExtType = cast<VectorType>(Val: RedOp->getOperand(i: 0)->getType());
1719
1720 CostBeforeReduction =
1721 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: ExtType,
1722 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1723 CostAfterReduction =
1724 TTI.getExtendedReductionCost(Opcode: ReductionOpc, IsUnsigned, ResTy: II.getType(),
1725 Ty: ExtType, FMF: FastMathFlags(), CostKind);
1726 return;
1727 }
1728 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1729 match(V: RedOp,
1730 P: m_ZExtOrSExt(Op: m_Mul(L: m_Instruction(I&: Op0), R: m_Instruction(I&: Op1)))) &&
1731 match(V: Op0, P: m_ZExtOrSExt(Op: m_Value())) &&
1732 Op0->getOpcode() == Op1->getOpcode() &&
1733 Op0->getOperand(i: 0)->getType() == Op1->getOperand(i: 0)->getType() &&
1734 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1735 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1736 bool IsUnsigned = isa<ZExtInst>(Val: Op0);
1737 auto *ExtType = cast<VectorType>(Val: Op0->getOperand(i: 0)->getType());
1738 VectorType *MulType = VectorType::get(ElementType: Op0->getType(), Other: VecRedTy);
1739
1740 InstructionCost ExtCost =
1741 TTI.getCastInstrCost(Opcode: Op0->getOpcode(), Dst: MulType, Src: ExtType,
1742 CCH: TTI::CastContextHint::None, CostKind, I: Op0);
1743 InstructionCost MulCost =
1744 TTI.getArithmeticInstrCost(Opcode: Instruction::Mul, Ty: MulType, CostKind);
1745 InstructionCost Ext2Cost =
1746 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: MulType,
1747 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1748
1749 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1750 CostAfterReduction = TTI.getMulAccReductionCost(
1751 IsUnsigned, RedOpcode: ReductionOpc, ResTy: II.getType(), Ty: ExtType, CostKind);
1752 return;
1753 }
1754 CostAfterReduction = TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecRedTy,
1755 FMF: std::nullopt, CostKind);
1756}
1757
1758bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1759 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(Val: &I)->getOpcode();
1760 Intrinsic::ID ReductionIID = getReductionForBinop(Opc: BinOpOpc);
1761 if (BinOpOpc == Instruction::Sub)
1762 ReductionIID = Intrinsic::vector_reduce_add;
1763 if (ReductionIID == Intrinsic::not_intrinsic)
1764 return false;
1765 // FP reductions have a start-value operand that this fold doesn't handle.
1766 if (ReductionIID == Intrinsic::vector_reduce_fadd ||
1767 ReductionIID == Intrinsic::vector_reduce_fmul)
1768 return false;
1769
1770 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1771 Intrinsic::ID IID) -> Value * {
1772 auto *II = dyn_cast<IntrinsicInst>(Val: V);
1773 if (!II)
1774 return nullptr;
1775 if (II->getIntrinsicID() == IID && II->hasOneUse())
1776 return II->getArgOperand(i: 0);
1777 return nullptr;
1778 };
1779
1780 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 0), ReductionIID);
1781 if (!V0)
1782 return false;
1783 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 1), ReductionIID);
1784 if (!V1)
1785 return false;
1786
1787 auto *VTy = cast<VectorType>(Val: V0->getType());
1788 if (V1->getType() != VTy)
1789 return false;
1790 const auto &II0 = *cast<IntrinsicInst>(Val: I.getOperand(i: 0));
1791 const auto &II1 = *cast<IntrinsicInst>(Val: I.getOperand(i: 1));
1792 unsigned ReductionOpc =
1793 getArithmeticReductionInstruction(RdxID: II0.getIntrinsicID());
1794
1795 InstructionCost OldCost = 0;
1796 InstructionCost NewCost = 0;
1797 InstructionCost CostOfRedOperand0 = 0;
1798 InstructionCost CostOfRed0 = 0;
1799 InstructionCost CostOfRedOperand1 = 0;
1800 InstructionCost CostOfRed1 = 0;
1801 analyzeCostOfVecReduction(II: II0, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand0, CostAfterReduction&: CostOfRed0);
1802 analyzeCostOfVecReduction(II: II1, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand1, CostAfterReduction&: CostOfRed1);
1803 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(U: &I, CostKind);
1804 NewCost =
1805 CostOfRedOperand0 + CostOfRedOperand1 +
1806 TTI.getArithmeticInstrCost(Opcode: BinOpOpc, Ty: VTy, CostKind) +
1807 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VTy, FMF: std::nullopt, CostKind);
1808 if (NewCost >= OldCost || !NewCost.isValid())
1809 return false;
1810
1811 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1812 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1813 << "\n");
1814 Value *VectorBO;
1815 if (BinOpOpc == Instruction::Or)
1816 VectorBO = Builder.CreateOr(LHS: V0, RHS: V1, Name: "",
1817 IsDisjoint: cast<PossiblyDisjointInst>(Val&: I).isDisjoint());
1818 else
1819 VectorBO = Builder.CreateBinOp(Opc: BinOpOpc, LHS: V0, RHS: V1);
1820
1821 Value *Rdx = Builder.CreateIntrinsic(ID: ReductionIID, OverloadTypes: {VTy}, Args: {VectorBO});
1822 replaceValue(Old&: I, New&: *Rdx);
1823 return true;
1824}
1825
1826// Check if memory loc modified between two instrs in the same BB
1827static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1828 BasicBlock::iterator End,
1829 const MemoryLocation &Loc, AAResults &AA) {
1830 unsigned NumScanned = 0;
1831 return std::any_of(first: Begin, last: End, pred: [&](const Instruction &Instr) {
1832 return isModSet(MRI: AA.getModRefInfo(I: &Instr, OptLoc: Loc)) ||
1833 ++NumScanned > MaxInstrsToScan;
1834 });
1835}
1836
1837namespace {
1838/// Helper class to indicate whether a vector index can be safely scalarized and
1839/// if a freeze needs to be inserted.
1840class ScalarizationResult {
1841 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1842
1843 StatusTy Status;
1844 Value *ToFreeze;
1845
1846 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1847 : Status(Status), ToFreeze(ToFreeze) {}
1848
1849public:
1850 ScalarizationResult(const ScalarizationResult &Other) = default;
1851 ~ScalarizationResult() {
1852 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1853 }
1854
1855 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1856 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1857 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1858 return {StatusTy::SafeWithFreeze, ToFreeze};
1859 }
1860
1861 /// Returns true if the index can be scalarize without requiring a freeze.
1862 bool isSafe() const { return Status == StatusTy::Safe; }
1863 /// Returns true if the index cannot be scalarized.
1864 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1865 /// Returns true if the index can be scalarize, but requires inserting a
1866 /// freeze.
1867 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1868
1869 /// Reset the state of Unsafe and clear ToFreze if set.
1870 void discard() {
1871 ToFreeze = nullptr;
1872 Status = StatusTy::Unsafe;
1873 }
1874
1875 /// Freeze the ToFreeze and update the use in \p User to use it.
1876 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1877 assert(isSafeWithFreeze() &&
1878 "should only be used when freezing is required");
1879 assert(is_contained(ToFreeze->users(), &UserI) &&
1880 "UserI must be a user of ToFreeze");
1881 IRBuilder<>::InsertPointGuard Guard(Builder);
1882 Builder.SetInsertPoint(cast<Instruction>(Val: &UserI));
1883 Value *Frozen =
1884 Builder.CreateFreeze(V: ToFreeze, Name: ToFreeze->getName() + ".frozen");
1885 for (Use &U : make_early_inc_range(Range: (UserI.operands())))
1886 if (U.get() == ToFreeze)
1887 U.set(Frozen);
1888
1889 ToFreeze = nullptr;
1890 }
1891};
1892} // namespace
1893
1894/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1895/// Idx. \p Idx must access a valid vector element.
1896static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1897 const SimplifyQuery &SQ) {
1898 // We do checks for both fixed vector types and scalable vector types.
1899 // This is the number of elements of fixed vector types,
1900 // or the minimum number of elements of scalable vector types.
1901 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1902 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1903
1904 if (auto *C = dyn_cast<ConstantInt>(Val: Idx)) {
1905 if (C->getValue().ult(RHS: NumElements))
1906 return ScalarizationResult::safe();
1907 return ScalarizationResult::unsafe();
1908 }
1909
1910 // Always unsafe if the index type can't handle all inbound values.
1911 if (!llvm::isUIntN(N: IntWidth, x: NumElements))
1912 return ScalarizationResult::unsafe();
1913
1914 APInt Zero(IntWidth, 0);
1915 APInt MaxElts(IntWidth, NumElements);
1916 ConstantRange ValidIndices(Zero, MaxElts);
1917 ConstantRange IdxRange(IntWidth, true);
1918
1919 if (isGuaranteedNotToBePoison(V: Idx, AC: SQ.AC, CtxI: SQ.CxtI, DT: SQ.DT)) {
1920 if (ValidIndices.contains(
1921 CR: computeConstantRange(V: Idx, /*ForSigned=*/false, SQ)))
1922 return ScalarizationResult::safe();
1923 return ScalarizationResult::unsafe();
1924 }
1925
1926 // If the index may be poison, check if we can insert a freeze before the
1927 // range of the index is restricted.
1928 Value *IdxBase;
1929 ConstantInt *CI;
1930 if (match(V: Idx, P: m_And(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1931 IdxRange = IdxRange.binaryAnd(Other: CI->getValue());
1932 } else if (match(V: Idx, P: m_URem(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1933 IdxRange = IdxRange.urem(Other: CI->getValue());
1934 }
1935
1936 if (ValidIndices.contains(CR: IdxRange))
1937 return ScalarizationResult::safeWithFreeze(ToFreeze: IdxBase);
1938 return ScalarizationResult::unsafe();
1939}
1940
1941/// The memory operation on a vector of \p ScalarType had alignment of
1942/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1943/// alignment that will be valid for the memory operation on a single scalar
1944/// element of the same type with index \p Idx.
1945static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1946 Type *ScalarType, Value *Idx,
1947 const DataLayout &DL) {
1948 if (auto *C = dyn_cast<ConstantInt>(Val: Idx))
1949 return commonAlignment(A: VectorAlignment,
1950 Offset: C->getZExtValue() * DL.getTypeStoreSize(Ty: ScalarType));
1951 return commonAlignment(A: VectorAlignment, Offset: DL.getTypeStoreSize(Ty: ScalarType));
1952}
1953
1954// Combine patterns like:
1955// %0 = load <4 x i32>, <4 x i32>* %a
1956// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1957// store <4 x i32> %1, <4 x i32>* %a
1958// to:
1959// %0 = bitcast <4 x i32>* %a to i32*
1960// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1961// store i32 %b, i32* %1
1962bool VectorCombine::foldSingleElementStore(Instruction &I) {
1963 if (!TTI.allowVectorElementIndexingUsingGEP())
1964 return false;
1965 auto *SI = cast<StoreInst>(Val: &I);
1966 if (!SI->isSimple() || !isa<VectorType>(Val: SI->getValueOperand()->getType()))
1967 return false;
1968
1969 // TODO: Combine more complicated patterns (multiple insert) by referencing
1970 // TargetTransformInfo.
1971 Instruction *Source;
1972 Value *NewElement;
1973 Value *Idx;
1974 if (!match(V: SI->getValueOperand(),
1975 P: m_InsertElt(Val: m_Instruction(I&: Source), Elt: m_Value(V&: NewElement),
1976 Idx: m_Value(V&: Idx))))
1977 return false;
1978
1979 if (auto *Load = dyn_cast<LoadInst>(Val: Source)) {
1980 auto VecTy = cast<VectorType>(Val: SI->getValueOperand()->getType());
1981 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1982 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1983 // modified between, vector type matches store size, and index is inbounds.
1984 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1985 !DL->typeSizeEqualsStoreSize(Ty: Load->getType()->getScalarType()) ||
1986 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1987 return false;
1988
1989 auto ScalarizableIdx =
1990 canScalarizeAccess(VecTy, Idx, SQ: SQ.getWithInstruction(I: Load));
1991 if (ScalarizableIdx.isUnsafe() ||
1992 isMemModifiedBetween(Begin: Load->getIterator(), End: SI->getIterator(),
1993 Loc: MemoryLocation::get(SI), AA))
1994 return false;
1995
1996 // Ensure we add the load back to the worklist BEFORE its users so they can
1997 // erased in the correct order.
1998 Worklist.push(I: Load);
1999
2000 if (ScalarizableIdx.isSafeWithFreeze())
2001 ScalarizableIdx.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
2002 Value *GEP = Builder.CreateInBoundsGEP(
2003 Ty: SI->getValueOperand()->getType(), Ptr: SI->getPointerOperand(),
2004 IdxList: {ConstantInt::get(Ty: Idx->getType(), V: 0), Idx});
2005 StoreInst *NSI = Builder.CreateStore(Val: NewElement, Ptr: GEP);
2006 NSI->copyMetadata(SrcInst: *SI);
2007 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
2008 VectorAlignment: std::max(a: SI->getAlign(), b: Load->getAlign()), ScalarType: NewElement->getType(), Idx,
2009 DL: *DL);
2010 NSI->setAlignment(ScalarOpAlignment);
2011 replaceValue(Old&: I, New&: *NSI);
2012 eraseInstruction(I);
2013 return true;
2014 }
2015
2016 return false;
2017}
2018
2019/// Try to scalarize vector loads feeding extractelement or bitcast
2020/// instructions.
2021bool VectorCombine::scalarizeLoad(Instruction &I) {
2022 Value *Ptr;
2023 if (!match(V: &I, P: m_Load(Op: m_Value(V&: Ptr))))
2024 return false;
2025
2026 auto *LI = cast<LoadInst>(Val: &I);
2027 auto *VecTy = cast<VectorType>(Val: LI->getType());
2028
2029 // The isSimple() check could be isUnordered(), but for now we cowardly
2030 // refuse to handle even unordered atomics.
2031 if (!LI->isSimple() || !DL->typeSizeEqualsStoreSize(Ty: VecTy->getScalarType()))
2032 return false;
2033
2034 bool AllExtracts = true;
2035 bool AllBitcasts = true;
2036 Instruction *LastCheckedInst = LI;
2037 unsigned NumInstChecked = 0;
2038
2039 // Check what type of users we have (must either all be extracts or
2040 // bitcasts) and ensure no memory modifications between the load and
2041 // its users.
2042 for (User *U : LI->users()) {
2043 auto *UI = dyn_cast<Instruction>(Val: U);
2044 if (!UI || UI->getParent() != LI->getParent())
2045 return false;
2046
2047 // If any user is waiting to be erased, then bail out as this will
2048 // distort the cost calculation and possibly lead to infinite loops.
2049 if (UI->use_empty())
2050 return false;
2051
2052 if (!isa<ExtractElementInst>(Val: UI))
2053 AllExtracts = false;
2054 if (!isa<BitCastInst>(Val: UI))
2055 AllBitcasts = false;
2056
2057 // Check if any instruction between the load and the user may modify memory.
2058 if (LastCheckedInst->comesBefore(Other: UI)) {
2059 for (Instruction &I :
2060 make_range(x: std::next(x: LI->getIterator()), y: UI->getIterator())) {
2061 // Bail out if we reached the check limit or the instruction may write
2062 // to memory.
2063 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
2064 return false;
2065 NumInstChecked++;
2066 }
2067 LastCheckedInst = UI;
2068 }
2069 }
2070
2071 if (AllExtracts)
2072 return scalarizeLoadExtract(LI, VecTy, Ptr);
2073 if (AllBitcasts)
2074 return scalarizeLoadBitcast(LI, VecTy, Ptr);
2075 return false;
2076}
2077
2078/// Try to scalarize vector loads feeding extractelement instructions.
2079bool VectorCombine::scalarizeLoadExtract(LoadInst *LI, VectorType *VecTy,
2080 Value *Ptr) {
2081 if (!TTI.allowVectorElementIndexingUsingGEP())
2082 return false;
2083
2084 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
2085 llvm::scope_exit FailureGuard([&]() {
2086 // If the transform is aborted, discard the ScalarizationResults.
2087 for (auto &Pair : NeedFreeze)
2088 Pair.second.discard();
2089 });
2090
2091 InstructionCost OriginalCost =
2092 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2093 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2094 InstructionCost ScalarizedCost = 0;
2095
2096 for (User *U : LI->users()) {
2097 auto *UI = cast<ExtractElementInst>(Val: U);
2098
2099 auto ScalarIdx = canScalarizeAccess(VecTy, Idx: UI->getIndexOperand(),
2100 SQ: SQ.getWithInstruction(I: LI));
2101 if (ScalarIdx.isUnsafe())
2102 return false;
2103 if (ScalarIdx.isSafeWithFreeze()) {
2104 NeedFreeze.try_emplace(Key: UI, Args&: ScalarIdx);
2105 ScalarIdx.discard();
2106 }
2107
2108 auto *Index = dyn_cast<ConstantInt>(Val: UI->getIndexOperand());
2109 OriginalCost +=
2110 TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: VecTy, CostKind,
2111 Index: Index ? Index->getZExtValue() : -1);
2112 ScalarizedCost +=
2113 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy->getElementType(),
2114 Alignment: Align(1), AddressSpace: LI->getPointerAddressSpace(), CostKind);
2115 ScalarizedCost += TTI.getAddressComputationCost(PtrTy: LI->getPointerOperandType(),
2116 SE: nullptr, Ptr: nullptr, CostKind);
2117 }
2118
2119 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << *LI
2120 << "\n LoadExtractCost: " << OriginalCost
2121 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2122
2123 if (ScalarizedCost >= OriginalCost)
2124 return false;
2125
2126 // Ensure we add the load back to the worklist BEFORE its users so they can
2127 // erased in the correct order.
2128 Worklist.push(I: LI);
2129
2130 Type *ElemType = VecTy->getElementType();
2131
2132 // Replace extracts with narrow scalar loads.
2133 for (User *U : LI->users()) {
2134 auto *EI = cast<ExtractElementInst>(Val: U);
2135 Value *Idx = EI->getIndexOperand();
2136
2137 // Insert 'freeze' for poison indexes.
2138 auto It = NeedFreeze.find(Val: EI);
2139 if (It != NeedFreeze.end())
2140 It->second.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
2141
2142 Builder.SetInsertPoint(EI);
2143 Value *GEP =
2144 Builder.CreateInBoundsGEP(Ty: VecTy, Ptr, IdxList: {Builder.getInt32(C: 0), Idx});
2145 auto *NewLoad = cast<LoadInst>(
2146 Val: Builder.CreateLoad(Ty: ElemType, Ptr: GEP, Name: EI->getName() + ".scalar"));
2147
2148 Align ScalarOpAlignment =
2149 computeAlignmentAfterScalarization(VectorAlignment: LI->getAlign(), ScalarType: ElemType, Idx, DL: *DL);
2150 NewLoad->setAlignment(ScalarOpAlignment);
2151
2152 if (auto *ConstIdx = dyn_cast<ConstantInt>(Val: Idx)) {
2153 size_t Offset = ConstIdx->getZExtValue() * DL->getTypeStoreSize(Ty: ElemType);
2154 AAMDNodes OldAAMD = LI->getAAMetadata();
2155 NewLoad->setAAMetadata(OldAAMD.adjustForAccess(Offset, AccessTy: ElemType, DL: *DL));
2156 }
2157
2158 replaceValue(Old&: *EI, New&: *NewLoad, Erase: false);
2159 }
2160
2161 FailureGuard.release();
2162 return true;
2163}
2164
2165/// Try to scalarize vector loads feeding bitcast instructions.
2166bool VectorCombine::scalarizeLoadBitcast(LoadInst *LI, VectorType *VecTy,
2167 Value *Ptr) {
2168 InstructionCost OriginalCost =
2169 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
2170 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2171
2172 Type *TargetScalarType = nullptr;
2173 unsigned VecBitWidth = DL->getTypeSizeInBits(Ty: VecTy);
2174
2175 for (User *U : LI->users()) {
2176 auto *BC = cast<BitCastInst>(Val: U);
2177
2178 Type *DestTy = BC->getDestTy();
2179 if (!DestTy->isIntegerTy() && !DestTy->isFloatingPointTy())
2180 return false;
2181
2182 unsigned DestBitWidth = DL->getTypeSizeInBits(Ty: DestTy);
2183 if (DestBitWidth != VecBitWidth)
2184 return false;
2185
2186 // All bitcasts must target the same scalar type.
2187 if (!TargetScalarType)
2188 TargetScalarType = DestTy;
2189 else if (TargetScalarType != DestTy)
2190 return false;
2191
2192 OriginalCost +=
2193 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: TargetScalarType, Src: VecTy,
2194 CCH: TTI.getCastContextHint(I: BC), CostKind, I: BC);
2195 }
2196
2197 if (!TargetScalarType)
2198 return false;
2199
2200 assert(!LI->user_empty() && "Unexpected load without bitcast users");
2201 InstructionCost ScalarizedCost =
2202 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: TargetScalarType, Alignment: LI->getAlign(),
2203 AddressSpace: LI->getPointerAddressSpace(), CostKind);
2204
2205 LLVM_DEBUG(dbgs() << "Found vector load feeding only bitcasts: " << *LI
2206 << "\n OriginalCost: " << OriginalCost
2207 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
2208
2209 if (ScalarizedCost >= OriginalCost)
2210 return false;
2211
2212 // Ensure we add the load back to the worklist BEFORE its users so they can
2213 // erased in the correct order.
2214 Worklist.push(I: LI);
2215
2216 Builder.SetInsertPoint(LI);
2217 auto *ScalarLoad =
2218 Builder.CreateLoad(Ty: TargetScalarType, Ptr, Name: LI->getName() + ".scalar");
2219 ScalarLoad->setAlignment(LI->getAlign());
2220 ScalarLoad->copyMetadata(SrcInst: *LI);
2221
2222 // Replace all bitcast users with the scalar load.
2223 for (User *U : LI->users()) {
2224 auto *BC = cast<BitCastInst>(Val: U);
2225 replaceValue(Old&: *BC, New&: *ScalarLoad, Erase: false);
2226 }
2227
2228 return true;
2229}
2230
2231bool VectorCombine::scalarizeExtExtract(Instruction &I) {
2232 if (!TTI.allowVectorElementIndexingUsingGEP())
2233 return false;
2234 auto *Ext = dyn_cast<ZExtInst>(Val: &I);
2235 if (!Ext)
2236 return false;
2237
2238 // Try to convert a vector zext feeding only extracts to a set of scalar
2239 // (Src << ExtIdx *Size) & (Size -1)
2240 // if profitable .
2241 auto *SrcTy = dyn_cast<FixedVectorType>(Val: Ext->getOperand(i_nocapture: 0)->getType());
2242 if (!SrcTy)
2243 return false;
2244 auto *DstTy = cast<FixedVectorType>(Val: Ext->getType());
2245
2246 Type *ScalarDstTy = DstTy->getElementType();
2247 if (DL->getTypeSizeInBits(Ty: SrcTy) != DL->getTypeSizeInBits(Ty: ScalarDstTy))
2248 return false;
2249
2250 InstructionCost VectorCost =
2251 TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: DstTy, Src: SrcTy,
2252 CCH: TTI::CastContextHint::None, CostKind, I: Ext);
2253 unsigned ExtCnt = 0;
2254 bool ExtLane0 = false;
2255 for (User *U : Ext->users()) {
2256 uint64_t Idx;
2257 if (!match(V: U, P: m_ExtractElt(Val: m_Value(), Idx: m_ConstantInt(V&: Idx))))
2258 return false;
2259 if (cast<Instruction>(Val: U)->use_empty())
2260 continue;
2261 ExtCnt += 1;
2262 ExtLane0 |= !Idx;
2263 VectorCost += TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: DstTy,
2264 CostKind, Index: Idx, Op0: U);
2265 }
2266
2267 InstructionCost ScalarCost =
2268 ExtCnt * TTI.getArithmeticInstrCost(
2269 Opcode: Instruction::And, Ty: ScalarDstTy, CostKind,
2270 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2271 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None}) +
2272 (ExtCnt - ExtLane0) *
2273 TTI.getArithmeticInstrCost(
2274 Opcode: Instruction::LShr, Ty: ScalarDstTy, CostKind,
2275 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
2276 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None});
2277 if (ScalarCost > VectorCost)
2278 return false;
2279
2280 Value *ScalarV = Ext->getOperand(i_nocapture: 0);
2281 if (!isGuaranteedNotToBePoison(V: ScalarV, AC: SQ.AC, CtxI: dyn_cast<Instruction>(Val: ScalarV),
2282 DT: SQ.DT)) {
2283 // Check wether all lanes are extracted, all extracts trigger UB
2284 // on poison, and the last extract (and hence all previous ones)
2285 // are guaranteed to execute if Ext executes. If so, we do not
2286 // need to insert a freeze.
2287 SmallDenseSet<ConstantInt *, 8> ExtractedLanes;
2288 bool AllExtractsTriggerUB = true;
2289 ExtractElementInst *LastExtract = nullptr;
2290 BasicBlock *ExtBB = Ext->getParent();
2291 for (User *U : Ext->users()) {
2292 auto *Extract = cast<ExtractElementInst>(Val: U);
2293 if (Extract->getParent() != ExtBB || !programUndefinedIfPoison(Inst: Extract)) {
2294 AllExtractsTriggerUB = false;
2295 break;
2296 }
2297 ExtractedLanes.insert(V: cast<ConstantInt>(Val: Extract->getIndexOperand()));
2298 if (!LastExtract || LastExtract->comesBefore(Other: Extract))
2299 LastExtract = Extract;
2300 }
2301 if (ExtractedLanes.size() != DstTy->getNumElements() ||
2302 !AllExtractsTriggerUB ||
2303 !isGuaranteedToTransferExecutionToSuccessor(Begin: Ext->getIterator(),
2304 End: LastExtract->getIterator()))
2305 ScalarV = Builder.CreateFreeze(V: ScalarV);
2306 }
2307 ScalarV = Builder.CreateBitCast(
2308 V: ScalarV,
2309 DestTy: IntegerType::get(C&: SrcTy->getContext(), NumBits: DL->getTypeSizeInBits(Ty: SrcTy)));
2310 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(Ty: SrcTy->getElementType());
2311 uint64_t TotalBits = DL->getTypeSizeInBits(Ty: SrcTy);
2312 APInt EltBitMask = APInt::getLowBitsSet(numBits: TotalBits, loBitsSet: SrcEltSizeInBits);
2313 Type *PackedTy = IntegerType::get(C&: SrcTy->getContext(), NumBits: TotalBits);
2314 Value *Mask = ConstantInt::get(Ty: PackedTy, V: EltBitMask);
2315 for (User *U : Ext->users()) {
2316 auto *Extract = cast<ExtractElementInst>(Val: U);
2317 uint64_t Idx =
2318 cast<ConstantInt>(Val: Extract->getIndexOperand())->getZExtValue();
2319 uint64_t ShiftAmt =
2320 DL->isBigEndian()
2321 ? (TotalBits - SrcEltSizeInBits - Idx * SrcEltSizeInBits)
2322 : (Idx * SrcEltSizeInBits);
2323 Value *LShr = Builder.CreateLShr(LHS: ScalarV, RHS: ShiftAmt);
2324 Value *And = Builder.CreateAnd(LHS: LShr, RHS: Mask);
2325 U->replaceAllUsesWith(V: And);
2326 }
2327 return true;
2328}
2329
2330/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
2331/// to "(bitcast (concat X, Y))"
2332/// where X/Y are bitcasted from i1 mask vectors.
2333bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
2334 Type *Ty = I.getType();
2335 if (!Ty->isIntegerTy())
2336 return false;
2337
2338 // TODO: Add big endian test coverage
2339 if (DL->isBigEndian())
2340 return false;
2341
2342 // Restrict to disjoint cases so the mask vectors aren't overlapping.
2343 Instruction *X, *Y;
2344 if (!match(V: &I, P: m_DisjointOr(L: m_Instruction(I&: X), R: m_Instruction(I&: Y))))
2345 return false;
2346
2347 // Allow both sources to contain shl, to handle more generic pattern:
2348 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
2349 Value *SrcX;
2350 uint64_t ShAmtX = 0;
2351 if (!match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX)))))) &&
2352 !match(V: X, P: m_OneUse(
2353 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX))))),
2354 R: m_ConstantInt(V&: ShAmtX)))))
2355 return false;
2356
2357 Value *SrcY;
2358 uint64_t ShAmtY = 0;
2359 if (!match(V: Y, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY)))))) &&
2360 !match(V: Y, P: m_OneUse(
2361 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY))))),
2362 R: m_ConstantInt(V&: ShAmtY)))))
2363 return false;
2364
2365 // Canonicalize larger shift to the RHS.
2366 if (ShAmtX > ShAmtY) {
2367 std::swap(a&: X, b&: Y);
2368 std::swap(a&: SrcX, b&: SrcY);
2369 std::swap(a&: ShAmtX, b&: ShAmtY);
2370 }
2371
2372 // Ensure both sources are matching vXi1 bool mask types, and that the shift
2373 // difference is the mask width so they can be easily concatenated together.
2374 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
2375 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
2376 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
2377 auto *MaskTy = dyn_cast<FixedVectorType>(Val: SrcX->getType());
2378 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
2379 !MaskTy->getElementType()->isIntegerTy(BitWidth: 1) ||
2380 MaskTy->getNumElements() != ShAmtDiff ||
2381 MaskTy->getNumElements() > (BitWidth / 2))
2382 return false;
2383
2384 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(VTy: MaskTy);
2385 auto *ConcatIntTy =
2386 Type::getIntNTy(C&: Ty->getContext(), N: ConcatTy->getNumElements());
2387 auto *MaskIntTy = Type::getIntNTy(C&: Ty->getContext(), N: ShAmtDiff);
2388
2389 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
2390 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
2391
2392 // TODO: Is it worth supporting multi use cases?
2393 InstructionCost OldCost = 0;
2394 OldCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Or, Ty, CostKind);
2395 OldCost +=
2396 NumSHL * TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2397 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: MaskIntTy,
2398 CCH: TTI::CastContextHint::None, CostKind);
2399 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: MaskIntTy, Src: MaskTy,
2400 CCH: TTI::CastContextHint::None, CostKind);
2401
2402 InstructionCost NewCost = 0;
2403 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ConcatTy,
2404 SrcTy: MaskTy, Mask: ConcatMask, CostKind);
2405 NewCost += TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: ConcatIntTy, Src: ConcatTy,
2406 CCH: TTI::CastContextHint::None, CostKind);
2407 if (Ty != ConcatIntTy)
2408 NewCost += TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: ConcatIntTy,
2409 CCH: TTI::CastContextHint::None, CostKind);
2410 if (ShAmtX > 0)
2411 NewCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
2412
2413 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
2414 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2415 << "\n");
2416
2417 if (NewCost > OldCost)
2418 return false;
2419
2420 // Build bool mask concatenation, bitcast back to scalar integer, and perform
2421 // any residual zero-extension or shifting.
2422 Value *Concat = Builder.CreateShuffleVector(V1: SrcX, V2: SrcY, Mask: ConcatMask);
2423 Worklist.pushValue(V: Concat);
2424
2425 Value *Result = Builder.CreateBitCast(V: Concat, DestTy: ConcatIntTy);
2426
2427 if (Ty != ConcatIntTy) {
2428 Worklist.pushValue(V: Result);
2429 Result = Builder.CreateZExt(V: Result, DestTy: Ty);
2430 }
2431
2432 if (ShAmtX > 0) {
2433 Worklist.pushValue(V: Result);
2434 Result = Builder.CreateShl(LHS: Result, RHS: ShAmtX);
2435 }
2436
2437 replaceValue(Old&: I, New&: *Result);
2438 return true;
2439}
2440
2441/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
2442/// --> "binop (shuffle), (shuffle)".
2443bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
2444 BinaryOperator *BinOp;
2445 ArrayRef<int> OuterMask;
2446 if (!match(V: &I, P: m_Shuffle(v1: m_BinOp(I&: BinOp), v2: m_Undef(), mask: m_Mask(OuterMask))))
2447 return false;
2448
2449 // Don't introduce poison into div/rem.
2450 if (BinOp->isIntDivRem() && llvm::is_contained(Range&: OuterMask, Element: PoisonMaskElem))
2451 return false;
2452
2453 Value *Op00, *Op01, *Op10, *Op11;
2454 ArrayRef<int> Mask0, Mask1;
2455 bool Match0 = match(V: BinOp->getOperand(i_nocapture: 0),
2456 P: m_Shuffle(v1: m_Value(V&: Op00), v2: m_Value(V&: Op01), mask: m_Mask(Mask0)));
2457 bool Match1 = match(V: BinOp->getOperand(i_nocapture: 1),
2458 P: m_Shuffle(v1: m_Value(V&: Op10), v2: m_Value(V&: Op11), mask: m_Mask(Mask1)));
2459 if (!Match0 && !Match1)
2460 return false;
2461
2462 Op00 = Match0 ? Op00 : BinOp->getOperand(i_nocapture: 0);
2463 Op01 = Match0 ? Op01 : BinOp->getOperand(i_nocapture: 0);
2464 Op10 = Match1 ? Op10 : BinOp->getOperand(i_nocapture: 1);
2465 Op11 = Match1 ? Op11 : BinOp->getOperand(i_nocapture: 1);
2466
2467 Instruction::BinaryOps Opcode = BinOp->getOpcode();
2468 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2469 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: BinOp->getType());
2470 auto *Op0Ty = dyn_cast<FixedVectorType>(Val: Op00->getType());
2471 auto *Op1Ty = dyn_cast<FixedVectorType>(Val: Op10->getType());
2472 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
2473 return false;
2474
2475 unsigned NumSrcElts = BinOpTy->getNumElements();
2476
2477 // Don't accept shuffles that reference the second operand in
2478 // div/rem or if its an undef arg.
2479 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(Val: I.getOperand(i: 1))) &&
2480 any_of(Range&: OuterMask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2481 return false;
2482
2483 // Merge outer / inner (or identity if no match) shuffles.
2484 SmallVector<int> NewMask0, NewMask1;
2485 for (int M : OuterMask) {
2486 if (M < 0 || M >= (int)NumSrcElts) {
2487 NewMask0.push_back(Elt: PoisonMaskElem);
2488 NewMask1.push_back(Elt: PoisonMaskElem);
2489 } else {
2490 NewMask0.push_back(Elt: Match0 ? Mask0[M] : M);
2491 NewMask1.push_back(Elt: Match1 ? Mask1[M] : M);
2492 }
2493 }
2494
2495 unsigned NumOpElts = Op0Ty->getNumElements();
2496 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2497 all_of(Range&: NewMask0, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2498 ShuffleVectorInst::isIdentityMask(Mask: NewMask0, NumSrcElts: NumOpElts);
2499 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2500 all_of(Range&: NewMask1, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2501 ShuffleVectorInst::isIdentityMask(Mask: NewMask1, NumSrcElts: NumOpElts);
2502
2503 InstructionCost NewCost = 0;
2504 // Try to merge shuffles across the binop if the new shuffles are not costly.
2505 InstructionCost BinOpCost =
2506 TTI.getArithmeticInstrCost(Opcode, Ty: BinOpTy, CostKind);
2507 InstructionCost OldCost =
2508 BinOpCost + TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
2509 DstTy: ShuffleDstTy, SrcTy: BinOpTy, Mask: OuterMask, CostKind,
2510 Index: 0, SubTp: nullptr, Args: {BinOp}, CxtI: &I);
2511 if (!BinOp->hasOneUse())
2512 NewCost += BinOpCost;
2513
2514 if (Match0) {
2515 InstructionCost Shuf0Cost = TTI.getShuffleCost(
2516 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op0Ty, Mask: Mask0, CostKind,
2517 Index: 0, SubTp: nullptr, Args: {Op00, Op01}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 0)));
2518 OldCost += Shuf0Cost;
2519 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 0)->hasOneUse())
2520 NewCost += Shuf0Cost;
2521 }
2522 if (Match1) {
2523 InstructionCost Shuf1Cost = TTI.getShuffleCost(
2524 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op1Ty, Mask: Mask1, CostKind,
2525 Index: 0, SubTp: nullptr, Args: {Op10, Op11}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 1)));
2526 OldCost += Shuf1Cost;
2527 if (!BinOp->hasOneUse() || !BinOp->getOperand(i_nocapture: 1)->hasOneUse())
2528 NewCost += Shuf1Cost;
2529 }
2530
2531 NewCost += TTI.getArithmeticInstrCost(Opcode, Ty: ShuffleDstTy, CostKind);
2532
2533 if (!IsIdentity0)
2534 NewCost +=
2535 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2536 SrcTy: Op0Ty, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {Op00, Op01});
2537 if (!IsIdentity1)
2538 NewCost +=
2539 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2540 SrcTy: Op1Ty, Mask: NewMask1, CostKind, Index: 0, SubTp: nullptr, Args: {Op10, Op11});
2541
2542 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2543 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2544 << "\n");
2545
2546 // If costs are equal, still fold as we reduce instruction count.
2547 if (NewCost > OldCost)
2548 return false;
2549
2550 Value *LHS =
2551 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(V1: Op00, V2: Op01, Mask: NewMask0);
2552 Value *RHS =
2553 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(V1: Op10, V2: Op11, Mask: NewMask1);
2554 Value *NewBO = Builder.CreateBinOp(Opc: Opcode, LHS, RHS);
2555
2556 // Intersect flags from the old binops.
2557 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO))
2558 NewInst->copyIRFlags(V: BinOp);
2559
2560 Worklist.pushValue(V: LHS);
2561 Worklist.pushValue(V: RHS);
2562 replaceValue(Old&: I, New&: *NewBO);
2563 return true;
2564}
2565
2566/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2567/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2568bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2569 ArrayRef<int> OldMask;
2570 Instruction *LHS, *RHS;
2571 if (!match(V: &I, P: m_Shuffle(v1: m_Instruction(I&: LHS), v2: m_Instruction(I&: RHS),
2572 mask: m_Mask(OldMask))))
2573 return false;
2574
2575 // TODO: Add support for addlike etc.
2576 if (LHS->getOpcode() != RHS->getOpcode())
2577 return false;
2578
2579 Value *X, *Y, *Z, *W;
2580 bool IsCommutative = false;
2581 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2582 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2583 if (match(V: LHS, P: m_BinOp(L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2584 match(V: RHS, P: m_BinOp(L: m_Value(V&: Z), R: m_Value(V&: W)))) {
2585 auto *BO = cast<BinaryOperator>(Val: LHS);
2586 // Don't introduce poison into div/rem.
2587 if (llvm::is_contained(Range&: OldMask, Element: PoisonMaskElem) && BO->isIntDivRem())
2588 return false;
2589 IsCommutative = BinaryOperator::isCommutative(Opcode: BO->getOpcode());
2590 } else if (match(V: LHS, P: m_Cmp(Pred&: PredLHS, L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2591 match(V: RHS, P: m_Cmp(Pred&: PredRHS, L: m_Value(V&: Z), R: m_Value(V&: W))) &&
2592 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2593 IsCommutative = cast<CmpInst>(Val: LHS)->isCommutative();
2594 } else
2595 return false;
2596
2597 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2598 auto *BinResTy = dyn_cast<FixedVectorType>(Val: LHS->getType());
2599 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: X->getType());
2600 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2601 return false;
2602
2603 bool SameBinOp = LHS == RHS;
2604 unsigned NumSrcElts = BinOpTy->getNumElements();
2605
2606 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2607 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2608 std::swap(a&: X, b&: Y);
2609
2610 auto ConvertToUnary = [NumSrcElts](int &M) {
2611 if (M >= (int)NumSrcElts)
2612 M -= NumSrcElts;
2613 };
2614
2615 SmallVector<int> NewMask0(OldMask);
2616 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2617 TTI::OperandValueInfo Op0Info = TTI.commonOperandInfo(X, Y: Z);
2618 if (X == Z) {
2619 llvm::for_each(Range&: NewMask0, F: ConvertToUnary);
2620 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2621 Z = PoisonValue::get(T: BinOpTy);
2622 }
2623
2624 SmallVector<int> NewMask1(OldMask);
2625 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2626 TTI::OperandValueInfo Op1Info = TTI.commonOperandInfo(X: Y, Y: W);
2627 if (Y == W) {
2628 llvm::for_each(Range&: NewMask1, F: ConvertToUnary);
2629 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2630 W = PoisonValue::get(T: BinOpTy);
2631 }
2632
2633 // Try to replace a binop with a shuffle if the shuffle is not costly.
2634 // When SameBinOp, only count the binop cost once.
2635 InstructionCost LHSCost = TTI.getInstructionCost(U: LHS, CostKind);
2636 InstructionCost RHSCost = TTI.getInstructionCost(U: RHS, CostKind);
2637
2638 InstructionCost OldCost = LHSCost;
2639 if (!SameBinOp) {
2640 OldCost += RHSCost;
2641 }
2642 OldCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc,
2643 DstTy: ShuffleDstTy, SrcTy: BinResTy, Mask: OldMask, CostKind, Index: 0,
2644 SubTp: nullptr, Args: {LHS, RHS}, CxtI: &I);
2645
2646 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2647 // where one use shuffles have gotten split across the binop/cmp. These
2648 // often allow a major reduction in total cost that wouldn't happen as
2649 // individual folds.
2650 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2651 TTI::TargetCostKind CostKind) -> bool {
2652 Value *InnerOp;
2653 ArrayRef<int> InnerMask;
2654 if (match(V: Op, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: InnerOp), v2: m_Undef(),
2655 mask: m_Mask(InnerMask)))) &&
2656 InnerOp->getType() == Op->getType() &&
2657 all_of(Range&: InnerMask,
2658 P: [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2659 for (int &M : Mask)
2660 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2661 M = InnerMask[M - Offset];
2662 M = 0 <= M ? M + Offset : M;
2663 }
2664 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: Op), CostKind);
2665 Op = InnerOp;
2666 return true;
2667 }
2668 return false;
2669 };
2670 bool ReducedInstCount = false;
2671 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2672 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2673 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2674 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2675 bool SingleSrcBinOp = (X == Y) && (Z == W) && (NewMask0 == NewMask1);
2676 // SingleSrcBinOp only reduces instruction count if we also eliminate the
2677 // original binop(s). If binops have multiple uses, they won't be eliminated.
2678 ReducedInstCount |= SingleSrcBinOp && LHS->hasOneUser() && RHS->hasOneUser();
2679
2680 auto *ShuffleCmpTy =
2681 FixedVectorType::get(ElementType: BinOpTy->getElementType(), FVTy: ShuffleDstTy);
2682 InstructionCost NewCost = TTI.getShuffleCost(
2683 Kind: SK0, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {X, Z});
2684 if (!SingleSrcBinOp)
2685 NewCost += TTI.getShuffleCost(Kind: SK1, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask1,
2686 CostKind, Index: 0, SubTp: nullptr, Args: {Y, W});
2687
2688 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2689 NewCost += TTI.getArithmeticInstrCost(Opcode: LHS->getOpcode(), Ty: ShuffleDstTy,
2690 CostKind, Opd1Info: Op0Info, Opd2Info: Op1Info);
2691 } else {
2692 NewCost +=
2693 TTI.getCmpSelInstrCost(Opcode: LHS->getOpcode(), ValTy: ShuffleCmpTy, CondTy: ShuffleDstTy,
2694 VecPred: PredLHS, CostKind, Op1Info: Op0Info, Op2Info: Op1Info);
2695 }
2696 // If LHS/RHS have other uses, we need to account for the cost of keeping
2697 // the original instructions. When SameBinOp, only add the cost once.
2698 if (!LHS->hasOneUser())
2699 NewCost += LHSCost;
2700 if (!SameBinOp && !RHS->hasOneUser())
2701 NewCost += RHSCost;
2702
2703 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2704 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2705 << "\n");
2706
2707 // If either shuffle will constant fold away, then fold for the same cost as
2708 // we will reduce the instruction count.
2709 ReducedInstCount |= (isa<Constant>(Val: X) && isa<Constant>(Val: Z)) ||
2710 (isa<Constant>(Val: Y) && isa<Constant>(Val: W));
2711 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2712 return false;
2713
2714 Value *Shuf0 = Builder.CreateShuffleVector(V1: X, V2: Z, Mask: NewMask0);
2715 Value *Shuf1 =
2716 SingleSrcBinOp ? Shuf0 : Builder.CreateShuffleVector(V1: Y, V2: W, Mask: NewMask1);
2717 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2718 ? Builder.CreateBinOp(
2719 Opc: cast<BinaryOperator>(Val: LHS)->getOpcode(), LHS: Shuf0, RHS: Shuf1)
2720 : Builder.CreateCmp(Pred: PredLHS, LHS: Shuf0, RHS: Shuf1);
2721
2722 // Intersect flags from the old binops.
2723 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
2724 NewInst->copyIRFlags(V: LHS);
2725 NewInst->andIRFlags(V: RHS);
2726 }
2727
2728 Worklist.pushValue(V: Shuf0);
2729 Worklist.pushValue(V: Shuf1);
2730 replaceValue(Old&: I, New&: *NewBO);
2731 return true;
2732}
2733
2734/// Try to convert,
2735/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2736/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2737bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2738 ArrayRef<int> Mask;
2739 Value *C1, *T1, *F1, *C2, *T2, *F2;
2740 if (!match(V: &I, P: m_Shuffle(v1: m_Select(C: m_Value(V&: C1), L: m_Value(V&: T1), R: m_Value(V&: F1)),
2741 v2: m_Select(C: m_Value(V&: C2), L: m_Value(V&: T2), R: m_Value(V&: F2)),
2742 mask: m_Mask(Mask))))
2743 return false;
2744
2745 auto *Sel1 = cast<Instruction>(Val: I.getOperand(i: 0));
2746 auto *Sel2 = cast<Instruction>(Val: I.getOperand(i: 1));
2747
2748 auto *C1VecTy = dyn_cast<FixedVectorType>(Val: C1->getType());
2749 auto *C2VecTy = dyn_cast<FixedVectorType>(Val: C2->getType());
2750 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2751 return false;
2752
2753 auto *SI0FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 0));
2754 auto *SI1FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 1));
2755 // SelectInsts must have the same FMF.
2756 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2757 ((SI0FOp != nullptr) &&
2758 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2759 return false;
2760
2761 auto *SrcVecTy = cast<FixedVectorType>(Val: T1->getType());
2762 auto *DstVecTy = cast<FixedVectorType>(Val: I.getType());
2763 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2764 auto SelOp = Instruction::Select;
2765
2766 InstructionCost CostSel1 = TTI.getCmpSelInstrCost(
2767 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C1VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2768 InstructionCost CostSel2 = TTI.getCmpSelInstrCost(
2769 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C2VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2770
2771 InstructionCost OldCost =
2772 CostSel1 + CostSel2 +
2773 TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
2774 Args: {I.getOperand(i: 0), I.getOperand(i: 1)}, CxtI: &I);
2775
2776 InstructionCost NewCost = TTI.getShuffleCost(
2777 Kind: SK, DstTy: FixedVectorType::get(ElementType: C1VecTy->getScalarType(), NumElts: Mask.size()), SrcTy: C1VecTy,
2778 Mask, CostKind, Index: 0, SubTp: nullptr, Args: {C1, C2});
2779 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2780 SubTp: nullptr, Args: {T1, T2});
2781 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2782 SubTp: nullptr, Args: {F1, F2});
2783 auto *C1C2ShuffledVecTy = FixedVectorType::get(
2784 ElementType: Type::getInt1Ty(C&: I.getContext()), NumElts: DstVecTy->getNumElements());
2785 NewCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: DstVecTy, CondTy: C1C2ShuffledVecTy,
2786 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2787
2788 if (!Sel1->hasOneUse())
2789 NewCost += CostSel1;
2790 if (!Sel2->hasOneUse())
2791 NewCost += CostSel2;
2792
2793 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2794 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2795 << "\n");
2796 if (NewCost > OldCost)
2797 return false;
2798
2799 Value *ShuffleCmp = Builder.CreateShuffleVector(V1: C1, V2: C2, Mask);
2800 Value *ShuffleTrue = Builder.CreateShuffleVector(V1: T1, V2: T2, Mask);
2801 Value *ShuffleFalse = Builder.CreateShuffleVector(V1: F1, V2: F2, Mask);
2802 Value *NewSel;
2803 // We presuppose that the SelectInsts have the same FMF.
2804 if (SI0FOp)
2805 NewSel = Builder.CreateSelectFMF(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse,
2806 FMFSource: SI0FOp->getFastMathFlags());
2807 else
2808 NewSel = Builder.CreateSelect(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse);
2809
2810 Worklist.pushValue(V: ShuffleCmp);
2811 Worklist.pushValue(V: ShuffleTrue);
2812 Worklist.pushValue(V: ShuffleFalse);
2813 replaceValue(Old&: I, New&: *NewSel);
2814 return true;
2815}
2816
2817/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2818/// into "castop (shuffle)".
2819bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2820 Value *V0, *V1;
2821 ArrayRef<int> OldMask;
2822 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
2823 return false;
2824
2825 // Check whether this is a binary shuffle.
2826 bool IsBinaryShuffle = !isa<UndefValue>(Val: V1);
2827
2828 auto *C0 = dyn_cast<CastInst>(Val: V0);
2829 auto *C1 = dyn_cast<CastInst>(Val: V1);
2830 if (!C0 || (IsBinaryShuffle && !C1))
2831 return false;
2832
2833 Instruction::CastOps Opcode = C0->getOpcode();
2834
2835 // If this is allowed, foldShuffleOfCastops can get stuck in a loop
2836 // with foldBitcastOfShuffle. Reject in favor of foldBitcastOfShuffle.
2837 if (!IsBinaryShuffle && Opcode == Instruction::BitCast)
2838 return false;
2839
2840 if (IsBinaryShuffle) {
2841 if (C0->getSrcTy() != C1->getSrcTy())
2842 return false;
2843 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2844 if (Opcode != C1->getOpcode()) {
2845 if (match(V: C0, P: m_SExtLike(Op: m_Value())) && match(V: C1, P: m_SExtLike(Op: m_Value())))
2846 Opcode = Instruction::SExt;
2847 else
2848 return false;
2849 }
2850 }
2851
2852 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2853 auto *CastDstTy = dyn_cast<FixedVectorType>(Val: C0->getDestTy());
2854 auto *CastSrcTy = dyn_cast<FixedVectorType>(Val: C0->getSrcTy());
2855 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2856 return false;
2857
2858 unsigned NumSrcElts = CastSrcTy->getNumElements();
2859 unsigned NumDstElts = CastDstTy->getNumElements();
2860 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2861 "Only bitcasts expected to alter src/dst element counts");
2862
2863 // Check for bitcasting of unscalable vector types.
2864 // e.g. <32 x i40> -> <40 x i32>
2865 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2866 (NumDstElts % NumSrcElts) != 0)
2867 return false;
2868
2869 SmallVector<int, 16> NewMask;
2870 if (NumSrcElts >= NumDstElts) {
2871 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2872 // always be expanded to the equivalent form choosing narrower elements.
2873 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2874 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2875 narrowShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask);
2876 } else {
2877 // The bitcast is from narrow elements to wide elements. The shuffle mask
2878 // must choose consecutive elements to allow casting first.
2879 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2880 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2881 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask))
2882 return false;
2883 }
2884
2885 auto *NewShuffleDstTy =
2886 FixedVectorType::get(ElementType: CastSrcTy->getScalarType(), NumElts: NewMask.size());
2887
2888 // Try to replace a castop with a shuffle if the shuffle is not costly.
2889 InstructionCost CostC0 =
2890 TTI.getCastInstrCost(Opcode: C0->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2891 CCH: TTI::CastContextHint::None, CostKind, I: C0);
2892
2893 TargetTransformInfo::ShuffleKind ShuffleKind;
2894 if (IsBinaryShuffle)
2895 ShuffleKind = TargetTransformInfo::SK_PermuteTwoSrc;
2896 else
2897 ShuffleKind = TargetTransformInfo::SK_PermuteSingleSrc;
2898
2899 InstructionCost OldCost = CostC0;
2900 OldCost += TTI.getShuffleCost(Kind: ShuffleKind, DstTy: ShuffleDstTy, SrcTy: CastDstTy, Mask: OldMask,
2901 CostKind, Index: 0, SubTp: nullptr, Args: {}, CxtI: &I);
2902
2903 InstructionCost NewCost = TTI.getShuffleCost(Kind: ShuffleKind, DstTy: NewShuffleDstTy,
2904 SrcTy: CastSrcTy, Mask: NewMask, CostKind);
2905 NewCost += TTI.getCastInstrCost(Opcode, Dst: ShuffleDstTy, Src: NewShuffleDstTy,
2906 CCH: TTI::CastContextHint::None, CostKind);
2907 if (!C0->hasOneUse())
2908 NewCost += CostC0;
2909 if (IsBinaryShuffle) {
2910 InstructionCost CostC1 =
2911 TTI.getCastInstrCost(Opcode: C1->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2912 CCH: TTI::CastContextHint::None, CostKind, I: C1);
2913 OldCost += CostC1;
2914 if (!C1->hasOneUse())
2915 NewCost += CostC1;
2916 }
2917
2918 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2919 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2920 << "\n");
2921 if (NewCost > OldCost)
2922 return false;
2923
2924 Value *Shuf;
2925 if (IsBinaryShuffle)
2926 Shuf = Builder.CreateShuffleVector(V1: C0->getOperand(i_nocapture: 0), V2: C1->getOperand(i_nocapture: 0),
2927 Mask: NewMask);
2928 else
2929 Shuf = Builder.CreateShuffleVector(V: C0->getOperand(i_nocapture: 0), Mask: NewMask);
2930
2931 Value *Cast = Builder.CreateCast(Op: Opcode, V: Shuf, DestTy: ShuffleDstTy);
2932
2933 // Intersect flags from the old casts.
2934 if (auto *NewInst = dyn_cast<Instruction>(Val: Cast)) {
2935 NewInst->copyIRFlags(V: C0);
2936 if (IsBinaryShuffle)
2937 NewInst->andIRFlags(V: C1);
2938 }
2939
2940 Worklist.pushValue(V: Shuf);
2941 replaceValue(Old&: I, New&: *Cast);
2942 return true;
2943}
2944
2945/// Try to convert any of:
2946/// "shuffle (shuffle x, y), (shuffle y, x)"
2947/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2948/// "shuffle (shuffle x, undef), y"
2949/// "shuffle x, (shuffle y, undef)"
2950/// into "shuffle x, y".
2951bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2952 ArrayRef<int> OuterMask;
2953 Value *OuterV0, *OuterV1;
2954 if (!match(V: &I,
2955 P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1), mask: m_Mask(OuterMask))))
2956 return false;
2957
2958 ArrayRef<int> InnerMask0, InnerMask1;
2959 Value *X0, *X1, *Y0, *Y1;
2960 bool Match0 =
2961 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: X0), v2: m_Value(V&: Y0), mask: m_Mask(InnerMask0)));
2962 bool Match1 =
2963 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: X1), v2: m_Value(V&: Y1), mask: m_Mask(InnerMask1)));
2964 if (!Match0 && !Match1)
2965 return false;
2966
2967 // If the outer shuffle is a permute, then create a fake inner all-poison
2968 // shuffle. This is easier than accounting for length-changing shuffles below.
2969 SmallVector<int, 16> PoisonMask1;
2970 if (!Match1 && isa<PoisonValue>(Val: OuterV1)) {
2971 X1 = X0;
2972 Y1 = Y0;
2973 PoisonMask1.append(NumInputs: InnerMask0.size(), Elt: PoisonMaskElem);
2974 InnerMask1 = PoisonMask1;
2975 Match1 = true; // fake match
2976 }
2977
2978 X0 = Match0 ? X0 : OuterV0;
2979 Y0 = Match0 ? Y0 : OuterV0;
2980 X1 = Match1 ? X1 : OuterV1;
2981 Y1 = Match1 ? Y1 : OuterV1;
2982 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2983 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(Val: X0->getType());
2984 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(Val: OuterV0->getType());
2985 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2986 X0->getType() != X1->getType())
2987 return false;
2988
2989 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2990 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2991
2992 // Attempt to merge shuffles, matching upto 2 source operands.
2993 // Replace index to a poison arg with PoisonMaskElem.
2994 // Bail if either inner masks reference an undef arg.
2995 SmallVector<int, 16> NewMask(OuterMask);
2996 Value *NewX = nullptr, *NewY = nullptr;
2997 for (int &M : NewMask) {
2998 Value *Src = nullptr;
2999 if (0 <= M && M < (int)NumImmElts) {
3000 Src = OuterV0;
3001 if (Match0) {
3002 M = InnerMask0[M];
3003 Src = M >= (int)NumSrcElts ? Y0 : X0;
3004 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3005 }
3006 } else if (M >= (int)NumImmElts) {
3007 Src = OuterV1;
3008 M -= NumImmElts;
3009 if (Match1) {
3010 M = InnerMask1[M];
3011 Src = M >= (int)NumSrcElts ? Y1 : X1;
3012 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
3013 }
3014 }
3015 if (Src && M != PoisonMaskElem) {
3016 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
3017 if (isa<UndefValue>(Val: Src)) {
3018 // We've referenced an undef element - if its poison, update the shuffle
3019 // mask, else bail.
3020 if (!isa<PoisonValue>(Val: Src))
3021 return false;
3022 M = PoisonMaskElem;
3023 continue;
3024 }
3025 if (!NewX || NewX == Src) {
3026 NewX = Src;
3027 continue;
3028 }
3029 if (!NewY || NewY == Src) {
3030 M += NumSrcElts;
3031 NewY = Src;
3032 continue;
3033 }
3034 return false;
3035 }
3036 }
3037
3038 if (!NewX) {
3039 replaceValue(Old&: I, New&: *PoisonValue::get(T: ShuffleDstTy));
3040 return true;
3041 }
3042
3043 if (!NewY)
3044 NewY = PoisonValue::get(T: ShuffleSrcTy);
3045
3046 // Have we folded to an Identity shuffle?
3047 if (ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts)) {
3048 replaceValue(Old&: I, New&: *NewX);
3049 return true;
3050 }
3051
3052 // Try to merge the shuffles if the new shuffle is not costly.
3053 InstructionCost InnerCost0 = 0;
3054 if (Match0)
3055 InnerCost0 = TTI.getInstructionCost(U: cast<User>(Val: OuterV0), CostKind);
3056
3057 InstructionCost InnerCost1 = 0;
3058 if (Match1)
3059 InnerCost1 = TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3060
3061 InstructionCost OuterCost = TTI.getInstructionCost(U: &I, CostKind);
3062
3063 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
3064
3065 bool IsUnary = all_of(Range&: NewMask, P: [&](int M) { return M < (int)NumSrcElts; });
3066 TargetTransformInfo::ShuffleKind SK =
3067 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
3068 : TargetTransformInfo::SK_PermuteTwoSrc;
3069 InstructionCost NewCost =
3070 TTI.getShuffleCost(Kind: SK, DstTy: ShuffleDstTy, SrcTy: ShuffleSrcTy, Mask: NewMask, CostKind, Index: 0,
3071 SubTp: nullptr, Args: {NewX, NewY});
3072 if (!OuterV0->hasOneUse())
3073 NewCost += InnerCost0;
3074 if (!OuterV1->hasOneUse())
3075 NewCost += InnerCost1;
3076
3077 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
3078 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3079 << "\n");
3080 if (NewCost > OldCost)
3081 return false;
3082
3083 Value *Shuf = Builder.CreateShuffleVector(V1: NewX, V2: NewY, Mask: NewMask);
3084 replaceValue(Old&: I, New&: *Shuf);
3085 return true;
3086}
3087
3088/// Try to convert a chain of length-preserving shuffles that are fed by
3089/// length-changing shuffles from the same source, e.g. a chain of length 3:
3090///
3091/// "shuffle (shuffle (shuffle x, (shuffle y, undef)),
3092/// (shuffle y, undef)),
3093// (shuffle y, undef)"
3094///
3095/// into a single shuffle fed by a length-changing shuffle:
3096///
3097/// "shuffle x, (shuffle y, undef)"
3098///
3099/// Such chains arise e.g. from folding extract/insert sequences.
3100bool VectorCombine::foldShufflesOfLengthChangingShuffles(Instruction &I) {
3101 FixedVectorType *TrunkType = dyn_cast<FixedVectorType>(Val: I.getType());
3102 if (!TrunkType)
3103 return false;
3104
3105 unsigned ChainLength = 0;
3106 SmallVector<int> Mask;
3107 SmallVector<int> YMask;
3108 InstructionCost OldCost = 0;
3109 InstructionCost NewCost = 0;
3110 Value *Trunk = &I;
3111 unsigned NumTrunkElts = TrunkType->getNumElements();
3112 Value *Y = nullptr;
3113
3114 for (;;) {
3115 // Match the current trunk against (commutations of) the pattern
3116 // "shuffle trunk', (shuffle y, undef)"
3117 ArrayRef<int> OuterMask;
3118 Value *OuterV0, *OuterV1;
3119 if (ChainLength != 0 && !Trunk->hasOneUse())
3120 break;
3121 if (!match(V: Trunk, P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1),
3122 mask: m_Mask(OuterMask))))
3123 break;
3124 if (OuterV0->getType() != TrunkType) {
3125 // This shuffle is not length-preserving, so it cannot be part of the
3126 // chain.
3127 break;
3128 }
3129
3130 ArrayRef<int> InnerMask0, InnerMask1;
3131 Value *A0, *A1, *B0, *B1;
3132 bool Match0 =
3133 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: A0), v2: m_Value(V&: B0), mask: m_Mask(InnerMask0)));
3134 bool Match1 =
3135 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: A1), v2: m_Value(V&: B1), mask: m_Mask(InnerMask1)));
3136 bool Match0Leaf = Match0 && A0->getType() != I.getType();
3137 bool Match1Leaf = Match1 && A1->getType() != I.getType();
3138 if (Match0Leaf == Match1Leaf) {
3139 // Only handle the case of exactly one leaf in each step. The "two leaves"
3140 // case is handled by foldShuffleOfShuffles.
3141 break;
3142 }
3143
3144 SmallVector<int> CommutedOuterMask;
3145 if (Match0Leaf) {
3146 std::swap(a&: OuterV0, b&: OuterV1);
3147 std::swap(a&: InnerMask0, b&: InnerMask1);
3148 std::swap(a&: A0, b&: A1);
3149 std::swap(a&: B0, b&: B1);
3150 llvm::append_range(C&: CommutedOuterMask, R&: OuterMask);
3151 for (int &M : CommutedOuterMask) {
3152 if (M == PoisonMaskElem)
3153 continue;
3154 if (M < (int)NumTrunkElts)
3155 M += NumTrunkElts;
3156 else
3157 M -= NumTrunkElts;
3158 }
3159 OuterMask = CommutedOuterMask;
3160 }
3161 if (!OuterV1->hasOneUse())
3162 break;
3163
3164 if (!isa<UndefValue>(Val: A1)) {
3165 if (!Y)
3166 Y = A1;
3167 else if (Y != A1)
3168 break;
3169 }
3170 if (!isa<UndefValue>(Val: B1)) {
3171 if (!Y)
3172 Y = B1;
3173 else if (Y != B1)
3174 break;
3175 }
3176
3177 auto *YType = cast<FixedVectorType>(Val: A1->getType());
3178 int NumLeafElts = YType->getNumElements();
3179 SmallVector<int> LocalYMask(InnerMask1);
3180 for (int &M : LocalYMask) {
3181 if (M >= NumLeafElts)
3182 M -= NumLeafElts;
3183 }
3184
3185 InstructionCost LocalOldCost =
3186 TTI.getInstructionCost(U: cast<User>(Val: Trunk), CostKind) +
3187 TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
3188
3189 // Handle the initial (start of chain) case.
3190 if (!ChainLength) {
3191 Mask.assign(AR: OuterMask);
3192 YMask.assign(RHS: LocalYMask);
3193 OldCost = NewCost = LocalOldCost;
3194 Trunk = OuterV0;
3195 ChainLength++;
3196 continue;
3197 }
3198
3199 // For the non-root case, first attempt to combine masks.
3200 SmallVector<int> NewYMask(YMask);
3201 bool Valid = true;
3202 for (auto [CombinedM, LeafM] : llvm::zip(t&: NewYMask, u&: LocalYMask)) {
3203 if (LeafM == -1 || CombinedM == LeafM)
3204 continue;
3205 if (CombinedM == -1) {
3206 CombinedM = LeafM;
3207 } else {
3208 Valid = false;
3209 break;
3210 }
3211 }
3212 if (!Valid)
3213 break;
3214
3215 SmallVector<int> NewMask;
3216 NewMask.reserve(N: NumTrunkElts);
3217 for (int M : Mask) {
3218 if (M < 0 || M >= static_cast<int>(NumTrunkElts))
3219 NewMask.push_back(Elt: M);
3220 else
3221 NewMask.push_back(Elt: OuterMask[M]);
3222 }
3223
3224 // Break the chain if adding this new step complicates the shuffles such
3225 // that it would increase the new cost by more than the old cost of this
3226 // step.
3227 InstructionCost LocalNewCost =
3228 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: TrunkType,
3229 SrcTy: YType, Mask: NewYMask, CostKind) +
3230 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: TrunkType,
3231 SrcTy: TrunkType, Mask: NewMask, CostKind);
3232
3233 if (LocalNewCost >= NewCost && LocalOldCost < LocalNewCost - NewCost)
3234 break;
3235
3236 LLVM_DEBUG({
3237 if (ChainLength == 1) {
3238 dbgs() << "Found chain of shuffles fed by length-changing shuffles: "
3239 << I << '\n';
3240 }
3241 dbgs() << " next chain link: " << *Trunk << '\n'
3242 << " old cost: " << (OldCost + LocalOldCost)
3243 << " new cost: " << LocalNewCost << '\n';
3244 });
3245
3246 Mask = NewMask;
3247 YMask = NewYMask;
3248 OldCost += LocalOldCost;
3249 NewCost = LocalNewCost;
3250 Trunk = OuterV0;
3251 ChainLength++;
3252 }
3253 if (ChainLength <= 1)
3254 return false;
3255
3256 if (llvm::all_of(Range&: Mask, P: [&](int M) {
3257 return M < 0 || M >= static_cast<int>(NumTrunkElts);
3258 })) {
3259 // Produce a canonical simplified form if all elements are sourced from Y.
3260 for (int &M : Mask) {
3261 if (M >= static_cast<int>(NumTrunkElts))
3262 M = YMask[M - NumTrunkElts];
3263 }
3264 Value *Root =
3265 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask);
3266 replaceValue(Old&: I, New&: *Root);
3267 return true;
3268 }
3269
3270 Value *Leaf =
3271 Builder.CreateShuffleVector(V1: Y, V2: PoisonValue::get(T: Y->getType()), Mask: YMask);
3272 Value *Root = Builder.CreateShuffleVector(V1: Trunk, V2: Leaf, Mask);
3273 replaceValue(Old&: I, New&: *Root);
3274 return true;
3275}
3276
3277/// Try to convert
3278/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
3279bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
3280 Value *V0, *V1;
3281 ArrayRef<int> OldMask;
3282 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
3283 return false;
3284
3285 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3286 auto *II1 = dyn_cast<IntrinsicInst>(Val: V1);
3287 if (!II0 || !II1)
3288 return false;
3289
3290 Intrinsic::ID IID = II0->getIntrinsicID();
3291 if (IID != II1->getIntrinsicID())
3292 return false;
3293 InstructionCost CostII0 =
3294 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3295 InstructionCost CostII1 =
3296 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II1), CostKind);
3297
3298 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3299 auto *II0Ty = dyn_cast<FixedVectorType>(Val: II0->getType());
3300 if (!ShuffleDstTy || !II0Ty)
3301 return false;
3302
3303 if (!isTriviallyVectorizable(ID: IID))
3304 return false;
3305
3306 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3307 Value *Arg0 = II0->getArgOperand(i: I);
3308 Value *Arg1 = II1->getArgOperand(i: I);
3309 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3310 // Scalar operands must be identical.
3311 if (Arg0 != Arg1)
3312 return false;
3313 } else if (Arg0->getType() != Arg1->getType()) {
3314 // The corresponding vector operands are shuffled together, so they must
3315 // share the same type. For intrinsics overloaded on their operand type
3316 // (e.g. llvm.fptosi.sat), two calls can produce the same result type
3317 // from different operand types; shuffling those would be invalid.
3318 return false;
3319 }
3320 }
3321
3322 InstructionCost OldCost =
3323 CostII0 + CostII1 +
3324 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
3325 SrcTy: II0Ty, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {II0, II1}, CxtI: &I);
3326
3327 SmallVector<Type *> NewArgsTy;
3328 InstructionCost NewCost = 0;
3329 SmallDenseSet<std::pair<Value *, Value *>> SeenOperandPairs;
3330 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3331 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3332 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3333 } else {
3334 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3335 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3336 NumElts: ShuffleDstTy->getNumElements());
3337 NewArgsTy.push_back(Elt: ArgTy);
3338 std::pair<Value *, Value *> OperandPair =
3339 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3340 if (!SeenOperandPairs.insert(V: OperandPair).second) {
3341 // We've already computed the cost for this operand pair.
3342 continue;
3343 }
3344 NewCost += TTI.getShuffleCost(
3345 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ArgTy, SrcTy: VecTy, Mask: OldMask,
3346 CostKind, Index: 0, SubTp: nullptr, Args: {II0->getArgOperand(i: I), II1->getArgOperand(i: I)});
3347 }
3348 }
3349 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3350
3351 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3352 if (!II0->hasOneUse())
3353 NewCost += CostII0;
3354 if (II1 != II0 && !II1->hasOneUse())
3355 NewCost += CostII1;
3356
3357 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
3358 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3359 << "\n");
3360
3361 if (NewCost > OldCost)
3362 return false;
3363
3364 SmallVector<Value *> NewArgs;
3365 SmallDenseMap<std::pair<Value *, Value *>, Value *> ShuffleCache;
3366 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
3367 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3368 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3369 } else {
3370 std::pair<Value *, Value *> OperandPair =
3371 std::make_pair(x: II0->getArgOperand(i: I), y: II1->getArgOperand(i: I));
3372 auto It = ShuffleCache.find(Val: OperandPair);
3373 if (It != ShuffleCache.end()) {
3374 // Reuse previously created shuffle for this operand pair.
3375 NewArgs.push_back(Elt: It->second);
3376 continue;
3377 }
3378 Value *Shuf = Builder.CreateShuffleVector(V1: II0->getArgOperand(i: I),
3379 V2: II1->getArgOperand(i: I), Mask: OldMask);
3380 ShuffleCache[OperandPair] = Shuf;
3381 NewArgs.push_back(Elt: Shuf);
3382 Worklist.pushValue(V: Shuf);
3383 }
3384 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3385
3386 // Intersect flags from the old intrinsics.
3387 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic)) {
3388 NewInst->copyIRFlags(V: II0);
3389 NewInst->andIRFlags(V: II1);
3390 }
3391
3392 replaceValue(Old&: I, New&: *NewIntrinsic);
3393 return true;
3394}
3395
3396/// Try to convert
3397/// "shuffle (intrinsic), (poison/undef)" into "intrinsic (shuffle)".
3398bool VectorCombine::foldPermuteOfIntrinsic(Instruction &I) {
3399 Value *V0;
3400 ArrayRef<int> Mask;
3401 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Undef(), mask: m_Mask(Mask))))
3402 return false;
3403
3404 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
3405 if (!II0)
3406 return false;
3407
3408 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
3409 auto *IntrinsicSrcTy = dyn_cast<FixedVectorType>(Val: II0->getType());
3410 if (!ShuffleDstTy || !IntrinsicSrcTy)
3411 return false;
3412
3413 // Validate it's a pure permute, mask should only reference the first vector
3414 unsigned NumSrcElts = IntrinsicSrcTy->getNumElements();
3415 if (any_of(Range&: Mask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
3416 return false;
3417
3418 Intrinsic::ID IID = II0->getIntrinsicID();
3419 if (!isTriviallyVectorizable(ID: IID))
3420 return false;
3421
3422 // Cost analysis
3423 InstructionCost IntrinsicCost =
3424 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind);
3425 InstructionCost OldCost =
3426 IntrinsicCost +
3427 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: ShuffleDstTy,
3428 SrcTy: IntrinsicSrcTy, Mask, CostKind, Index: 0, SubTp: nullptr, Args: {V0}, CxtI: &I);
3429
3430 SmallVector<Type *> NewArgsTy;
3431 InstructionCost NewCost = 0;
3432 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3433 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3434 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
3435 } else {
3436 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
3437 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
3438 NumElts: ShuffleDstTy->getNumElements());
3439 NewArgsTy.push_back(Elt: ArgTy);
3440 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
3441 DstTy: ArgTy, SrcTy: VecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
3442 Args: {II0->getArgOperand(i: I)});
3443 }
3444 }
3445 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
3446 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
3447
3448 // If the intrinsic has multiple uses, we need to account for the cost of
3449 // keeping the original intrinsic around.
3450 if (!II0->hasOneUse())
3451 NewCost += IntrinsicCost;
3452
3453 LLVM_DEBUG(dbgs() << "Found a permute of intrinsic: " << I << "\n OldCost: "
3454 << OldCost << " vs NewCost: " << NewCost << "\n");
3455
3456 if (NewCost > OldCost)
3457 return false;
3458
3459 // Transform
3460 SmallVector<Value *> NewArgs;
3461 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
3462 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
3463 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
3464 } else {
3465 Value *Shuf = Builder.CreateShuffleVector(V: II0->getArgOperand(i: I), Mask);
3466 NewArgs.push_back(Elt: Shuf);
3467 Worklist.pushValue(V: Shuf);
3468 }
3469 }
3470
3471 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
3472
3473 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic))
3474 NewInst->copyIRFlags(V: II0);
3475
3476 replaceValue(Old&: I, New&: *NewIntrinsic);
3477 return true;
3478}
3479
3480using InstLane = std::pair<Value *, int>;
3481
3482static InstLane lookThroughShuffles(Value *V, int Lane) {
3483 while (auto *SV = dyn_cast<ShuffleVectorInst>(Val: V)) {
3484 unsigned NumElts =
3485 cast<FixedVectorType>(Val: SV->getOperand(i_nocapture: 0)->getType())->getNumElements();
3486 int M = SV->getMaskValue(Elt: Lane);
3487 if (M < 0)
3488 return {nullptr, PoisonMaskElem};
3489 if (static_cast<unsigned>(M) < NumElts) {
3490 V = SV->getOperand(i_nocapture: 0);
3491 Lane = M;
3492 } else {
3493 V = SV->getOperand(i_nocapture: 1);
3494 Lane = M - NumElts;
3495 }
3496 }
3497 return InstLane{V, Lane};
3498}
3499
3500static SmallVector<InstLane>
3501generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
3502 SmallVector<InstLane> NItem;
3503 for (InstLane IL : Item) {
3504 auto [U, Lane] = IL;
3505 InstLane OpLane =
3506 U ? lookThroughShuffles(V: cast<Instruction>(Val: U)->getOperand(i: Op), Lane)
3507 : InstLane{nullptr, PoisonMaskElem};
3508 NItem.emplace_back(Args&: OpLane);
3509 }
3510 return NItem;
3511}
3512
3513/// Detect concat of multiple values into a vector
3514static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
3515 const TargetTransformInfo &TTI) {
3516 auto *Ty = cast<FixedVectorType>(Val: Item.front().first->getType());
3517 unsigned NumElts = Ty->getNumElements();
3518 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
3519 return false;
3520
3521 // Check that the concat is free, usually meaning that the type will be split
3522 // during legalization.
3523 SmallVector<int, 16> ConcatMask(NumElts * 2);
3524 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
3525 if (TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc,
3526 DstTy: FixedVectorType::get(ElementType: Ty->getScalarType(), NumElts: NumElts * 2),
3527 SrcTy: Ty, Mask: ConcatMask, CostKind) != 0)
3528 return false;
3529
3530 unsigned NumSlices = Item.size() / NumElts;
3531 // Currently we generate a tree of shuffles for the concats, which limits us
3532 // to a power2.
3533 if (!isPowerOf2_32(Value: NumSlices))
3534 return false;
3535 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
3536 Value *SliceV = Item[Slice * NumElts].first;
3537 if (!SliceV || SliceV->getType() != Ty)
3538 return false;
3539 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
3540 auto [V, Lane] = Item[Slice * NumElts + Elt];
3541 if (Lane != static_cast<int>(Elt) || SliceV != V)
3542 return false;
3543 }
3544 }
3545 return true;
3546}
3547
3548static Value *
3549generateNewInstTree(ArrayRef<InstLane> Item, Use *From, FixedVectorType *Ty,
3550 const DenseSet<std::pair<Value *, Use *>> &IdentityLeafs,
3551 const DenseSet<std::pair<Value *, Use *>> &SplatLeafs,
3552 const DenseSet<std::pair<Value *, Use *>> &ConcatLeafs,
3553 IRBuilderBase &Builder, const TargetTransformInfo *TTI) {
3554 auto [FrontV, FrontLane] = Item.front();
3555
3556 if (IdentityLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3557 return FrontV;
3558 }
3559 if (SplatLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3560 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
3561 return Builder.CreateShuffleVector(V: FrontV, Mask);
3562 }
3563 if (ConcatLeafs.contains(V: std::make_pair(x&: FrontV, y&: From))) {
3564 unsigned NumElts =
3565 cast<FixedVectorType>(Val: FrontV->getType())->getNumElements();
3566 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
3567 for (unsigned S = 0; S < Values.size(); ++S)
3568 Values[S] = Item[S * NumElts].first;
3569
3570 while (Values.size() > 1) {
3571 NumElts *= 2;
3572 SmallVector<int, 16> Mask(NumElts, 0);
3573 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
3574 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
3575 for (unsigned S = 0; S < NewValues.size(); ++S)
3576 NewValues[S] =
3577 Builder.CreateShuffleVector(V1: Values[S * 2], V2: Values[S * 2 + 1], Mask);
3578 Values = NewValues;
3579 }
3580 return Values[0];
3581 }
3582
3583 auto *I = cast<Instruction>(Val: FrontV);
3584 auto *II = dyn_cast<IntrinsicInst>(Val: I);
3585 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
3586 SmallVector<Value *> Ops(NumOps);
3587 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
3588 if (II &&
3589 isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI)) {
3590 Ops[Idx] = II->getOperand(i_nocapture: Idx);
3591 continue;
3592 }
3593 Ops[Idx] = generateNewInstTree(Item: generateInstLaneVectorFromOperand(Item, Op: Idx),
3594 From: &I->getOperandUse(i: Idx), Ty, IdentityLeafs,
3595 SplatLeafs, ConcatLeafs, Builder, TTI);
3596 }
3597
3598 SmallVector<Value *, 8> ValueList;
3599 for (const auto &Lane : Item)
3600 if (Lane.first)
3601 ValueList.push_back(Elt: Lane.first);
3602
3603 Type *DstTy =
3604 FixedVectorType::get(ElementType: I->getType()->getScalarType(), NumElts: Ty->getNumElements());
3605 if (auto *BI = dyn_cast<BinaryOperator>(Val: I)) {
3606 auto *Value = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)BI->getOpcode(),
3607 LHS: Ops[0], RHS: Ops[1]);
3608 propagateIRFlags(I: Value, VL: ValueList);
3609 return Value;
3610 }
3611 if (auto *CI = dyn_cast<CmpInst>(Val: I)) {
3612 auto *Value = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: Ops[0], RHS: Ops[1]);
3613 propagateIRFlags(I: Value, VL: ValueList);
3614 return Value;
3615 }
3616 if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
3617 auto *Value = Builder.CreateSelect(C: Ops[0], True: Ops[1], False: Ops[2], Name: "", MDFrom: SI);
3618 propagateIRFlags(I: Value, VL: ValueList);
3619 return Value;
3620 }
3621 if (auto *CI = dyn_cast<CastInst>(Val: I)) {
3622 auto *Value = Builder.CreateCast(Op: CI->getOpcode(), V: Ops[0], DestTy: DstTy);
3623 propagateIRFlags(I: Value, VL: ValueList);
3624 return Value;
3625 }
3626 if (II) {
3627 auto *Value = Builder.CreateIntrinsic(RetTy: DstTy, ID: II->getIntrinsicID(), Args: Ops);
3628 propagateIRFlags(I: Value, VL: ValueList);
3629 return Value;
3630 }
3631 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
3632 auto *Value =
3633 Builder.CreateUnOp(Opc: (Instruction::UnaryOps)I->getOpcode(), V: Ops[0]);
3634 propagateIRFlags(I: Value, VL: ValueList);
3635 return Value;
3636}
3637
3638// Starting from a shuffle, look up through operands tracking the shuffled index
3639// of each lane. If we can simplify away the shuffles to identities then
3640// do so.
3641bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
3642 auto *Ty = dyn_cast<FixedVectorType>(Val: I.getType());
3643 if (!Ty || I.use_empty())
3644 return false;
3645
3646 SmallVector<InstLane> Start(Ty->getNumElements());
3647 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
3648 Start[M] = lookThroughShuffles(V: &I, Lane: M);
3649
3650 SmallVector<std::pair<SmallVector<InstLane>, Use *>> Worklist;
3651 Worklist.push_back(Elt: std::make_pair(x&: Start, y: &*I.use_begin()));
3652 DenseSet<std::pair<Value *, Use *>> IdentityLeafs, SplatLeafs, ConcatLeafs;
3653 unsigned NumVisited = 0;
3654
3655 while (!Worklist.empty()) {
3656 if (++NumVisited > MaxInstrsToScan)
3657 return false;
3658
3659 auto ItemFrom = Worklist.pop_back_val();
3660 auto Item = ItemFrom.first;
3661 auto From = ItemFrom.second;
3662 auto [FrontV, FrontLane] = Item.front();
3663
3664 // If we found an undef first lane then bail out to keep things simple.
3665 if (!FrontV)
3666 return false;
3667
3668 // Helper to peek through bitcasts to the same value.
3669 auto IsEquiv = [&](Value *X, Value *Y) {
3670 return X->getType() == Y->getType() &&
3671 peekThroughBitcasts(V: X) == peekThroughBitcasts(V: Y);
3672 };
3673
3674 // Look for an identity value.
3675 if (FrontLane == 0 &&
3676 cast<FixedVectorType>(Val: FrontV->getType())->getNumElements() ==
3677 Ty->getNumElements() &&
3678 all_of(Range: drop_begin(RangeOrContainer: enumerate(First&: Item)), P: [IsEquiv, Item](const auto &E) {
3679 Value *FrontV = Item.front().first;
3680 return !E.value().first || (IsEquiv(E.value().first, FrontV) &&
3681 E.value().second == (int)E.index());
3682 })) {
3683 IdentityLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3684 continue;
3685 }
3686 // Look for constants, for the moment only supporting constant splats.
3687 if (auto *C = dyn_cast<Constant>(Val: FrontV);
3688 C && C->getSplatValue() &&
3689 all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3690 Value *FrontV = Item.front().first;
3691 Value *V = IL.first;
3692 return !V || (isa<Constant>(Val: V) &&
3693 cast<Constant>(Val: V)->getSplatValue() ==
3694 cast<Constant>(Val: FrontV)->getSplatValue());
3695 })) {
3696 SplatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3697 continue;
3698 }
3699 // Look for a splat value.
3700 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
3701 auto [FrontV, FrontLane] = Item.front();
3702 auto [V, Lane] = IL;
3703 return !V || (V == FrontV && Lane == FrontLane);
3704 })) {
3705 SplatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3706 continue;
3707 }
3708
3709 // We need each element to be the same type of value, and check that each
3710 // element has a single use.
3711 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
3712 Value *FrontV = Item.front().first;
3713 if (!IL.first)
3714 return true;
3715 Value *V = IL.first;
3716 if (auto *I = dyn_cast<Instruction>(Val: V); I && !I->hasOneUser())
3717 return false;
3718 if (V->getValueID() != FrontV->getValueID())
3719 return false;
3720 if (auto *CI = dyn_cast<CmpInst>(Val: V))
3721 if (CI->getPredicate() != cast<CmpInst>(Val: FrontV)->getPredicate())
3722 return false;
3723 if (auto *CI = dyn_cast<CastInst>(Val: V))
3724 if (CI->getSrcTy()->getScalarType() !=
3725 cast<CastInst>(Val: FrontV)->getSrcTy()->getScalarType())
3726 return false;
3727 if (auto *SI = dyn_cast<SelectInst>(Val: V))
3728 if (!isa<VectorType>(Val: SI->getOperand(i_nocapture: 0)->getType()) ||
3729 SI->getOperand(i_nocapture: 0)->getType() !=
3730 cast<SelectInst>(Val: FrontV)->getOperand(i_nocapture: 0)->getType())
3731 return false;
3732 if (isa<CallInst>(Val: V) && !isa<IntrinsicInst>(Val: V))
3733 return false;
3734 auto *II = dyn_cast<IntrinsicInst>(Val: V);
3735 return !II || (isa<IntrinsicInst>(Val: FrontV) &&
3736 II->getIntrinsicID() ==
3737 cast<IntrinsicInst>(Val: FrontV)->getIntrinsicID() &&
3738 !II->hasOperandBundles());
3739 };
3740 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: CheckLaneIsEquivalentToFirst)) {
3741 // Check the operator is one that we support.
3742 if (isa<BinaryOperator, CmpInst>(Val: FrontV)) {
3743 // We exclude div/rem in case they hit UB from poison lanes.
3744 if (auto *BO = dyn_cast<BinaryOperator>(Val: FrontV);
3745 BO && BO->isIntDivRem())
3746 return false;
3747 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3748 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 0));
3749 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 1),
3750 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 1));
3751 continue;
3752 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
3753 FPToUIInst, SIToFPInst, UIToFPInst>(Val: FrontV)) {
3754 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3755 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: 0));
3756 continue;
3757 } else if (auto *BitCast = dyn_cast<BitCastInst>(Val: FrontV)) {
3758 // TODO: Handle vector widening/narrowing bitcasts.
3759 auto *DstTy = dyn_cast<FixedVectorType>(Val: BitCast->getDestTy());
3760 auto *SrcTy = dyn_cast<FixedVectorType>(Val: BitCast->getSrcTy());
3761 if (DstTy && SrcTy &&
3762 SrcTy->getNumElements() == DstTy->getNumElements()) {
3763 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3764 Args: &BitCast->getOperandUse(i: 0));
3765 continue;
3766 }
3767 } else if (auto *Sel = dyn_cast<SelectInst>(Val: FrontV)) {
3768 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 0),
3769 Args: &Sel->getOperandUse(i: 0));
3770 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 1),
3771 Args: &Sel->getOperandUse(i: 1));
3772 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op: 2),
3773 Args: &Sel->getOperandUse(i: 2));
3774 continue;
3775 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: FrontV);
3776 II && isTriviallyVectorizable(ID: II->getIntrinsicID()) &&
3777 !II->hasOperandBundles()) {
3778 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
3779 if (isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Op,
3780 TTI: &TTI)) {
3781 if (!all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item, Op](InstLane &IL) {
3782 Value *FrontV = Item.front().first;
3783 Value *V = IL.first;
3784 return !V || (cast<Instruction>(Val: V)->getOperand(i: Op) ==
3785 cast<Instruction>(Val: FrontV)->getOperand(i: Op));
3786 }))
3787 return false;
3788 continue;
3789 }
3790 Worklist.emplace_back(Args: generateInstLaneVectorFromOperand(Item, Op),
3791 Args: &cast<Instruction>(Val: FrontV)->getOperandUse(i: Op));
3792 }
3793 continue;
3794 }
3795 }
3796
3797 if (isFreeConcat(Item, CostKind, TTI)) {
3798 ConcatLeafs.insert(V: std::make_pair(x&: FrontV, y&: From));
3799 continue;
3800 }
3801
3802 return false;
3803 }
3804
3805 if (NumVisited <= 1)
3806 return false;
3807
3808 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
3809
3810 // If we got this far, we know the shuffles are superfluous and can be
3811 // removed. Scan through again and generate the new tree of instructions.
3812 Builder.SetInsertPoint(&I);
3813 Value *V = generateNewInstTree(Item: Start, From: &*I.use_begin(), Ty, IdentityLeafs,
3814 SplatLeafs, ConcatLeafs, Builder, TTI: &TTI);
3815 replaceValue(Old&: I, New&: *V);
3816 return true;
3817}
3818
3819/// Given a commutative reduction, the order of the input lanes does not alter
3820/// the results. We can use this to remove certain shuffles feeding the
3821/// reduction, removing the need to shuffle at all.
3822bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
3823 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
3824 if (!II)
3825 return false;
3826 switch (II->getIntrinsicID()) {
3827 case Intrinsic::vector_reduce_add:
3828 case Intrinsic::vector_reduce_mul:
3829 case Intrinsic::vector_reduce_and:
3830 case Intrinsic::vector_reduce_or:
3831 case Intrinsic::vector_reduce_xor:
3832 case Intrinsic::vector_reduce_smin:
3833 case Intrinsic::vector_reduce_smax:
3834 case Intrinsic::vector_reduce_umin:
3835 case Intrinsic::vector_reduce_umax:
3836 break;
3837 default:
3838 return false;
3839 }
3840
3841 // Find all the inputs when looking through operations that do not alter the
3842 // lane order (binops, for example). Currently we look for a single shuffle,
3843 // and can ignore splat values.
3844 std::queue<Value *> Worklist;
3845 SmallPtrSet<Value *, 4> Visited;
3846 ShuffleVectorInst *Shuffle = nullptr;
3847 if (auto *Op = dyn_cast<Instruction>(Val: I.getOperand(i: 0)))
3848 Worklist.push(x: Op);
3849
3850 while (!Worklist.empty()) {
3851 Value *CV = Worklist.front();
3852 Worklist.pop();
3853 if (Visited.contains(Ptr: CV))
3854 continue;
3855
3856 // Splats don't change the order, so can be safely ignored.
3857 if (isSplatValue(V: CV))
3858 continue;
3859
3860 Visited.insert(Ptr: CV);
3861
3862 if (auto *CI = dyn_cast<Instruction>(Val: CV)) {
3863 if (CI->isBinaryOp()) {
3864 for (auto *Op : CI->operand_values())
3865 Worklist.push(x: Op);
3866 continue;
3867 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: CI)) {
3868 if (Shuffle && Shuffle != SV)
3869 return false;
3870 Shuffle = SV;
3871 continue;
3872 }
3873 }
3874
3875 // Anything else is currently an unknown node.
3876 return false;
3877 }
3878
3879 if (!Shuffle)
3880 return false;
3881
3882 // Check all uses of the binary ops and shuffles are also included in the
3883 // lane-invariant operations (Visited should be the list of lanewise
3884 // instructions, including the shuffle that we found).
3885 for (auto *V : Visited)
3886 for (auto *U : V->users())
3887 if (!Visited.contains(Ptr: U) && U != &I)
3888 return false;
3889
3890 FixedVectorType *VecType =
3891 dyn_cast<FixedVectorType>(Val: II->getOperand(i_nocapture: 0)->getType());
3892 if (!VecType)
3893 return false;
3894 FixedVectorType *ShuffleInputType =
3895 dyn_cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType());
3896 if (!ShuffleInputType)
3897 return false;
3898 unsigned NumInputElts = ShuffleInputType->getNumElements();
3899
3900 // Find the mask from sorting the lanes into order. This is most likely to
3901 // become a identity or concat mask. Undef elements are pushed to the end.
3902 SmallVector<int> ConcatMask;
3903 Shuffle->getShuffleMask(Result&: ConcatMask);
3904 sort(C&: ConcatMask, Comp: [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3905 bool UsesSecondVec =
3906 any_of(Range&: ConcatMask, P: [&](int M) { return M >= (int)NumInputElts; });
3907
3908 InstructionCost OldCost = TTI.getShuffleCost(
3909 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3910 SrcTy: ShuffleInputType, Mask: Shuffle->getShuffleMask(), CostKind);
3911 InstructionCost NewCost = TTI.getShuffleCost(
3912 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3913 SrcTy: ShuffleInputType, Mask: ConcatMask, CostKind);
3914
3915 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3916 << "\n");
3917 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3918 << "\n");
3919 bool MadeChanges = false;
3920 if (NewCost < OldCost) {
3921 Builder.SetInsertPoint(Shuffle);
3922 Value *NewShuffle = Builder.CreateShuffleVector(
3923 V1: Shuffle->getOperand(i_nocapture: 0), V2: Shuffle->getOperand(i_nocapture: 1), Mask: ConcatMask);
3924 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3925 replaceValue(Old&: *Shuffle, New&: *NewShuffle);
3926 return true;
3927 }
3928
3929 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3930 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3931 MadeChanges |= foldSelectShuffle(I&: *Shuffle, FromReduction: true);
3932 return MadeChanges;
3933}
3934
3935/// Try to fold a chain of shuffles and ops feeding extractelement(..., 0)
3936/// into llvm.vector.reduce.*, by tracking which lanes contribute to the
3937/// extracted lane and reducing the widest vector whose lanes each contribute
3938/// once.
3939///
3940/// For example:
3941///
3942/// %lo = shufflevector <4 x i32> %a, poison, <2 x i32> <i32 0, i32 1>
3943/// %hi = shufflevector <4 x i32> %a, poison, <2 x i32> <i32 2, i32 3>
3944/// %s = add <2 x i32> %lo, %hi
3945/// %sh = shufflevector <2 x i32> %s, poison, <2 x i32> <i32 1, i32 poison>
3946/// %r = add <2 x i32> %s, %sh
3947/// %e = extractelement <2 x i32> %r, i64 0
3948///
3949/// transforms to:
3950///
3951/// %e = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %a)
3952bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3953 Value *VecOpEE;
3954 if (!match(V: &I, P: m_ExtractElt(Val: m_Value(V&: VecOpEE), Idx: m_Zero())))
3955 return false;
3956
3957 auto *FVT = dyn_cast<FixedVectorType>(Val: VecOpEE->getType());
3958 if (!FVT)
3959 return false;
3960
3961 if (FVT->getNumElements() < 2)
3962 return false;
3963
3964 std::optional<Instruction::BinaryOps> CommonBinOp;
3965 std::optional<Intrinsic::ID> CommonCallOp;
3966
3967 if (auto *BO = dyn_cast<BinaryOperator>(Val: VecOpEE)) {
3968 if (!getReductionForBinop(Opc: BO->getOpcode()))
3969 return false;
3970 CommonBinOp = BO->getOpcode();
3971 } else if (auto *MMI = dyn_cast<MinMaxIntrinsic>(Val: VecOpEE)) {
3972 CommonCallOp = MMI->getIntrinsicID();
3973 } else {
3974 return false;
3975 }
3976
3977 // For floating-point reductions, track FMF intersection across all binops.
3978 FastMathFlags CommonFMF;
3979 bool IsFloatReduction = false;
3980
3981 // A chain node is one we walk through, either a matching-opcode binop/min-max
3982 // or a single-source shuffle. Anything else is a leaf source.
3983 auto IsChainNode = [&](Value *V) {
3984 if (auto *BO = dyn_cast<BinaryOperator>(Val: V))
3985 return CommonBinOp && BO->getOpcode() == *CommonBinOp;
3986 if (auto *MMI = dyn_cast<MinMaxIntrinsic>(Val: V))
3987 return CommonCallOp && MMI->getIntrinsicID() == *CommonCallOp;
3988 if (auto *SVI = dyn_cast<ShuffleVectorInst>(Val: V))
3989 return isa<PoisonValue>(Val: SVI->getOperand(i_nocapture: 1));
3990 return false;
3991 };
3992
3993 // Collect the chain, building Nodes in postorder. Bail if the chain is empty
3994 // or exceeds MaxChainNodes.
3995 constexpr unsigned MaxChainNodes = 32;
3996 SmallSetVector<Value *, 16> Nodes;
3997 SmallSetVector<Value *, 4> Sources;
3998 unsigned NumVisited = 0;
3999 auto AddSource = [&](Value *V) {
4000 if (!isa<FixedVectorType>(Val: V->getType()))
4001 return false;
4002 Sources.insert(X: V);
4003 return true;
4004 };
4005 auto Walk = [&](Value *V, auto &&Walk) -> bool {
4006 if (Nodes.contains(key: V) || Sources.contains(key: V))
4007 return true;
4008 if (++NumVisited > MaxChainNodes)
4009 return false;
4010 if (!IsChainNode(V))
4011 return AddSource(V);
4012 // Chain shuffles always have poison as op1, so only op0 matters.
4013 auto *U = cast<Instruction>(Val: V);
4014 unsigned NumOps = isa<ShuffleVectorInst>(Val: U) ? 1 : 2;
4015 for (unsigned I = 0; I != NumOps; ++I)
4016 if (!Walk(U->getOperand(i: I), Walk))
4017 return false;
4018 if (isa<ShuffleVectorInst>(Val: U) || Nodes.contains(key: U->getOperand(i: 0)) ||
4019 Nodes.contains(key: U->getOperand(i: 1))) {
4020 Nodes.insert(X: V);
4021 return true;
4022 }
4023 // Both operands are leaves so treat this binop as a source rather than
4024 // walking into it.
4025 return AddSource(V);
4026 };
4027 if (!Walk(VecOpEE, Walk) || Nodes.empty())
4028 return false;
4029
4030 bool IsIdempotent =
4031 CommonCallOp || (CommonBinOp && Instruction::isIdempotent(Opcode: *CommonBinOp));
4032
4033 // For FP reductions, require reassoc on every binop and collect FMF.
4034 for (Value *V : Nodes) {
4035 auto *BinOp = dyn_cast<BinaryOperator>(Val: V);
4036 if (!BinOp || !BinOp->getType()->isFPOrFPVectorTy())
4037 continue;
4038 if (!BinOp->hasAllowReassoc())
4039 return false;
4040 if (!IsFloatReduction) {
4041 CommonFMF = BinOp->getFastMathFlags();
4042 IsFloatReduction = true;
4043 } else {
4044 CommonFMF &= BinOp->getFastMathFlags();
4045 }
4046 }
4047
4048 // Top-down demanded elements. For each chain value, track which lanes feed
4049 // the extracted lane 0 and which feed it more than once. Reverse postorder
4050 // visits every use before its value. A binop forwards its demand to both
4051 // operands and a shuffle follows its mask back to the source lane.
4052 struct Demand {
4053 APInt Lanes;
4054 APInt Duplicates;
4055 };
4056 DenseMap<Value *, Demand> Demands;
4057 auto DemandOf = [&](Value *V) -> Demand & {
4058 unsigned N = cast<FixedVectorType>(Val: V->getType())->getNumElements();
4059 Demand &D = Demands[V];
4060 if (D.Lanes.getBitWidth() != N)
4061 D.Lanes = D.Duplicates = APInt::getZero(numBits: N);
4062 return D;
4063 };
4064 DemandOf(VecOpEE).Lanes.setBit(0);
4065 for (Value *V : reverse(C&: Nodes)) {
4066 Demand DV = Demands.lookup(Val: V);
4067 if (DV.Lanes.isZero())
4068 continue;
4069 if (auto *SVI = dyn_cast<ShuffleVectorInst>(Val: V)) {
4070 ArrayRef<int> Mask = SVI->getShuffleMask();
4071 Demand &DS = DemandOf(SVI->getOperand(i_nocapture: 0));
4072 for (unsigned I = 0, E = Mask.size(); I != E; ++I) {
4073 // Skip lanes that are undemanded or map to poison.
4074 if (!DV.Lanes[I] || Mask[I] < 0 ||
4075 (unsigned)Mask[I] >= DS.Lanes.getBitWidth())
4076 continue;
4077 if (DS.Lanes[Mask[I]] || DV.Duplicates[I])
4078 DS.Duplicates.setBit(Mask[I]);
4079 DS.Lanes.setBit(Mask[I]);
4080 }
4081 } else {
4082 auto *U = cast<User>(Val: V);
4083 for (Value *Op : {U->getOperand(i: 0), U->getOperand(i: 1)}) {
4084 Demand &DOp = DemandOf(Op);
4085 // Lanes demanded through more than one path accumulate in Duplicates.
4086 DOp.Duplicates |= DV.Duplicates | (DOp.Lanes & DV.Lanes);
4087 DOp.Lanes |= DV.Lanes;
4088 }
4089 }
4090 }
4091
4092 // Reducing V replaces the entire chain, so every contribution to the result
4093 // must flow through V. Reject if anything above V reads outside the chain.
4094 auto CoversChain = [&](Value *V) {
4095 SmallVector<Value *, 8> Worklist(1, VecOpEE);
4096 SmallPtrSet<Value *, 8> Seen;
4097 Seen.insert(Ptr: VecOpEE);
4098 while (!Worklist.empty()) {
4099 auto *U = cast<Instruction>(Val: Worklist.pop_back_val());
4100 unsigned NumOps = isa<ShuffleVectorInst>(Val: U) ? 1 : 2;
4101 for (unsigned I = 0; I != NumOps; ++I) {
4102 Value *Op = U->getOperand(i: I);
4103 if (Op == V || !Seen.insert(Ptr: Op).second)
4104 continue;
4105 if (!Nodes.contains(key: Op))
4106 return false;
4107 Worklist.push_back(Elt: Op);
4108 }
4109 }
4110 return true;
4111 };
4112
4113 // Reduce a single cleanly demanded source if there is one, otherwise the
4114 // deepest intermediate that covers the chain.
4115 struct ReductionCut {
4116 Value *Src;
4117 APInt Elts;
4118 };
4119 std::optional<ReductionCut> Cut;
4120 for (Value *S : Sources) {
4121 auto It = Demands.find(Val: S);
4122 if (It == Demands.end() || It->second.Lanes.isZero())
4123 continue;
4124 if (Cut || (!IsIdempotent && !It->second.Duplicates.isZero())) {
4125 Cut.reset();
4126 break;
4127 }
4128 Cut = ReductionCut{.Src: S, .Elts: It->second.Lanes};
4129 }
4130 if (!Cut) {
4131 for (Value *V : Nodes) {
4132 if (!isa<BinaryOperator>(Val: V) && !isa<MinMaxIntrinsic>(Val: V))
4133 continue;
4134 auto It = Demands.find(Val: V);
4135 if (It == Demands.end() || !It->second.Lanes.isAllOnes())
4136 continue;
4137 if (!IsIdempotent && !It->second.Duplicates.isZero())
4138 continue;
4139 if (!CoversChain(V))
4140 continue;
4141 Cut = ReductionCut{.Src: V, .Elts: It->second.Lanes};
4142 break;
4143 }
4144 }
4145 // Reducing one lane is just an extract and can refold forever.
4146 if (!Cut || Cut->Elts.popcount() < 2)
4147 return false;
4148
4149 Intrinsic::ID ReducedOp =
4150 (CommonCallOp ? getMinMaxReductionIntrinsicID(IID: *CommonCallOp)
4151 : getReductionForBinop(Opc: *CommonBinOp));
4152 if (!ReducedOp)
4153 return false;
4154
4155 InstructionCost OrigCost = 0;
4156 for (Value *V : Nodes)
4157 OrigCost += TTI.getInstructionCost(U: cast<Instruction>(Val: V), CostKind);
4158
4159 auto *SrcVT = cast<FixedVectorType>(Val: Cut->Src->getType());
4160 bool IsPartialReduction = !Cut->Elts.isAllOnes();
4161 FixedVectorType *ReduceVecTy =
4162 IsPartialReduction
4163 ? FixedVectorType::get(ElementType: FVT->getElementType(), NumElts: Cut->Elts.popcount())
4164 : SrcVT;
4165
4166 SmallVector<int> ExtractMask;
4167 InstructionCost NewCost = 0;
4168 if (IsPartialReduction) {
4169 for (unsigned I = 0, E = Cut->Elts.getBitWidth(); I != E; ++I)
4170 if (Cut->Elts[I])
4171 ExtractMask.push_back(Elt: I);
4172 unsigned SubIdx = 0, SubLen;
4173 auto SK = Cut->Elts.isShiftedMask(MaskIdx&: SubIdx, MaskLen&: SubLen)
4174 ? TargetTransformInfo::SK_ExtractSubvector
4175 : TargetTransformInfo::SK_PermuteSingleSrc;
4176 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: ReduceVecTy, SrcTy: SrcVT, Mask: ExtractMask, CostKind,
4177 Index: SubIdx, SubTp: ReduceVecTy);
4178 }
4179
4180 IntrinsicCostAttributes ICA(
4181 ReducedOp, ReduceVecTy->getElementType(),
4182 IsFloatReduction
4183 ? SmallVector<Type *, 2>{ReduceVecTy->getElementType(), ReduceVecTy}
4184 : SmallVector<Type *, 2>{ReduceVecTy},
4185 IsFloatReduction ? CommonFMF : FastMathFlags());
4186 NewCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
4187
4188 LLVM_DEBUG(dbgs() << "Found reduction shuffle chain: " << I << "\n OldCost : "
4189 << OrigCost << " vs NewCost: " << NewCost << "\n");
4190
4191 if (!OrigCost.isValid() || !NewCost.isValid())
4192 return false;
4193
4194 if (VecOpEE->hasOneUse() ? (NewCost > OrigCost) : (NewCost >= OrigCost))
4195 return false;
4196
4197 Value *ReduceInput = Cut->Src;
4198 if (IsPartialReduction)
4199 ReduceInput = Builder.CreateShuffleVector(V: Cut->Src, Mask: ExtractMask);
4200
4201 Value *ReducedResult;
4202 if (IsFloatReduction) {
4203 Value *Identity = ConstantExpr::getBinOpIdentity(
4204 Opcode: *CommonBinOp, Ty: ReduceVecTy->getElementType(), /*AllowRHSConstant=*/false,
4205 NSZ: CommonFMF.noSignedZeros());
4206 ReducedResult = Builder.CreateIntrinsic(ID: ReducedOp, OverloadTypes: {ReduceVecTy},
4207 Args: {Identity, ReduceInput}, FMFSource: CommonFMF);
4208 } else {
4209 ReducedResult =
4210 Builder.CreateIntrinsic(ID: ReducedOp, OverloadTypes: {ReduceVecTy}, Args: {ReduceInput});
4211 }
4212 replaceValue(Old&: I, New&: *ReducedResult);
4213
4214 return true;
4215}
4216
4217/// Determine if its more efficient to fold:
4218/// reduce(trunc(x)) -> trunc(reduce(x)).
4219/// reduce(sext(x)) -> sext(reduce(x)).
4220/// reduce(zext(x)) -> zext(reduce(x)).
4221bool VectorCombine::foldCastFromReductions(Instruction &I) {
4222 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
4223 if (!II)
4224 return false;
4225
4226 bool TruncOnly = false;
4227 Intrinsic::ID IID = II->getIntrinsicID();
4228 switch (IID) {
4229 case Intrinsic::vector_reduce_add:
4230 case Intrinsic::vector_reduce_mul:
4231 TruncOnly = true;
4232 break;
4233 case Intrinsic::vector_reduce_and:
4234 case Intrinsic::vector_reduce_or:
4235 case Intrinsic::vector_reduce_xor:
4236 break;
4237 default:
4238 return false;
4239 }
4240
4241 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4242 Value *ReductionSrc = I.getOperand(i: 0);
4243
4244 Value *Src;
4245 if (!match(V: ReductionSrc, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: Src)))) &&
4246 (TruncOnly || !match(V: ReductionSrc, P: m_OneUse(SubPattern: m_ZExtOrSExt(Op: m_Value(V&: Src))))))
4247 return false;
4248
4249 auto CastOpc =
4250 (Instruction::CastOps)cast<Instruction>(Val: ReductionSrc)->getOpcode();
4251
4252 auto *SrcTy = cast<VectorType>(Val: Src->getType());
4253 auto *ReductionSrcTy = cast<VectorType>(Val: ReductionSrc->getType());
4254 Type *ResultTy = I.getType();
4255
4256 InstructionCost OldCost = TTI.getArithmeticReductionCost(
4257 Opcode: ReductionOpc, Ty: ReductionSrcTy, FMF: std::nullopt, CostKind);
4258 OldCost += TTI.getCastInstrCost(Opcode: CastOpc, Dst: ReductionSrcTy, Src: SrcTy,
4259 CCH: TTI::CastContextHint::None, CostKind,
4260 I: cast<CastInst>(Val: ReductionSrc));
4261 InstructionCost NewCost =
4262 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: SrcTy, FMF: std::nullopt,
4263 CostKind) +
4264 TTI.getCastInstrCost(Opcode: CastOpc, Dst: ResultTy, Src: ReductionSrcTy->getScalarType(),
4265 CCH: TTI::CastContextHint::None, CostKind);
4266
4267 if (OldCost <= NewCost || !NewCost.isValid())
4268 return false;
4269
4270 Value *NewReduction = Builder.CreateIntrinsic(RetTy: SrcTy->getScalarType(),
4271 ID: II->getIntrinsicID(), Args: {Src});
4272 Value *NewCast = Builder.CreateCast(Op: CastOpc, V: NewReduction, DestTy: ResultTy);
4273 replaceValue(Old&: I, New&: *NewCast);
4274 return true;
4275}
4276
4277/// Fold:
4278/// icmp pred (reduce.{add,or,and,umax,umin}(signbit_extract(x))), C
4279/// into:
4280/// icmp sgt/slt (reduce.{or,umax,and,umin}(x)), -1/0
4281///
4282/// Sign-bit reductions produce values with known semantics:
4283/// - reduce.{or,umax}: 0 if no element is negative, 1 if any is
4284/// - reduce.{and,umin}: 1 if all elements are negative, 0 if any isn't
4285/// - reduce.add: count of negative elements (0 to NumElts)
4286///
4287/// Both lshr and ashr are supported:
4288/// - lshr produces 0 or 1, so reduce.add range is [0, N]
4289/// - ashr produces 0 or -1, so reduce.add range is [-N, 0]
4290///
4291/// The fold generalizes to multiple source vectors combined with the same
4292/// operation as the reduction. For example:
4293/// reduce.or(or(shr A, shr B)) conceptually extends the vector
4294/// For reduce.add, this changes the count to M*N where M is the number of
4295/// source vectors.
4296///
4297/// We transform to a direct sign check on the original vector using
4298/// reduce.{or,umax} or reduce.{and,umin}.
4299///
4300/// In spirit, it's similar to foldSignBitCheck in InstCombine.
4301bool VectorCombine::foldSignBitReductionCmp(Instruction &I) {
4302 CmpPredicate Pred;
4303 IntrinsicInst *ReduceOp;
4304 const APInt *CmpVal;
4305 if (!match(V: &I,
4306 P: m_ICmp(Pred, L: m_OneUse(SubPattern: m_AnyIntrinsic(I&: ReduceOp)), R: m_APInt(Res&: CmpVal))))
4307 return false;
4308
4309 Intrinsic::ID OrigIID = ReduceOp->getIntrinsicID();
4310 switch (OrigIID) {
4311 case Intrinsic::vector_reduce_or:
4312 case Intrinsic::vector_reduce_umax:
4313 case Intrinsic::vector_reduce_and:
4314 case Intrinsic::vector_reduce_umin:
4315 case Intrinsic::vector_reduce_add:
4316 break;
4317 default:
4318 return false;
4319 }
4320
4321 Value *ReductionSrc = ReduceOp->getArgOperand(i: 0);
4322 auto *VecTy = dyn_cast<FixedVectorType>(Val: ReductionSrc->getType());
4323 if (!VecTy)
4324 return false;
4325
4326 unsigned BitWidth = VecTy->getScalarSizeInBits();
4327 if (BitWidth == 1)
4328 return false;
4329
4330 unsigned NumElts = VecTy->getNumElements();
4331
4332 // Determine the expected tree opcode for multi-vector patterns.
4333 // The tree opcode must match the reduction's underlying operation.
4334 //
4335 // TODO: for pairs of equivalent operators, we should match both,
4336 // not only the most common.
4337 Instruction::BinaryOps TreeOpcode;
4338 switch (OrigIID) {
4339 case Intrinsic::vector_reduce_or:
4340 case Intrinsic::vector_reduce_umax:
4341 TreeOpcode = Instruction::Or;
4342 break;
4343 case Intrinsic::vector_reduce_and:
4344 case Intrinsic::vector_reduce_umin:
4345 TreeOpcode = Instruction::And;
4346 break;
4347 case Intrinsic::vector_reduce_add:
4348 TreeOpcode = Instruction::Add;
4349 break;
4350 default:
4351 llvm_unreachable("Unexpected intrinsic");
4352 }
4353
4354 // Collect sign-bit extraction leaves from an associative tree of TreeOpcode.
4355 // The tree conceptually extends the vector being reduced.
4356 SmallVector<Value *, 8> Worklist;
4357 SmallVector<Value *, 8> Sources; // Original vectors (X in shr X, BW-1)
4358 Worklist.push_back(Elt: ReductionSrc);
4359 std::optional<bool> IsAShr;
4360 constexpr unsigned MaxSources = 8;
4361
4362 // Calculate old cost: all shifts + tree ops + reduction
4363 InstructionCost OldCost = TTI.getInstructionCost(U: ReduceOp, CostKind);
4364
4365 while (!Worklist.empty() && Worklist.size() <= MaxSources &&
4366 Sources.size() <= MaxSources) {
4367 Value *V = Worklist.pop_back_val();
4368
4369 // Try to match sign-bit extraction: shr X, (bitwidth-1)
4370 Value *X;
4371 if (match(V, P: m_OneUse(SubPattern: m_Shr(L: m_Value(V&: X), R: m_SpecificInt(V: BitWidth - 1))))) {
4372 auto *Shr = cast<Instruction>(Val: V);
4373
4374 // All shifts must be the same type (all lshr or all ashr)
4375 bool ThisIsAShr = Shr->getOpcode() == Instruction::AShr;
4376 if (!IsAShr)
4377 IsAShr = ThisIsAShr;
4378 else if (*IsAShr != ThisIsAShr)
4379 return false;
4380
4381 Sources.push_back(Elt: X);
4382
4383 // As part of the fold, we remove all of the shifts, so we need to keep
4384 // track of their costs.
4385 OldCost += TTI.getInstructionCost(U: Shr, CostKind);
4386
4387 continue;
4388 }
4389
4390 // Try to extend through a tree node of the expected opcode
4391 Value *A, *B;
4392 if (!match(V, P: m_OneUse(SubPattern: m_BinOp(Opcode: TreeOpcode, L: m_Value(V&: A), R: m_Value(V&: B)))))
4393 return false;
4394
4395 // We are potentially replacing these operations as well, so we add them
4396 // to the costs.
4397 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: V), CostKind);
4398
4399 Worklist.push_back(Elt: A);
4400 Worklist.push_back(Elt: B);
4401 }
4402
4403 // Must have at least one source and not exceed limit
4404 if (Sources.empty() || Sources.size() > MaxSources ||
4405 Worklist.size() > MaxSources || !IsAShr)
4406 return false;
4407
4408 unsigned NumSources = Sources.size();
4409
4410 // For reduce.add, the total count must fit as a signed integer.
4411 // Range is [0, M*N] for lshr or [-M*N, 0] for ashr.
4412 if (OrigIID == Intrinsic::vector_reduce_add &&
4413 !isIntN(N: BitWidth, x: NumSources * NumElts))
4414 return false;
4415
4416 // Compute the boundary value when all elements are negative:
4417 // - Per-element contribution: 1 for lshr, -1 for ashr
4418 // - For add: M*N (total elements across all sources); for others: just 1
4419 unsigned Count =
4420 (OrigIID == Intrinsic::vector_reduce_add) ? NumSources * NumElts : 1;
4421 APInt NegativeVal(CmpVal->getBitWidth(), Count);
4422 if (*IsAShr)
4423 NegativeVal.negate();
4424
4425 // Range is [min(0, AllNegVal), max(0, AllNegVal)]
4426 APInt Zero = APInt::getZero(numBits: CmpVal->getBitWidth());
4427 APInt RangeLow = APIntOps::smin(A: Zero, B: NegativeVal);
4428 APInt RangeHigh = APIntOps::smax(A: Zero, B: NegativeVal);
4429
4430 // Determine comparison semantics:
4431 // - IsEq: true for equality test, false for inequality
4432 // - TestsNegative: true if testing against AllNegVal, false for zero
4433 //
4434 // In addition to EQ/NE against 0 or AllNegVal, we support inequalities
4435 // that fold to boundary tests given the narrow value range:
4436 // < RangeHigh -> != RangeHigh
4437 // > RangeHigh-1 -> == RangeHigh
4438 // > RangeLow -> != RangeLow
4439 // < RangeLow+1 -> == RangeLow
4440 //
4441 // For inequalities, we work with signed predicates only. Unsigned predicates
4442 // are canonicalized to signed when the range is non-negative (where they are
4443 // equivalent). When the range includes negative values, unsigned predicates
4444 // would have different semantics due to wrap-around, so we reject them.
4445 if (!ICmpInst::isEquality(P: Pred) && !ICmpInst::isSigned(Pred)) {
4446 if (RangeLow.isNegative())
4447 return false;
4448 Pred = ICmpInst::getSignedPredicate(Pred);
4449 }
4450
4451 bool IsEq;
4452 bool TestsNegative;
4453 if (ICmpInst::isEquality(P: Pred)) {
4454 if (CmpVal->isZero()) {
4455 TestsNegative = false;
4456 } else if (*CmpVal == NegativeVal) {
4457 TestsNegative = true;
4458 } else {
4459 return false;
4460 }
4461 IsEq = Pred == ICmpInst::ICMP_EQ;
4462 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeHigh) {
4463 IsEq = false;
4464 TestsNegative = (RangeHigh == NegativeVal);
4465 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeHigh - 1) {
4466 IsEq = true;
4467 TestsNegative = (RangeHigh == NegativeVal);
4468 } else if (Pred == ICmpInst::ICMP_SGT && *CmpVal == RangeLow) {
4469 IsEq = false;
4470 TestsNegative = (RangeLow == NegativeVal);
4471 } else if (Pred == ICmpInst::ICMP_SLT && *CmpVal == RangeLow + 1) {
4472 IsEq = true;
4473 TestsNegative = (RangeLow == NegativeVal);
4474 } else {
4475 return false;
4476 }
4477
4478 // For this fold we support four types of checks:
4479 //
4480 // 1. All lanes are negative - AllNeg
4481 // 2. All lanes are non-negative - AllNonNeg
4482 // 3. At least one negative lane - AnyNeg
4483 // 4. At least one non-negative lane - AnyNonNeg
4484 //
4485 // For each case, we can generate the following code:
4486 //
4487 // 1. AllNeg - reduce.and/umin(X) < 0
4488 // 2. AllNonNeg - reduce.or/umax(X) > -1
4489 // 3. AnyNeg - reduce.or/umax(X) < 0
4490 // 4. AnyNonNeg - reduce.and/umin(X) > -1
4491 //
4492 // The table below shows the aggregation of all supported cases
4493 // using these four cases.
4494 //
4495 // Reduction | == 0 | != 0 | == MAX | != MAX
4496 // ------------+-----------+-----------+-----------+-----------
4497 // or/umax | AllNonNeg | AnyNeg | AnyNeg | AllNonNeg
4498 // and/umin | AnyNonNeg | AllNeg | AllNeg | AnyNonNeg
4499 // add | AllNonNeg | AnyNeg | AllNeg | AnyNonNeg
4500 //
4501 // NOTE: MAX = 1 for or/and/umax/umin, and the vector size N for add
4502 //
4503 // For easier codegen and check inversion, we use the following encoding:
4504 //
4505 // 1. Bit-3 === requires or/umax (1) or and/umin (0) check
4506 // 2. Bit-2 === checks < 0 (1) or > -1 (0)
4507 // 3. Bit-1 === universal (1) or existential (0) check
4508 //
4509 // AnyNeg = 0b110: uses or/umax, checks negative, any-check
4510 // AllNonNeg = 0b101: uses or/umax, checks non-neg, all-check
4511 // AnyNonNeg = 0b000: uses and/umin, checks non-neg, any-check
4512 // AllNeg = 0b011: uses and/umin, checks negative, all-check
4513 //
4514 // XOR with 0b011 inverts the check (swaps all/any and neg/non-neg).
4515 //
4516 enum CheckKind : unsigned {
4517 AnyNonNeg = 0b000,
4518 AllNeg = 0b011,
4519 AllNonNeg = 0b101,
4520 AnyNeg = 0b110,
4521 };
4522 // Return true if we fold this check into or/umax and false for and/umin
4523 auto RequiresOr = [](CheckKind C) -> bool { return C & 0b100; };
4524 // Return true if we should check if result is negative and false otherwise
4525 auto IsNegativeCheck = [](CheckKind C) -> bool { return C & 0b010; };
4526 // Logically invert the check
4527 auto Invert = [](CheckKind C) { return CheckKind(C ^ 0b011); };
4528
4529 CheckKind Base;
4530 switch (OrigIID) {
4531 case Intrinsic::vector_reduce_or:
4532 case Intrinsic::vector_reduce_umax:
4533 Base = TestsNegative ? AnyNeg : AllNonNeg;
4534 break;
4535 case Intrinsic::vector_reduce_and:
4536 case Intrinsic::vector_reduce_umin:
4537 Base = TestsNegative ? AllNeg : AnyNonNeg;
4538 break;
4539 case Intrinsic::vector_reduce_add:
4540 Base = TestsNegative ? AllNeg : AllNonNeg;
4541 break;
4542 default:
4543 llvm_unreachable("Unexpected intrinsic");
4544 }
4545
4546 CheckKind Check = IsEq ? Base : Invert(Base);
4547
4548 auto PickCheaper = [&](Intrinsic::ID Arith, Intrinsic::ID MinMax) {
4549 InstructionCost ArithCost =
4550 TTI.getArithmeticReductionCost(Opcode: getArithmeticReductionInstruction(RdxID: Arith),
4551 Ty: VecTy, FMF: std::nullopt, CostKind);
4552 InstructionCost MinMaxCost =
4553 TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: MinMax), Ty: VecTy,
4554 FMF: FastMathFlags(), CostKind);
4555 return ArithCost <= MinMaxCost ? std::make_pair(x&: Arith, y&: ArithCost)
4556 : std::make_pair(x&: MinMax, y&: MinMaxCost);
4557 };
4558
4559 // Choose output reduction based on encoding's MSB
4560 auto [NewIID, NewCost] = RequiresOr(Check)
4561 ? PickCheaper(Intrinsic::vector_reduce_or,
4562 Intrinsic::vector_reduce_umax)
4563 : PickCheaper(Intrinsic::vector_reduce_and,
4564 Intrinsic::vector_reduce_umin);
4565
4566 // Add cost of combining multiple sources with or/and
4567 if (NumSources > 1) {
4568 unsigned CombineOpc =
4569 RequiresOr(Check) ? Instruction::Or : Instruction::And;
4570 NewCost += TTI.getArithmeticInstrCost(Opcode: CombineOpc, Ty: VecTy, CostKind) *
4571 (NumSources - 1);
4572 }
4573
4574 LLVM_DEBUG(dbgs() << "Found sign-bit reduction cmp: " << I << "\n OldCost: "
4575 << OldCost << " vs NewCost: " << NewCost << "\n");
4576
4577 if (NewCost > OldCost)
4578 return false;
4579
4580 // Generate the combined input and reduction
4581 Builder.SetInsertPoint(&I);
4582 Type *ScalarTy = VecTy->getScalarType();
4583
4584 Value *Input;
4585 if (NumSources == 1) {
4586 Input = Sources[0];
4587 } else {
4588 // Combine sources with or/and based on check type
4589 Input = RequiresOr(Check) ? Builder.CreateOr(Ops: Sources)
4590 : Builder.CreateAnd(Ops: Sources);
4591 }
4592
4593 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: NewIID, Args: {Input});
4594 Value *NewCmp = IsNegativeCheck(Check) ? Builder.CreateIsNeg(Arg: NewReduce)
4595 : Builder.CreateIsNotNeg(Arg: NewReduce);
4596 replaceValue(Old&: I, New&: *NewCmp);
4597 return true;
4598}
4599
4600/// Fold a zero test of reduce.or or reduce.umax into a boolean reduction.
4601///
4602/// Vectorization may produce IR that compares the result of a scalar reduction
4603/// with zero. Depending on the target, lowering a reduction and a scalar
4604/// comparison separately can cost more than reducing lane-wise comparison
4605/// results. This fold creates the latter form only when it is not costlier.
4606///
4607/// Before:
4608/// %r = call iT @llvm.vector.reduce.or.vNiT(<N x iT> %x)
4609/// %cmp = icmp ne iT %r, 0
4610///
4611/// After:
4612/// %lane.cmp = icmp ne <N x iT> %x, zeroinitializer
4613/// %cmp = call i1 @llvm.vector.reduce.or.vNi1(<N x i1> %lane.cmp)
4614///
4615/// `reduce.or` and `reduce.umax` are non-zero when at least one lane is
4616/// non-zero. Therefore, `icmp ne` uses the existential `reduce.or` test.
4617/// Conversely, `icmp eq` must check that every lane is zero, so it uses the
4618/// universal `reduce.and` test.
4619///
4620/// Before:
4621/// %r = call iT @llvm.vector.reduce.umax.vNiT(<N x iT> %x)
4622/// %cmp = icmp eq iT %r, 0
4623///
4624/// After:
4625/// %lane.cmp = icmp eq <N x iT> %x, zeroinitializer
4626/// %cmp = call i1 @llvm.vector.reduce.and.vNi1(<N x i1> %lane.cmp)
4627bool VectorCombine::foldReductionZeroTest(Instruction &I) {
4628 CmpPredicate Pred;
4629 Value *Op;
4630
4631 if (!match(V: &I, P: m_c_ICmp(Pred, L: m_Value(V&: Op), R: m_Zero())) ||
4632 !ICmpInst::isEquality(P: Pred))
4633 return false;
4634
4635 auto *II = dyn_cast<IntrinsicInst>(Val: Op);
4636 if (!II || !II->hasOneUse())
4637 return false;
4638
4639 auto ReduceID = II->getIntrinsicID();
4640 if (ReduceID != Intrinsic::vector_reduce_or &&
4641 ReduceID != Intrinsic::vector_reduce_umax)
4642 return false;
4643
4644 Value *Vec = II->getArgOperand(i: 0);
4645 auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType());
4646 if (!VecTy || !VecTy->getElementType()->isIntegerTy())
4647 return false;
4648
4649 // Map the scalar zero test to an any-lane or all-lane boolean reduction.
4650 Intrinsic::ID NewIID = (Pred == ICmpInst::ICMP_NE)
4651 ? Intrinsic::vector_reduce_or
4652 : Intrinsic::vector_reduce_and;
4653
4654 // This is not an unconditional canonicalization: compare the cost of the
4655 // original scalar reduction and compare with the vector compare and i1
4656 // reduction replacement for both reduce.or and reduce.umax.
4657 InstructionCost OldCost = TTI.getInstructionCost(U: II, CostKind) +
4658 TTI.getInstructionCost(U: &I, CostKind);
4659
4660 auto *CmpTy = cast<VectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
4661 InstructionCost NewCost =
4662 TTI.getCmpSelInstrCost(Opcode: Instruction::ICmp, ValTy: VecTy, CondTy: CmpTy, VecPred: Pred, CostKind);
4663 NewCost += TTI.getArithmeticReductionCost(
4664 Opcode: getArithmeticReductionInstruction(RdxID: NewIID), Ty: CmpTy, FMF: std::nullopt, CostKind);
4665
4666 LLVM_DEBUG(dbgs() << "Found a reduction zero test: " << I << "\n OldCost: "
4667 << OldCost << " vs NewCost: " << NewCost << "\n");
4668
4669 if (!OldCost.isValid() || !NewCost.isValid() || NewCost > OldCost)
4670 return false;
4671
4672 Builder.SetInsertPoint(&I);
4673 Value *NewCmp = Builder.CreateICmp(P: Pred, LHS: Vec, RHS: Constant::getNullValue(Ty: VecTy));
4674 Value *NewReduce = Builder.CreateIntrinsic(ID: NewIID, OverloadTypes: {CmpTy}, Args: {NewCmp});
4675 replaceValue(Old&: I, New&: *NewReduce);
4676 return true;
4677}
4678
4679/// vector.reduce.OP f(X_i) == 0 -> vector.reduce.OP X_i == 0
4680///
4681/// We can prove it for cases when:
4682///
4683/// 1. OP X_i == 0 <=> \forall i \in [1, N] X_i == 0
4684/// 1'. OP X_i == 0 <=> \exists j \in [1, N] X_j == 0
4685/// 2. f(x) == 0 <=> x == 0
4686///
4687/// From 1 and 2 (or 1' and 2), we can infer that
4688///
4689/// OP f(X_i) == 0 <=> OP X_i == 0.
4690///
4691/// (1)
4692/// OP f(X_i) == 0 <=> \forall i \in [1, N] f(X_i) == 0
4693/// (2)
4694/// <=> \forall i \in [1, N] X_i == 0
4695/// (1)
4696/// <=> OP(X_i) == 0
4697///
4698/// For some of the OP's and f's, we need to have domain constraints on X
4699/// to ensure properties 1 (or 1') and 2.
4700bool VectorCombine::foldICmpEqZeroVectorReduce(Instruction &I) {
4701 CmpPredicate Pred;
4702 Value *Op;
4703 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: Op), R: m_Zero())) ||
4704 !ICmpInst::isEquality(P: Pred))
4705 return false;
4706
4707 auto *II = dyn_cast<IntrinsicInst>(Val: Op);
4708 if (!II)
4709 return false;
4710
4711 switch (II->getIntrinsicID()) {
4712 case Intrinsic::vector_reduce_add:
4713 case Intrinsic::vector_reduce_or:
4714 case Intrinsic::vector_reduce_umin:
4715 case Intrinsic::vector_reduce_umax:
4716 case Intrinsic::vector_reduce_smin:
4717 case Intrinsic::vector_reduce_smax:
4718 break;
4719 default:
4720 return false;
4721 }
4722
4723 Value *InnerOp = II->getArgOperand(i: 0);
4724
4725 // TODO: fixed vector type might be too restrictive
4726 if (!II->hasOneUse() || !isa<FixedVectorType>(Val: InnerOp->getType()))
4727 return false;
4728
4729 Value *X = nullptr;
4730
4731 // Check for zero-preserving operations where f(x) = 0 <=> x = 0
4732 //
4733 // 1. f(x) = shl nuw x, y for arbitrary y
4734 // 2. f(x) = mul nuw x, c for defined c != 0
4735 // 3. f(x) = zext x
4736 // 4. f(x) = sext x
4737 // 5. f(x) = neg x
4738 //
4739 if (!(match(V: InnerOp, P: m_NUWShl(L: m_Value(V&: X), R: m_Value())) || // Case 1
4740 match(V: InnerOp, P: m_NUWMul(L: m_Value(V&: X), R: m_NonZeroInt())) || // Case 2
4741 match(V: InnerOp, P: m_ZExt(Op: m_Value(V&: X))) || // Case 3
4742 match(V: InnerOp, P: m_SExt(Op: m_Value(V&: X))) || // Case 4
4743 match(V: InnerOp, P: m_Neg(V: m_Value(V&: X))) // Case 5
4744 ))
4745 return false;
4746
4747 SimplifyQuery S = SQ.getWithInstruction(I: &I);
4748 auto *XTy = cast<FixedVectorType>(Val: X->getType());
4749
4750 // Check for domain constraints for all supported reductions.
4751 //
4752 // a. OR X_i - has property 1 for every X
4753 // b. UMAX X_i - has property 1 for every X
4754 // c. UMIN X_i - has property 1' for every X
4755 // d. SMAX X_i - has property 1 for X >= 0
4756 // e. SMIN X_i - has property 1' for X >= 0
4757 // f. ADD X_i - has property 1 for X >= 0 && ADD X_i doesn't sign wrap
4758 //
4759 // In order for the proof to work, we need 1 (or 1') to be true for both
4760 // OP f(X_i) and OP X_i and that's why below we check constraints twice.
4761 //
4762 // NOTE: ADD X_i holds property 1 for a mirror case as well, i.e. when
4763 // X <= 0 && ADD X_i doesn't sign wrap. However, due to the nature
4764 // of known bits, we can't reasonably hold knowledge of "either 0
4765 // or negative".
4766 switch (II->getIntrinsicID()) {
4767 case Intrinsic::vector_reduce_add: {
4768 // We need to check that both X_i and f(X_i) have enough leading
4769 // zeros to not overflow.
4770 KnownBits KnownX = computeKnownBits(V: X, Q: S);
4771 KnownBits KnownFX = computeKnownBits(V: InnerOp, Q: S);
4772 unsigned NumElems = XTy->getNumElements();
4773 // Adding N elements loses at most ceil(log2(N)) leading bits.
4774 unsigned LostBits = Log2_32_Ceil(Value: NumElems);
4775 unsigned LeadingZerosX = KnownX.countMinLeadingZeros();
4776 unsigned LeadingZerosFX = KnownFX.countMinLeadingZeros();
4777 // Need at least one leading zero left after summation to ensure no overflow
4778 if (LeadingZerosX <= LostBits || LeadingZerosFX <= LostBits)
4779 return false;
4780
4781 // We are not checking whether X or f(X) are positive explicitly because
4782 // we implicitly checked for it when we checked if both cases have enough
4783 // leading zeros to not wrap addition.
4784 break;
4785 }
4786 case Intrinsic::vector_reduce_smin:
4787 case Intrinsic::vector_reduce_smax:
4788 // Check whether X >= 0 and f(X) >= 0
4789 if (!isKnownNonNegative(V: InnerOp, SQ: S) || !isKnownNonNegative(V: X, SQ: S))
4790 return false;
4791
4792 break;
4793 default:
4794 break;
4795 };
4796
4797 LLVM_DEBUG(dbgs() << "Found a reduction to 0 comparison with removable op: "
4798 << *II << "\n");
4799
4800 // For zext/sext, check if the transform is profitable using cost model.
4801 // For other operations (shl, mul, neg), we're removing an instruction
4802 // while keeping the same reduction type, so it's always profitable.
4803 if (isa<ZExtInst>(Val: InnerOp) || isa<SExtInst>(Val: InnerOp)) {
4804 auto *FXTy = cast<FixedVectorType>(Val: InnerOp->getType());
4805 Intrinsic::ID IID = II->getIntrinsicID();
4806
4807 InstructionCost ExtCost = TTI.getCastInstrCost(
4808 Opcode: cast<CastInst>(Val: InnerOp)->getOpcode(), Dst: FXTy, Src: XTy,
4809 CCH: TTI::CastContextHint::None, CostKind, I: cast<CastInst>(Val: InnerOp));
4810
4811 InstructionCost OldReduceCost, NewReduceCost;
4812 switch (IID) {
4813 case Intrinsic::vector_reduce_add:
4814 case Intrinsic::vector_reduce_or:
4815 OldReduceCost = TTI.getArithmeticReductionCost(
4816 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: FXTy, FMF: std::nullopt, CostKind);
4817 NewReduceCost = TTI.getArithmeticReductionCost(
4818 Opcode: getArithmeticReductionInstruction(RdxID: IID), Ty: XTy, FMF: std::nullopt, CostKind);
4819 break;
4820 case Intrinsic::vector_reduce_umin:
4821 case Intrinsic::vector_reduce_umax:
4822 case Intrinsic::vector_reduce_smin:
4823 case Intrinsic::vector_reduce_smax:
4824 OldReduceCost = TTI.getMinMaxReductionCost(
4825 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: FXTy, FMF: FastMathFlags(), CostKind);
4826 NewReduceCost = TTI.getMinMaxReductionCost(
4827 IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: XTy, FMF: FastMathFlags(), CostKind);
4828 break;
4829 default:
4830 llvm_unreachable("Unexpected reduction");
4831 }
4832
4833 InstructionCost OldCost = OldReduceCost + ExtCost;
4834 InstructionCost NewCost =
4835 NewReduceCost + (InnerOp->hasOneUse() ? 0 : ExtCost);
4836
4837 LLVM_DEBUG(dbgs() << "Found a removable extension before reduction: "
4838 << *InnerOp << "\n OldCost: " << OldCost
4839 << " vs NewCost: " << NewCost << "\n");
4840
4841 // We consider transformation to still be potentially beneficial even
4842 // when the costs are the same because we might remove a use from f(X)
4843 // and unlock other optimizations. Equal costs would just mean that we
4844 // didn't make it worse in the worst case.
4845 if (NewCost > OldCost)
4846 return false;
4847 }
4848
4849 // Since we support zext and sext as f, we might change the scalar type
4850 // of the intrinsic.
4851 Type *Ty = XTy->getScalarType();
4852 Value *NewReduce = Builder.CreateIntrinsic(RetTy: Ty, ID: II->getIntrinsicID(), Args: {X});
4853 Value *NewCmp =
4854 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::getNullValue(Ty));
4855 replaceValue(Old&: I, New&: *NewCmp);
4856 return true;
4857}
4858
4859/// Fold comparisons of reduce.or/reduce.and with reduce.umax/reduce.umin
4860/// based on cost, preserving the comparison semantics.
4861///
4862/// We use two fundamental properties for each pair:
4863///
4864/// 1. or(X) == 0 <=> umax(X) == 0
4865/// 2. or(X) == 1 <=> umax(X) == 1
4866/// 3. sign(or(X)) == sign(umax(X))
4867///
4868/// 1. and(X) == -1 <=> umin(X) == -1
4869/// 2. and(X) == -2 <=> umin(X) == -2
4870/// 3. sign(and(X)) == sign(umin(X))
4871///
4872/// From these we can infer the following transformations:
4873/// a. or(X) ==/!= 0 <-> umax(X) ==/!= 0
4874/// b. or(X) s< 0 <-> umax(X) s< 0
4875/// c. or(X) s> -1 <-> umax(X) s> -1
4876/// d. or(X) s< 1 <-> umax(X) s< 1
4877/// e. or(X) ==/!= 1 <-> umax(X) ==/!= 1
4878/// f. or(X) s< 2 <-> umax(X) s< 2
4879/// g. and(X) ==/!= -1 <-> umin(X) ==/!= -1
4880/// h. and(X) s< 0 <-> umin(X) s< 0
4881/// i. and(X) s> -1 <-> umin(X) s> -1
4882/// j. and(X) s> -2 <-> umin(X) s> -2
4883/// k. and(X) ==/!= -2 <-> umin(X) ==/!= -2
4884/// l. and(X) s> -3 <-> umin(X) s> -3
4885///
4886bool VectorCombine::foldEquivalentReductionCmp(Instruction &I) {
4887 CmpPredicate Pred;
4888 Value *ReduceOp;
4889 const APInt *CmpVal;
4890 if (!match(V: &I, P: m_ICmp(Pred, L: m_Value(V&: ReduceOp), R: m_APInt(Res&: CmpVal))))
4891 return false;
4892
4893 auto *II = dyn_cast<IntrinsicInst>(Val: ReduceOp);
4894 if (!II || !II->hasOneUse())
4895 return false;
4896
4897 const auto IsValidOrUmaxCmp = [&]() {
4898 // or === umax for i1
4899 if (CmpVal->getBitWidth() == 1)
4900 return true;
4901
4902 // Cases a and e
4903 bool IsEquality =
4904 (CmpVal->isZero() || CmpVal->isOne()) && ICmpInst::isEquality(P: Pred);
4905 // Case c
4906 bool IsPositive = CmpVal->isAllOnes() && Pred == ICmpInst::ICMP_SGT;
4907 // Cases b, d, and f
4908 bool IsNegative = (CmpVal->isZero() || CmpVal->isOne() || *CmpVal == 2) &&
4909 Pred == ICmpInst::ICMP_SLT;
4910 return IsEquality || IsPositive || IsNegative;
4911 };
4912
4913 const auto IsValidAndUminCmp = [&]() {
4914 // and === umin for i1
4915 if (CmpVal->getBitWidth() == 1)
4916 return true;
4917
4918 const auto LeadingOnes = CmpVal->countl_one();
4919
4920 // Cases g and k
4921 bool IsEquality =
4922 (CmpVal->isAllOnes() || LeadingOnes + 1 == CmpVal->getBitWidth()) &&
4923 ICmpInst::isEquality(P: Pred);
4924 // Case h
4925 bool IsNegative = CmpVal->isZero() && Pred == ICmpInst::ICMP_SLT;
4926 // Cases i, j, and l
4927 bool IsPositive =
4928 // if the number has at least N - 2 leading ones
4929 // and the two LSBs are:
4930 // - 1 x 1 -> -1
4931 // - 1 x 0 -> -2
4932 // - 0 x 1 -> -3
4933 LeadingOnes + 2 >= CmpVal->getBitWidth() &&
4934 ((*CmpVal)[0] || (*CmpVal)[1]) && Pred == ICmpInst::ICMP_SGT;
4935 return IsEquality || IsNegative || IsPositive;
4936 };
4937
4938 Intrinsic::ID OriginalIID = II->getIntrinsicID();
4939 Intrinsic::ID AlternativeIID;
4940
4941 // Check if this is a valid comparison pattern and determine the alternate
4942 // reduction intrinsic.
4943 switch (OriginalIID) {
4944 case Intrinsic::vector_reduce_or:
4945 if (!IsValidOrUmaxCmp())
4946 return false;
4947 AlternativeIID = Intrinsic::vector_reduce_umax;
4948 break;
4949 case Intrinsic::vector_reduce_umax:
4950 if (!IsValidOrUmaxCmp())
4951 return false;
4952 AlternativeIID = Intrinsic::vector_reduce_or;
4953 break;
4954 case Intrinsic::vector_reduce_and:
4955 if (!IsValidAndUminCmp())
4956 return false;
4957 AlternativeIID = Intrinsic::vector_reduce_umin;
4958 break;
4959 case Intrinsic::vector_reduce_umin:
4960 if (!IsValidAndUminCmp())
4961 return false;
4962 AlternativeIID = Intrinsic::vector_reduce_and;
4963 break;
4964 default:
4965 return false;
4966 }
4967
4968 Value *X = II->getArgOperand(i: 0);
4969 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
4970 if (!VecTy)
4971 return false;
4972
4973 const auto GetReductionCost = [&](Intrinsic::ID IID) -> InstructionCost {
4974 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
4975 if (ReductionOpc != Instruction::ICmp)
4976 return TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecTy, FMF: std::nullopt,
4977 CostKind);
4978 return TTI.getMinMaxReductionCost(IID: getMinMaxReductionIntrinsicOp(RdxID: IID), Ty: VecTy,
4979 FMF: FastMathFlags(), CostKind);
4980 };
4981
4982 InstructionCost OrigCost = GetReductionCost(OriginalIID);
4983 InstructionCost AltCost = GetReductionCost(AlternativeIID);
4984
4985 LLVM_DEBUG(dbgs() << "Found equivalent reduction cmp: " << I
4986 << "\n OrigCost: " << OrigCost
4987 << " vs AltCost: " << AltCost << "\n");
4988
4989 if (AltCost >= OrigCost)
4990 return false;
4991
4992 Builder.SetInsertPoint(&I);
4993 Type *ScalarTy = VecTy->getScalarType();
4994 Value *NewReduce = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: AlternativeIID, Args: {X});
4995 Value *NewCmp =
4996 Builder.CreateICmp(P: Pred, LHS: NewReduce, RHS: ConstantInt::get(Ty: ScalarTy, V: *CmpVal));
4997
4998 replaceValue(Old&: I, New&: *NewCmp);
4999 return true;
5000}
5001
5002/// Used by foldReduceAddCmpZero to check if we can prove that a value is
5003/// non-positive.
5004/// KnownBits cannot see sext <? x i1> as non-positive: each top bit equals a
5005/// single unknown input bit, which a per-bit lattice cannot track. The fold's
5006/// target shape is popcount-style sums of <N x i1> valid/invalid masks (e.g.
5007/// ray-intersection hits) tested for any-hit.
5008/// Previous attempts to approximate the known bits of such expressions were
5009/// using a fully recursive value tracking approach to infer a constant range
5010/// but ultimately turned to be too expensive in compile time.
5011static bool isKnownNonPositive(const Value *V, const SimplifyQuery &SQ,
5012 unsigned Depth = 0) {
5013 constexpr unsigned MaxLocalDepth = 2;
5014 if (Depth > MaxLocalDepth)
5015 return false;
5016
5017 auto NumSignBits = [&](const Value *X) {
5018 return ComputeNumSignBits(Op: X, DL: SQ.DL, AC: SQ.AC, CxtI: SQ.CxtI, DT: SQ.DT);
5019 };
5020 if (NumSignBits(V) == V->getType()->getScalarSizeInBits())
5021 return true;
5022
5023 Value *A, *B;
5024 if (match(V, P: m_Add(L: m_Value(V&: A), R: m_Value(V&: B))))
5025 return NumSignBits(A) >= 2 && NumSignBits(B) >= 2 &&
5026 isKnownNonPositive(V: A, SQ, Depth: Depth + 1) &&
5027 isKnownNonPositive(V: B, SQ, Depth: Depth + 1);
5028
5029 return computeKnownBits(V, Q: SQ).isNonPositive();
5030}
5031
5032/// Fold (icmp pred (reduce.add X), 0) to (icmp pred' (reduce.or X), 0) when X
5033/// has lanes known to all be non-negative or all non-positive, so that
5034/// sum == 0 iff every lane is 0. Falls back to reduce.umax if reduce.or is
5035/// more expensive on the target.
5036bool VectorCombine::foldReduceAddCmpZero(Instruction &I) {
5037 CmpPredicate Pred;
5038 Value *Vec;
5039 if (!match(V: &I, P: m_ICmp(Pred,
5040 L: m_OneUse(SubPattern: m_Intrinsic<Intrinsic::vector_reduce_add>(
5041 Op0: m_Value(V&: Vec))),
5042 R: m_Zero())))
5043 return false;
5044
5045 auto *VecTy = dyn_cast<FixedVectorType>(Val: Vec->getType());
5046 if (!VecTy || VecTy->getNumElements() < 2)
5047 return false;
5048
5049 SimplifyQuery Q = SQ.getWithInstruction(I: &I);
5050 bool IsNonNegative = isKnownNonNegative(V: Vec, SQ: Q);
5051 bool IsNonPositive = !IsNonNegative && isKnownNonPositive(V: Vec, SQ: Q);
5052 if (!IsNonNegative && !IsNonPositive)
5053 return false;
5054
5055 // Summing NumElts lanes can consume up to log2(NumElts) sign bits. Require
5056 // strictly more headroom than that so the sum cannot wrap to zero.
5057 unsigned NumElts = VecTy->getNumElements();
5058 unsigned NumSignBits = ComputeNumSignBits(Op: Vec, DL: *DL, AC: SQ.AC, CxtI: &I, DT: &DT);
5059 if (Log2_32(Value: NumElts) >= NumSignBits)
5060 return false;
5061
5062 ICmpInst::Predicate NewPred;
5063 switch (Pred) {
5064 case ICmpInst::ICMP_EQ:
5065 case ICmpInst::ICMP_ULE:
5066 case ICmpInst::ICMP_SLE:
5067 case ICmpInst::ICMP_SGE:
5068 NewPred = ICmpInst::ICMP_EQ;
5069 break;
5070 case ICmpInst::ICMP_NE:
5071 case ICmpInst::ICMP_UGT:
5072 case ICmpInst::ICMP_SGT:
5073 case ICmpInst::ICMP_SLT:
5074 NewPred = ICmpInst::ICMP_NE;
5075 break;
5076 default:
5077 return false;
5078 }
5079
5080 // SGT and SLE on a non-positive tree, and SLT and SGE on a non-negative
5081 // tree, are tautologies (always true or always false). Leave those to
5082 // InstCombine rather than mapping them here. Remaining signed inequalities
5083 // also need one extra sign bit so the sum cannot flip sign.
5084 if (!IsNonNegative &&
5085 (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE))
5086 return false;
5087 if (!IsNonPositive &&
5088 (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE))
5089 return false;
5090 if ((Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE ||
5091 Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) &&
5092 Log2_32(Value: NumElts) >= NumSignBits - 1)
5093 return false;
5094
5095 InstructionCost OrigCost = TTI.getArithmeticReductionCost(
5096 Opcode: Instruction::Add, Ty: VecTy, FMF: std::nullopt, CostKind);
5097 InstructionCost OrCost = TTI.getArithmeticReductionCost(
5098 Opcode: Instruction::Or, Ty: VecTy, FMF: std::nullopt, CostKind);
5099 InstructionCost UmaxCost = TTI.getMinMaxReductionCost(
5100 IID: Intrinsic::umax, Ty: VecTy, FMF: FastMathFlags(), CostKind);
5101 if (!OrCost.isValid() && !UmaxCost.isValid())
5102 return false;
5103 bool UseOr = OrCost.isValid() && (!UmaxCost.isValid() || OrCost <= UmaxCost);
5104 InstructionCost AltCost = UseOr ? OrCost : UmaxCost;
5105 if (AltCost > OrigCost)
5106 return false;
5107
5108 Builder.SetInsertPoint(&I);
5109 Value *NewReduce = UseOr ? Builder.CreateOrReduce(Src: Vec)
5110 : Builder.CreateIntrinsic(
5111 ID: Intrinsic::vector_reduce_umax, OverloadTypes: {VecTy}, Args: {Vec});
5112 Worklist.pushValue(V: NewReduce);
5113 Value *NewCmp = Builder.CreateICmp(
5114 P: NewPred, LHS: NewReduce, RHS: ConstantInt::getNullValue(Ty: VecTy->getScalarType()));
5115 replaceValue(Old&: I, New&: *NewCmp);
5116 return true;
5117}
5118
5119/// Returns true if this ShuffleVectorInst eventually feeds into a
5120/// vector reduction intrinsic (e.g., vector_reduce_add) by only following
5121/// chains of shuffles and binary operators (in any combination/order).
5122/// The search does not go deeper than the given Depth.
5123static bool feedsIntoVectorReduction(ShuffleVectorInst *SVI) {
5124 constexpr unsigned MaxVisited = 32;
5125 SmallPtrSet<Instruction *, 8> Visited;
5126 SmallVector<Instruction *, 4> WorkList;
5127 bool FoundReduction = false;
5128
5129 WorkList.push_back(Elt: SVI);
5130 while (!WorkList.empty()) {
5131 Instruction *I = WorkList.pop_back_val();
5132 for (User *U : I->users()) {
5133 auto *UI = cast<Instruction>(Val: U);
5134 if (!UI || !Visited.insert(Ptr: UI).second)
5135 continue;
5136 if (Visited.size() > MaxVisited)
5137 return false;
5138 if (auto *II = dyn_cast<IntrinsicInst>(Val: UI)) {
5139 // More than one reduction reached
5140 if (FoundReduction)
5141 return false;
5142 switch (II->getIntrinsicID()) {
5143 case Intrinsic::vector_reduce_add:
5144 case Intrinsic::vector_reduce_mul:
5145 case Intrinsic::vector_reduce_and:
5146 case Intrinsic::vector_reduce_or:
5147 case Intrinsic::vector_reduce_xor:
5148 case Intrinsic::vector_reduce_smin:
5149 case Intrinsic::vector_reduce_smax:
5150 case Intrinsic::vector_reduce_umin:
5151 case Intrinsic::vector_reduce_umax:
5152 FoundReduction = true;
5153 continue;
5154 default:
5155 return false;
5156 }
5157 }
5158
5159 if (!isa<BinaryOperator>(Val: UI) && !isa<ShuffleVectorInst>(Val: UI))
5160 return false;
5161
5162 WorkList.emplace_back(Args&: UI);
5163 }
5164 }
5165 return FoundReduction;
5166}
5167
5168/// This method looks for groups of shuffles acting on binops, of the form:
5169/// %x = shuffle ...
5170/// %y = shuffle ...
5171/// %a = binop %x, %y
5172/// %b = binop %x, %y
5173/// shuffle %a, %b, selectmask
5174/// We may, especially if the shuffle is wider than legal, be able to convert
5175/// the shuffle to a form where only parts of a and b need to be computed. On
5176/// architectures with no obvious "select" shuffle, this can reduce the total
5177/// number of operations if the target reports them as cheaper.
5178bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
5179 auto *SVI = cast<ShuffleVectorInst>(Val: &I);
5180 auto *VT = cast<FixedVectorType>(Val: I.getType());
5181 auto *Op0 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 0));
5182 auto *Op1 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 1));
5183 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
5184 VT != Op0->getType())
5185 return false;
5186
5187 auto *SVI0A = dyn_cast<Instruction>(Val: Op0->getOperand(i: 0));
5188 auto *SVI0B = dyn_cast<Instruction>(Val: Op0->getOperand(i: 1));
5189 auto *SVI1A = dyn_cast<Instruction>(Val: Op1->getOperand(i: 0));
5190 auto *SVI1B = dyn_cast<Instruction>(Val: Op1->getOperand(i: 1));
5191 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
5192 auto checkSVNonOpUses = [&](Instruction *I) {
5193 if (!I || I->getOperand(i: 0)->getType() != VT)
5194 return true;
5195 return any_of(Range: I->users(), P: [&](User *U) {
5196 return U != Op0 && U != Op1 &&
5197 !(isa<ShuffleVectorInst>(Val: U) &&
5198 (InputShuffles.contains(Ptr: cast<Instruction>(Val: U)) ||
5199 isInstructionTriviallyDead(I: cast<Instruction>(Val: U))));
5200 });
5201 };
5202 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
5203 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
5204 return false;
5205
5206 // Collect all the uses that are shuffles that we can transform together. We
5207 // may not have a single shuffle, but a group that can all be transformed
5208 // together profitably.
5209 SmallVector<ShuffleVectorInst *> Shuffles;
5210 auto collectShuffles = [&](Instruction *I) {
5211 for (auto *U : I->users()) {
5212 auto *SV = dyn_cast<ShuffleVectorInst>(Val: U);
5213 if (!SV || SV->getType() != VT)
5214 return false;
5215 if ((SV->getOperand(i_nocapture: 0) != Op0 && SV->getOperand(i_nocapture: 0) != Op1) ||
5216 (SV->getOperand(i_nocapture: 1) != Op0 && SV->getOperand(i_nocapture: 1) != Op1))
5217 return false;
5218 if (!llvm::is_contained(Range&: Shuffles, Element: SV))
5219 Shuffles.push_back(Elt: SV);
5220 }
5221 return true;
5222 };
5223 if (!collectShuffles(Op0) || !collectShuffles(Op1))
5224 return false;
5225 // From a reduction, we need to be processing a single shuffle, otherwise the
5226 // other uses will not be lane-invariant.
5227 if (FromReduction && Shuffles.size() > 1)
5228 return false;
5229
5230 // Add any shuffle uses for the shuffles we have found, to include them in our
5231 // cost calculations.
5232 if (!FromReduction) {
5233 for (size_t Idx = 0, E = Shuffles.size(); Idx != E; ++Idx) {
5234 for (auto *U : Shuffles[Idx]->users()) {
5235 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(Val: U);
5236 if (SSV && isa<UndefValue>(Val: SSV->getOperand(i_nocapture: 1)) && SSV->getType() == VT)
5237 Shuffles.push_back(Elt: SSV);
5238 }
5239 }
5240 }
5241
5242 // For each of the output shuffles, we try to sort all the first vector
5243 // elements to the beginning, followed by the second array elements at the
5244 // end. If the binops are legalized to smaller vectors, this may reduce total
5245 // number of binops. We compute the ReconstructMask mask needed to convert
5246 // back to the original lane order.
5247 SmallVector<std::pair<int, int>> V1, V2;
5248 SmallVector<SmallVector<int>> OrigReconstructMasks;
5249 int MaxV1Elt = 0, MaxV2Elt = 0;
5250 unsigned NumElts = VT->getNumElements();
5251 for (ShuffleVectorInst *SVN : Shuffles) {
5252 SmallVector<int> Mask;
5253 SVN->getShuffleMask(Result&: Mask);
5254
5255 // Check the operands are the same as the original, or reversed (in which
5256 // case we need to commute the mask).
5257 Value *SVOp0 = SVN->getOperand(i_nocapture: 0);
5258 Value *SVOp1 = SVN->getOperand(i_nocapture: 1);
5259 if (isa<UndefValue>(Val: SVOp1)) {
5260 auto *SSV = cast<ShuffleVectorInst>(Val: SVOp0);
5261 SVOp0 = SSV->getOperand(i_nocapture: 0);
5262 SVOp1 = SSV->getOperand(i_nocapture: 1);
5263 for (int &Elem : Mask) {
5264 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
5265 return false;
5266 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elt: Elem);
5267 }
5268 }
5269 if (SVOp0 == Op1 && SVOp1 == Op0) {
5270 std::swap(a&: SVOp0, b&: SVOp1);
5271 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumElts);
5272 }
5273 if (SVOp0 != Op0 || SVOp1 != Op1)
5274 return false;
5275
5276 // Calculate the reconstruction mask for this shuffle, as the mask needed to
5277 // take the packed values from Op0/Op1 and reconstructing to the original
5278 // order.
5279 SmallVector<int> ReconstructMask;
5280 for (unsigned I = 0; I < Mask.size(); I++) {
5281 if (Mask[I] < 0) {
5282 ReconstructMask.push_back(Elt: -1);
5283 } else if (Mask[I] < static_cast<int>(NumElts)) {
5284 MaxV1Elt = std::max(a: MaxV1Elt, b: Mask[I]);
5285 auto It = find_if(Range&: V1, P: [&](const std::pair<int, int> &A) {
5286 return Mask[I] == A.first;
5287 });
5288 if (It != V1.end())
5289 ReconstructMask.push_back(Elt: It - V1.begin());
5290 else {
5291 ReconstructMask.push_back(Elt: V1.size());
5292 V1.emplace_back(Args&: Mask[I], Args: V1.size());
5293 }
5294 } else {
5295 MaxV2Elt = std::max<int>(a: MaxV2Elt, b: Mask[I] - NumElts);
5296 auto It = find_if(Range&: V2, P: [&](const std::pair<int, int> &A) {
5297 return Mask[I] - static_cast<int>(NumElts) == A.first;
5298 });
5299 if (It != V2.end())
5300 ReconstructMask.push_back(Elt: NumElts + It - V2.begin());
5301 else {
5302 ReconstructMask.push_back(Elt: NumElts + V2.size());
5303 V2.emplace_back(Args: Mask[I] - NumElts, Args: NumElts + V2.size());
5304 }
5305 }
5306 }
5307
5308 // For reductions, we know that the lane ordering out doesn't alter the
5309 // result. In-order can help simplify the shuffle away.
5310 if (FromReduction)
5311 sort(C&: ReconstructMask);
5312 OrigReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5313 }
5314
5315 // If the Maximum element used from V1 and V2 are not larger than the new
5316 // vectors, the vectors are already packes and performing the optimization
5317 // again will likely not help any further. This also prevents us from getting
5318 // stuck in a cycle in case the costs do not also rule it out.
5319 if (V1.empty() || V2.empty() ||
5320 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
5321 MaxV2Elt == static_cast<int>(V2.size()) - 1))
5322 return false;
5323
5324 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
5325 // shuffle of another shuffle, or not a shuffle (that is treated like a
5326 // identity shuffle).
5327 auto GetBaseMaskValue = [&](Instruction *I, int M) {
5328 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5329 if (!SV)
5330 return M;
5331 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5332 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5333 if (InputShuffles.contains(Ptr: SSV))
5334 return SSV->getMaskValue(Elt: SV->getMaskValue(Elt: M));
5335 return SV->getMaskValue(Elt: M);
5336 };
5337
5338 // Attempt to sort the inputs my ascending mask values to make simpler input
5339 // shuffles and push complex shuffles down to the uses. We sort on the first
5340 // of the two input shuffle orders, to try and get at least one input into a
5341 // nice order.
5342 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
5343 std::pair<int, int> Y) {
5344 int MXA = GetBaseMaskValue(A, X.first);
5345 int MYA = GetBaseMaskValue(A, Y.first);
5346 return MXA < MYA;
5347 };
5348 stable_sort(Range&: V1, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5349 return SortBase(SVI0A, A, B);
5350 });
5351 stable_sort(Range&: V2, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
5352 return SortBase(SVI1A, A, B);
5353 });
5354 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
5355 // modified order of the input shuffles.
5356 SmallVector<SmallVector<int>> ReconstructMasks;
5357 for (const auto &Mask : OrigReconstructMasks) {
5358 SmallVector<int> ReconstructMask;
5359 for (int M : Mask) {
5360 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
5361 auto It = find_if(Range: V, P: [M](auto A) { return A.second == M; });
5362 assert(It != V.end() && "Expected all entries in Mask");
5363 return std::distance(first: V.begin(), last: It);
5364 };
5365 if (M < 0)
5366 ReconstructMask.push_back(Elt: -1);
5367 else if (M < static_cast<int>(NumElts)) {
5368 ReconstructMask.push_back(Elt: FindIndex(V1, M));
5369 } else {
5370 ReconstructMask.push_back(Elt: NumElts + FindIndex(V2, M));
5371 }
5372 }
5373 ReconstructMasks.push_back(Elt: std::move(ReconstructMask));
5374 }
5375
5376 // Calculate the masks needed for the new input shuffles, which get padded
5377 // with undef
5378 SmallVector<int> V1A, V1B, V2A, V2B;
5379 for (unsigned I = 0; I < V1.size(); I++) {
5380 V1A.push_back(Elt: GetBaseMaskValue(SVI0A, V1[I].first));
5381 V1B.push_back(Elt: GetBaseMaskValue(SVI0B, V1[I].first));
5382 }
5383 for (unsigned I = 0; I < V2.size(); I++) {
5384 V2A.push_back(Elt: GetBaseMaskValue(SVI1A, V2[I].first));
5385 V2B.push_back(Elt: GetBaseMaskValue(SVI1B, V2[I].first));
5386 }
5387 while (V1A.size() < NumElts) {
5388 V1A.push_back(Elt: PoisonMaskElem);
5389 V1B.push_back(Elt: PoisonMaskElem);
5390 }
5391 while (V2A.size() < NumElts) {
5392 V2A.push_back(Elt: PoisonMaskElem);
5393 V2B.push_back(Elt: PoisonMaskElem);
5394 }
5395
5396 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
5397 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5398 if (!SV)
5399 return C;
5400 return C + TTI.getShuffleCost(Kind: isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1))
5401 ? TTI::SK_PermuteSingleSrc
5402 : TTI::SK_PermuteTwoSrc,
5403 DstTy: VT, SrcTy: VT, Mask: SV->getShuffleMask(), CostKind);
5404 };
5405 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5406 return C +
5407 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5408 };
5409
5410 unsigned ElementSize = VT->getElementType()->getPrimitiveSizeInBits();
5411 unsigned MaxVectorSize =
5412 TTI.getRegisterBitWidth(K: TargetTransformInfo::RGK_FixedWidthVector);
5413 unsigned MaxElementsInVector = MaxVectorSize / ElementSize;
5414 if (MaxElementsInVector == 0)
5415 return false;
5416 // When there are multiple shufflevector operations on the same input,
5417 // especially when the vector length is larger than the register size,
5418 // identical shuffle patterns may occur across different groups of elements.
5419 // To avoid overestimating the cost by counting these repeated shuffles more
5420 // than once, we only account for unique shuffle patterns. This adjustment
5421 // prevents inflated costs in the cost model for wide vectors split into
5422 // several register-sized groups.
5423 std::set<SmallVector<int, 4>> UniqueShuffles;
5424 auto AddShuffleMaskAdjustedCost = [&](InstructionCost C, ArrayRef<int> Mask) {
5425 // Compute the cost for performing the shuffle over the full vector.
5426 auto ShuffleCost =
5427 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
5428 unsigned NumFullVectors = Mask.size() / MaxElementsInVector;
5429 if (NumFullVectors < 2)
5430 return C + ShuffleCost;
5431 SmallVector<int, 4> SubShuffle(MaxElementsInVector);
5432 unsigned NumUniqueGroups = 0;
5433 unsigned NumGroups = Mask.size() / MaxElementsInVector;
5434 // For each group of MaxElementsInVector contiguous elements,
5435 // collect their shuffle pattern and insert into the set of unique patterns.
5436 for (unsigned I = 0; I < NumFullVectors; ++I) {
5437 for (unsigned J = 0; J < MaxElementsInVector; ++J)
5438 SubShuffle[J] = Mask[MaxElementsInVector * I + J];
5439 if (UniqueShuffles.insert(x: SubShuffle).second)
5440 NumUniqueGroups += 1;
5441 }
5442 return C + ShuffleCost * NumUniqueGroups / NumGroups;
5443 };
5444 auto AddShuffleAdjustedCost = [&](InstructionCost C, Instruction *I) {
5445 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5446 if (!SV)
5447 return C;
5448 SmallVector<int, 16> Mask;
5449 SV->getShuffleMask(Result&: Mask);
5450 return AddShuffleMaskAdjustedCost(C, Mask);
5451 };
5452 // Check that input consists of ShuffleVectors applied to the same input
5453 auto AllShufflesHaveSameOperands =
5454 [](SmallPtrSetImpl<Instruction *> &InputShuffles) {
5455 if (InputShuffles.size() < 2)
5456 return false;
5457 ShuffleVectorInst *FirstSV =
5458 dyn_cast<ShuffleVectorInst>(Val: *InputShuffles.begin());
5459 if (!FirstSV)
5460 return false;
5461
5462 Value *In0 = FirstSV->getOperand(i_nocapture: 0), *In1 = FirstSV->getOperand(i_nocapture: 1);
5463 return std::all_of(
5464 first: std::next(x: InputShuffles.begin()), last: InputShuffles.end(),
5465 pred: [&](Instruction *I) {
5466 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5467 return SV && SV->getOperand(i_nocapture: 0) == In0 && SV->getOperand(i_nocapture: 1) == In1;
5468 });
5469 };
5470
5471 // Get the costs of the shuffles + binops before and after with the new
5472 // shuffle masks.
5473 InstructionCost CostBefore =
5474 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: VT, CostKind) +
5475 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: VT, CostKind);
5476 CostBefore += std::accumulate(first: Shuffles.begin(), last: Shuffles.end(),
5477 init: InstructionCost(0), binary_op: AddShuffleCost);
5478 if (AllShufflesHaveSameOperands(InputShuffles)) {
5479 UniqueShuffles.clear();
5480 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5481 init: InstructionCost(0), binary_op: AddShuffleAdjustedCost);
5482 } else {
5483 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
5484 init: InstructionCost(0), binary_op: AddShuffleCost);
5485 }
5486
5487 // The new binops will be unused for lanes past the used shuffle lengths.
5488 // These types attempt to get the correct cost for that from the target.
5489 FixedVectorType *Op0SmallVT =
5490 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V1.size());
5491 FixedVectorType *Op1SmallVT =
5492 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V2.size());
5493 InstructionCost CostAfter =
5494 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: Op0SmallVT, CostKind) +
5495 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: Op1SmallVT, CostKind);
5496 UniqueShuffles.clear();
5497 CostAfter += std::accumulate(first: ReconstructMasks.begin(), last: ReconstructMasks.end(),
5498 init: InstructionCost(0), binary_op: AddShuffleMaskAdjustedCost);
5499 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
5500 CostAfter +=
5501 std::accumulate(first: OutputShuffleMasks.begin(), last: OutputShuffleMasks.end(),
5502 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
5503
5504 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
5505 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
5506 << " vs CostAfter: " << CostAfter << "\n");
5507 if (CostBefore < CostAfter ||
5508 (CostBefore == CostAfter && !feedsIntoVectorReduction(SVI)))
5509 return false;
5510
5511 // The cost model has passed, create the new instructions.
5512 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
5513 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
5514 if (!SV)
5515 return I;
5516 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
5517 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
5518 if (InputShuffles.contains(Ptr: SSV))
5519 return SSV->getOperand(i_nocapture: Op);
5520 return SV->getOperand(i_nocapture: Op);
5521 };
5522 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
5523 Value *NSV0A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0A, 0),
5524 V2: GetShuffleOperand(SVI0A, 1), Mask: V1A);
5525 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
5526 Value *NSV0B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0B, 0),
5527 V2: GetShuffleOperand(SVI0B, 1), Mask: V1B);
5528 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
5529 Value *NSV1A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1A, 0),
5530 V2: GetShuffleOperand(SVI1A, 1), Mask: V2A);
5531 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
5532 Value *NSV1B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1B, 0),
5533 V2: GetShuffleOperand(SVI1B, 1), Mask: V2B);
5534 Builder.SetInsertPoint(Op0);
5535 Value *NOp0 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op0->getOpcode(),
5536 LHS: NSV0A, RHS: NSV0B);
5537 if (auto *I = dyn_cast<Instruction>(Val: NOp0))
5538 I->copyIRFlags(V: Op0, IncludeWrapFlags: true);
5539 Builder.SetInsertPoint(Op1);
5540 Value *NOp1 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op1->getOpcode(),
5541 LHS: NSV1A, RHS: NSV1B);
5542 if (auto *I = dyn_cast<Instruction>(Val: NOp1))
5543 I->copyIRFlags(V: Op1, IncludeWrapFlags: true);
5544
5545 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
5546 Builder.SetInsertPoint(Shuffles[S]);
5547 Value *NSV = Builder.CreateShuffleVector(V1: NOp0, V2: NOp1, Mask: ReconstructMasks[S]);
5548 replaceValue(Old&: *Shuffles[S], New&: *NSV, Erase: false);
5549 }
5550
5551 Worklist.pushValue(V: NSV0A);
5552 Worklist.pushValue(V: NSV0B);
5553 Worklist.pushValue(V: NSV1A);
5554 Worklist.pushValue(V: NSV1B);
5555 return true;
5556}
5557
5558/// Check if instruction depends on ZExt and this ZExt can be moved after the
5559/// instruction. Move ZExt if it is profitable. For example:
5560/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
5561/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
5562/// Cost model calculations takes into account if zext(x) has other users and
5563/// whether it can be propagated through them too.
5564bool VectorCombine::shrinkType(Instruction &I) {
5565 Value *ZExted, *OtherOperand;
5566 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_ZExt(Op: m_Value(V&: ZExted)),
5567 R: m_Value(V&: OtherOperand))) &&
5568 !match(V: &I, P: m_LShr(L: m_ZExt(Op: m_Value(V&: ZExted)), R: m_Value(V&: OtherOperand))))
5569 return false;
5570
5571 Value *ZExtOperand = I.getOperand(i: I.getOperand(i: 0) == OtherOperand ? 1 : 0);
5572
5573 auto *BigTy = cast<FixedVectorType>(Val: I.getType());
5574 auto *SmallTy = cast<FixedVectorType>(Val: ZExted->getType());
5575 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
5576
5577 if (I.getOpcode() == Instruction::LShr) {
5578 // Check that the shift amount is less than the number of bits in the
5579 // smaller type. Otherwise, the smaller lshr will return a poison value.
5580 KnownBits ShAmtKB = computeKnownBits(V: I.getOperand(i: 1), DL: *DL);
5581 if (ShAmtKB.getMaxValue().uge(RHS: BW))
5582 return false;
5583 } else {
5584 // Check that the expression overall uses at most the same number of bits as
5585 // ZExted
5586 KnownBits KB = computeKnownBits(V: &I, DL: *DL);
5587 if (KB.countMaxActiveBits() > BW)
5588 return false;
5589 }
5590
5591 // Calculate costs of leaving current IR as it is and moving ZExt operation
5592 // later, along with adding truncates if needed
5593 InstructionCost ZExtCost = TTI.getCastInstrCost(
5594 Opcode: Instruction::ZExt, Dst: BigTy, Src: SmallTy,
5595 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5596 InstructionCost CurrentCost = ZExtCost;
5597 InstructionCost ShrinkCost = 0;
5598
5599 // Calculate total cost and check that we can propagate through all ZExt users
5600 for (User *U : ZExtOperand->users()) {
5601 auto *UI = cast<Instruction>(Val: U);
5602 if (UI == &I) {
5603 CurrentCost +=
5604 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5605 ShrinkCost +=
5606 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5607 ShrinkCost += ZExtCost;
5608 continue;
5609 }
5610
5611 if (!Instruction::isBinaryOp(Opcode: UI->getOpcode()))
5612 return false;
5613
5614 // Check if we can propagate ZExt through its other users
5615 KnownBits KB = computeKnownBits(V: UI, DL: *DL);
5616 if (KB.countMaxActiveBits() > BW)
5617 return false;
5618
5619 CurrentCost += TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
5620 ShrinkCost +=
5621 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
5622 ShrinkCost += ZExtCost;
5623 }
5624
5625 // If the other instruction operand is not a constant, we'll need to
5626 // generate a truncate instruction. So we have to adjust cost
5627 if (!isa<Constant>(Val: OtherOperand))
5628 ShrinkCost += TTI.getCastInstrCost(
5629 Opcode: Instruction::Trunc, Dst: SmallTy, Src: BigTy,
5630 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
5631
5632 // If the cost of shrinking types and leaving the IR is the same, we'll lean
5633 // towards modifying the IR because shrinking opens opportunities for other
5634 // shrinking optimisations.
5635 if (ShrinkCost > CurrentCost)
5636 return false;
5637
5638 Builder.SetInsertPoint(&I);
5639 Value *Op0 = ZExted;
5640 Value *Op1 = Builder.CreateTrunc(V: OtherOperand, DestTy: SmallTy);
5641 // Keep the order of operands the same
5642 if (I.getOperand(i: 0) == OtherOperand)
5643 std::swap(a&: Op0, b&: Op1);
5644 Value *NewBinOp =
5645 Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(), LHS: Op0, RHS: Op1);
5646 cast<Instruction>(Val: NewBinOp)->copyIRFlags(V: &I);
5647 cast<Instruction>(Val: NewBinOp)->copyMetadata(SrcInst: I);
5648 Value *NewZExtr = Builder.CreateZExt(V: NewBinOp, DestTy: BigTy);
5649 replaceValue(Old&: I, New&: *NewZExtr);
5650 return true;
5651}
5652
5653/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
5654/// shuffle (DstVec, SrcVec, Mask)
5655bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
5656 Value *DstVec, *SrcVec;
5657 uint64_t ExtIdx, InsIdx;
5658 if (!match(V: &I,
5659 P: m_InsertElt(Val: m_Value(V&: DstVec),
5660 Elt: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx)),
5661 Idx: m_ConstantInt(V&: InsIdx))))
5662 return false;
5663
5664 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
5665 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
5666 // We can try combining vectors with different element sizes.
5667 if (!DstVecTy || !SrcVecTy ||
5668 SrcVecTy->getElementType() != DstVecTy->getElementType())
5669 return false;
5670
5671 unsigned NumDstElts = DstVecTy->getNumElements();
5672 unsigned NumSrcElts = SrcVecTy->getNumElements();
5673 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
5674 return false;
5675
5676 // Insertion into poison is a cheaper single operand shuffle.
5677 TargetTransformInfo::ShuffleKind SK;
5678 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
5679
5680 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
5681 bool NeedDstSrcSwap = isa<PoisonValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec);
5682 if (NeedDstSrcSwap) {
5683 SK = TargetTransformInfo::SK_PermuteSingleSrc;
5684 Mask[InsIdx] = ExtIdx % NumDstElts;
5685 std::swap(a&: DstVec, b&: SrcVec);
5686 } else {
5687 SK = TargetTransformInfo::SK_PermuteTwoSrc;
5688 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
5689 Mask[InsIdx] = (ExtIdx % NumDstElts) + NumDstElts;
5690 }
5691
5692 // Cost
5693 auto *Ins = cast<InsertElementInst>(Val: &I);
5694 auto *Ext = cast<ExtractElementInst>(Val: I.getOperand(i: 1));
5695 InstructionCost InsCost =
5696 TTI.getVectorInstrCost(I: *Ins, Val: DstVecTy, CostKind, Index: InsIdx);
5697 InstructionCost ExtCost =
5698 TTI.getVectorInstrCost(I: *Ext, Val: DstVecTy, CostKind, Index: ExtIdx);
5699 InstructionCost OldCost = ExtCost + InsCost;
5700
5701 InstructionCost NewCost = 0;
5702 SmallVector<int> ExtToVecMask;
5703 if (!NeedExpOrNarrow) {
5704 // Ignore 'free' identity insertion shuffle.
5705 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
5706 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
5707 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind, Index: 0,
5708 SubTp: nullptr, Args: {DstVec, SrcVec});
5709 } else {
5710 // When creating a length-changing-vector, always try to keep the relevant
5711 // element in an equivalent position, so that bulk shuffles are more likely
5712 // to be useful.
5713 ExtToVecMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
5714 ExtToVecMask[ExtIdx % NumDstElts] = ExtIdx;
5715 // Add cost for expanding or narrowing
5716 NewCost = TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
5717 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: ExtToVecMask, CostKind);
5718 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind);
5719 }
5720
5721 if (!Ext->hasOneUse())
5722 NewCost += ExtCost;
5723
5724 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
5725 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
5726 << "\n");
5727
5728 if (OldCost < NewCost)
5729 return false;
5730
5731 if (NeedExpOrNarrow) {
5732 if (!NeedDstSrcSwap)
5733 SrcVec = Builder.CreateShuffleVector(V: SrcVec, Mask: ExtToVecMask);
5734 else
5735 DstVec = Builder.CreateShuffleVector(V: DstVec, Mask: ExtToVecMask);
5736 }
5737
5738 // Canonicalize undef param to RHS to help further folds.
5739 if (isa<UndefValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec)) {
5740 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumDstElts);
5741 std::swap(a&: DstVec, b&: SrcVec);
5742 }
5743
5744 Value *Shuf = Builder.CreateShuffleVector(V1: DstVec, V2: SrcVec, Mask);
5745 replaceValue(Old&: I, New&: *Shuf);
5746
5747 return true;
5748}
5749
5750/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
5751/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
5752/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
5753/// before casting it back into `<vscale x 16 x i32>`.
5754bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
5755 const APInt *SplatVal0, *SplatVal1;
5756 if (!match(V: &I, P: m_Intrinsic<Intrinsic::vector_interleave2>(
5757 Op0: m_APInt(Res&: SplatVal0), Op1: m_APInt(Res&: SplatVal1))))
5758 return false;
5759
5760 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
5761 << "\n");
5762
5763 auto *VTy =
5764 cast<VectorType>(Val: cast<IntrinsicInst>(Val&: I).getArgOperand(i: 0)->getType());
5765 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
5766 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
5767
5768 // Just in case the cost of interleave2 intrinsic and bitcast are both
5769 // invalid, in which case we want to bail out, we use <= rather
5770 // than < here. Even they both have valid and equal costs, it's probably
5771 // not a good idea to emit a high-cost constant splat.
5772 if (TTI.getInstructionCost(U: &I, CostKind) <=
5773 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: I.getType(), Src: ExtVTy,
5774 CCH: TTI::CastContextHint::None, CostKind)) {
5775 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
5776 << *I.getType() << " is too high.\n");
5777 return false;
5778 }
5779
5780 APInt NewSplatVal = SplatVal1->zext(width: Width * 2);
5781 NewSplatVal <<= Width;
5782 NewSplatVal |= SplatVal0->zext(width: Width * 2);
5783 auto *NewSplat = ConstantVector::getSplat(
5784 EC: ExtVTy->getElementCount(), Elt: ConstantInt::get(Context&: F.getContext(), V: NewSplatVal));
5785
5786 IRBuilder<> Builder(&I);
5787 replaceValue(Old&: I, New&: *Builder.CreateBitCast(V: NewSplat, DestTy: I.getType()));
5788 return true;
5789}
5790
5791/// Given this sequence:
5792/// ```
5793/// %d = llvm.vector.deinterleave2 <vscale x 16 x i32> %v
5794/// %f0 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 0
5795/// %f1 = extractvalue { <vscale x 8 x i32>, <vscale x 8 x i32> } %d, 1
5796///
5797/// %low0 = and <vscale x 8 x i32> %f0, splat (i32 65535)
5798/// %low1 = shl <vscale x 8 x i32> %f1, splat (i32 16)
5799/// %merge0 = or disjoint <vscale x 8 x i32> %low0, %low1
5800///
5801/// %high0 = and <vscale x 8 x i32> %f1, splat (i32 -65536)
5802/// %high1 = lshr <vscale x 8 x i32> %f0, splat (i32 16)
5803/// %merge1 = or disjoint <vscale x 8 x i32> %high0, %high1
5804/// ```
5805/// It is actually just de-interleaving a 16-bit vector with double the
5806/// vector length. More generally speaking, it's de-interleaving on a vector
5807/// with half the element width as the original vector.
5808///
5809/// Therefore, we can turn it into:
5810/// ```
5811/// %narrow.v = bitcast <vscale x 16 x i32> %v to <vscale x 32 x i16>
5812/// %d = llvm.vector.deinterleave2 <vscale x 32 x i16> %narrow.v
5813/// %f0 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 0
5814/// %f1 = extractvalue { <vscale x 16 x i16>, <vscale x 16 x i16> } %d, 1
5815///
5816/// %merge0 = bitcast <vscale x 16 x i16> %f0 to <vscale x 8 x i32>
5817/// %merge1 = bitcast <vscale x 16 x i16> %f1 to <vscale x 8 x i32>
5818/// ```
5819bool VectorCombine::foldDeinterleaveIntrinsics(Instruction &I) {
5820 // This pattern involves bitcast that is not compatible with big endian.
5821 if (DL->isBigEndian())
5822 return false;
5823
5824 using namespace PatternMatch;
5825 Value *DeinterleavedVal;
5826 if (!match(V: &I, P: m_Deinterleave2(Op: m_Value(V&: DeinterleavedVal))))
5827 return false;
5828
5829 VectorType *VecTy = cast<VectorType>(Val: DeinterleavedVal->getType());
5830 IntegerType *ElementTy = dyn_cast<IntegerType>(Val: VecTy->getElementType());
5831 if (!ElementTy)
5832 return false;
5833 unsigned ElementWidth = ElementTy->getBitWidth();
5834 if (ElementWidth < 2 || !isPowerOf2_32(Value: ElementWidth))
5835 return false;
5836 unsigned HalfElementWidth = ElementWidth / 2;
5837
5838 if (!I.hasNUses(N: 2))
5839 return false;
5840 std::array<ExtractValueInst *, 2> OrigFields{};
5841 for (User *Usr : I.users()) {
5842 auto *E = dyn_cast<ExtractValueInst>(Val: Usr);
5843 // The deinterleave result can only be used by extractions.
5844 if (!E || E->getNumIndices() != 1)
5845 return false;
5846 unsigned Idx = *E->idx_begin();
5847 // A single field cannot be extracted more than once.
5848 if (Idx >= 2 || OrigFields[Idx] || !E->hasNUses(N: 2))
5849 return false;
5850 OrigFields[Idx] = E;
5851 }
5852
5853 // Find the merge instruction (i.e. OR) first.
5854 SmallVector<Instruction *, 2> MergeInsts;
5855 for (auto *FieldUsr : OrigFields[0]->users()) {
5856 if (!FieldUsr->hasOneUse() || !isa<Instruction>(Val: FieldUsr->user_back()))
5857 return false;
5858 MergeInsts.push_back(Elt: cast<Instruction>(Val: FieldUsr->user_back()));
5859 }
5860 assert(MergeInsts.size() == 2);
5861
5862 // Pattern match bottom-up from the merge instructions.
5863 auto MatchMerge = [&](void) -> bool {
5864 APInt LoMask = APInt::getLowBitsSet(numBits: ElementWidth, loBitsSet: HalfElementWidth);
5865 APInt HiMask = APInt::getHighBitsSet(numBits: ElementWidth, hiBitsSet: HalfElementWidth);
5866 return match(V: MergeInsts[0],
5867 P: m_c_Or(L: m_And(L: m_Specific(V: OrigFields[0]), R: m_SpecificInt(V: LoMask)),
5868 R: m_Shl(L: m_Specific(V: OrigFields[1]),
5869 R: m_SpecificInt(V: HalfElementWidth)))) &&
5870 match(V: MergeInsts[1],
5871 P: m_c_Or(L: m_And(L: m_Specific(V: OrigFields[1]), R: m_SpecificInt(V: HiMask)),
5872 R: m_LShr(L: m_Specific(V: OrigFields[0]),
5873 R: m_SpecificInt(V: HalfElementWidth))));
5874 };
5875 if (!MatchMerge()) {
5876 std::swap(a&: MergeInsts[0], b&: MergeInsts[1]);
5877 if (!MatchMerge())
5878 return false;
5879 }
5880
5881 // Profitability check.
5882 InstructionCost OldCost =
5883 TTI.getInstructionCost(U: MergeInsts[0], CostKind) +
5884 TTI.getInstructionCost(U: cast<Instruction>(Val: MergeInsts[0]->getOperand(i: 0)),
5885 CostKind) +
5886 TTI.getInstructionCost(U: cast<Instruction>(Val: MergeInsts[0]->getOperand(i: 1)),
5887 CostKind);
5888 // There are two fields (assuming SHL has the same cost as LSHR).
5889 OldCost *= 2;
5890
5891 auto *NewFieldTy = VecTy->getWithNewBitWidth(NewBitWidth: HalfElementWidth);
5892 auto *NewVecTy =
5893 VectorType::getDoubleElementsVectorType(VTy: cast<VectorType>(Val: NewFieldTy));
5894 InstructionCost NewCost =
5895 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: VecTy, Src: NewVecTy,
5896 CCH: TTI::CastContextHint::None, CostKind) +
5897 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewFieldTy,
5898 Src: MergeInsts[0]->getType(), CCH: TTI::CastContextHint::None,
5899 CostKind) *
5900 2;
5901 if (OldCost <= NewCost || !NewCost.isValid()) {
5902 LLVM_DEBUG(
5903 dbgs() << "VC: New deinterleave2 sequence cost (" << NewCost << ")"
5904 << " is higher than that of the old one (" << OldCost << ")\n");
5905 return false;
5906 }
5907
5908 // Do the replacement.
5909 IRBuilder<> Builder(&I);
5910 Value *NewVecCast = Builder.CreateBitCast(V: DeinterleavedVal, DestTy: NewVecTy);
5911 Value *NewDeinterleave = Builder.CreateIntrinsic(
5912 ID: Intrinsic::vector_deinterleave2, OverloadTypes: {NewVecTy}, Args: {NewVecCast});
5913 for (auto [Idx, MergeInst] : enumerate(First&: MergeInsts)) {
5914 Value *NewField = Builder.CreateExtractValue(Agg: NewDeinterleave, Idxs: Idx);
5915 NewField = Builder.CreateBitCast(V: NewField, DestTy: MergeInst->getType());
5916 replaceValue(Old&: *MergeInst, New&: *NewField);
5917 }
5918
5919 return true;
5920}
5921
5922bool VectorCombine::foldBitcastOfVPLoad(Instruction &I) {
5923 const DataLayout &DL = I.getDataLayout();
5924 auto *Cast = dyn_cast<CastInst>(Val: &I);
5925 if (!Cast || !Cast->isNoopCast(DL) || !isa<VectorType>(Val: Cast->getDestTy()))
5926 return false;
5927
5928 // Fold away bit casts of the loaded value by loading the desired type,
5929 // if the mask is all-ones.
5930 Value *EVL;
5931 auto *II = dyn_cast<VPIntrinsic>(Val: I.getOperand(i: 0));
5932 if (!II || !match(V: II, P: m_OneUse(SubPattern: m_Intrinsic<Intrinsic::vp_load>(
5933 Op0: m_Value(), Op1: m_AllOnes(), Op2: m_Value(V&: EVL)))))
5934 return false;
5935
5936 VectorType *OrigVecTy = cast<VectorType>(Val: II->getType());
5937 Align OrigAlign =
5938 DL.getValueOrABITypeAlignment(Alignment: II->getPointerAlignment(), Ty: OrigVecTy);
5939 ElementCount OrigVecCnt = OrigVecTy->getElementCount();
5940 VectorType *NewVecTy = cast<VectorType>(Val: Cast->getDestTy());
5941 ElementCount NewVecCnt = NewVecTy->getElementCount();
5942
5943 // Right now we only support cases where the NewVec is longer, because for
5944 // cases where it's shorter, we have to be sure that EVL can be exactly
5945 // divided, otherwise it might yield incorrect results or even page faults
5946 // (if we round-up during the division).
5947 if (!(OrigVecCnt.isScalable() == NewVecCnt.isScalable() &&
5948 NewVecCnt.hasKnownScalarFactor(RHS: OrigVecCnt)))
5949 return false;
5950
5951 InstructionCost OldCost =
5952 TTI.getMemIntrinsicInstrCost(MICA: {Intrinsic::vp_load, OrigVecTy,
5953 II->getMemoryPointerParam(), false,
5954 OrigAlign},
5955 CostKind) +
5956 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: Cast->getType(), Src: OrigVecTy,
5957 CCH: TTI::CastContextHint::None, CostKind);
5958 InstructionCost NewCost = TTI.getMemIntrinsicInstrCost(
5959 MICA: {Intrinsic::vp_load, NewVecTy, II->getMemoryPointerParam(), false,
5960 OrigAlign},
5961 CostKind);
5962 LLVM_DEBUG(dbgs() << "foldBitcastOfVPLoad: OldCost=" << OldCost
5963 << " NewCost=" << NewCost << "\n");
5964 if (NewCost > OldCost || !NewCost.isValid())
5965 return false;
5966
5967 unsigned Factor = NewVecCnt.getKnownScalarFactor(RHS: OrigVecCnt);
5968 Value *NewEVL = Builder.CreateNUWMul(LHS: EVL, RHS: Builder.getInt32(C: Factor));
5969 Value *NewMask = Builder.CreateVectorSplat(EC: NewVecCnt, V: Builder.getTrue());
5970 CallInst *NewVP = Builder.CreateIntrinsicWithoutFolding(
5971 RetTy: NewVecTy, ID: Intrinsic::vp_load,
5972 Args: {II->getMemoryPointerParam(), NewMask, NewEVL});
5973 // Preserve the original alignment.
5974 NewVP->addParamAttrs(
5975 ArgNo: 0, B: AttrBuilder(II->getContext()).addAlignmentAttr(Align: OrigAlign));
5976 replaceValue(Old&: *Cast, New&: *NewVP);
5977 return true;
5978}
5979
5980/// Fold the following cases into a single byte-level bit-reverse operation
5981/// and accepts bswap and bitreverse intrinsics:
5982/// bswap(bitreverse(x)) --> bitcast(bitreverse(bitcast(x)))
5983/// bitreverse(bswap(x)) --> bitcast(bitreverse(bitcast(x)))
5984bool VectorCombine::foldBitOrderReverseAndSwap(Instruction &I) {
5985 Value *X;
5986 if (!match(V: &I, P: m_BitReverse(Op0: m_BSwap(Op0: m_Value(V&: X)))) &&
5987 !match(V: &I, P: m_BSwap(Op0: m_BitReverse(Op0: m_Value(V&: X)))))
5988 return false;
5989
5990 Type *Ty = I.getType();
5991 Type *I8Ty = Builder.getInt8Ty();
5992 TypeSize ElementSize = DL->getTypeStoreSize(Ty);
5993 ElementCount NewVecCnt = ElementCount::get(MinVal: ElementSize.getKnownMinValue(),
5994 Scalable: ElementSize.isScalable());
5995 Type *NewVecTy = VectorType::get(ElementType: I8Ty, EC: NewVecCnt);
5996
5997 auto *II = cast<IntrinsicInst>(Val: &I);
5998 auto *InnerII = cast<IntrinsicInst>(Val: II->getArgOperand(i: 0));
5999 // OldCost = cost of bitreverse/bswap + cost of bswap/bitreverse
6000 InstructionCost OldCost = TTI.getInstructionCost(U: II, CostKind) +
6001 TTI.getInstructionCost(U: InnerII, CostKind);
6002
6003 // NewCost = cost of bitcast to byte vector +
6004 // cost of bitreverse/bswap on byte vector +
6005 // cost of bitcast back to original type
6006 InstructionCost CastToVecCost = TTI.getCastInstrCost(
6007 Opcode: Instruction::BitCast, Dst: NewVecTy, Src: Ty, CCH: TTI::CastContextHint::None, CostKind);
6008 InstructionCost CastToOrigCost = TTI.getCastInstrCost(
6009 Opcode: Instruction::BitCast, Dst: Ty, Src: NewVecTy, CCH: TTI::CastContextHint::None, CostKind);
6010
6011 IntrinsicCostAttributes ICANew(Intrinsic::bitreverse, NewVecTy, {NewVecTy});
6012 InstructionCost NewIntrinsicCost =
6013 TTI.getIntrinsicInstrCost(ICA: ICANew, CostKind);
6014 InstructionCost NewCost = CastToVecCost + NewIntrinsicCost + CastToOrigCost;
6015
6016 if (!InnerII->hasOneUse())
6017 NewCost += TTI.getInstructionCost(U: InnerII, CostKind);
6018
6019 LLVM_DEBUG(dbgs() << "Found bitorder reverse and swap: " << I
6020 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
6021 << "\n");
6022 if (!NewCost.isValid() || NewCost >= OldCost)
6023 return false;
6024
6025 // Perform transform: bitcast(arg, <N x i8>), bitreverse, bitcast back
6026 Builder.SetInsertPoint(II);
6027 Value *CastToVec = Builder.CreateBitCast(V: X, DestTy: NewVecTy);
6028 Value *NewCall =
6029 Builder.CreateUnaryIntrinsic(ID: Intrinsic::bitreverse, Op: CastToVec);
6030 Value *CastToOrig = Builder.CreateBitCast(V: NewCall, DestTy: Ty);
6031 replaceValue(Old&: I, New&: *CastToOrig);
6032 return true;
6033}
6034
6035// Attempt to shrink loads that are only used by shufflevector instructions.
6036bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
6037 auto *OldLoad = dyn_cast<LoadInst>(Val: &I);
6038 if (!OldLoad || !OldLoad->isSimple())
6039 return false;
6040
6041 auto *OldLoadTy = dyn_cast<FixedVectorType>(Val: OldLoad->getType());
6042 if (!OldLoadTy)
6043 return false;
6044
6045 unsigned const OldNumElements = OldLoadTy->getNumElements();
6046
6047 // Search all uses of load. If all uses are shufflevector instructions, and
6048 // the second operands are all poison values, find the minimum and maximum
6049 // indices of the vector elements referenced by all shuffle masks.
6050 // Otherwise return `std::nullopt`.
6051 using IndexRange = std::pair<int, int>;
6052 auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
6053 IndexRange OutputRange = IndexRange(OldNumElements, -1);
6054 for (llvm::Use &Use : I.uses()) {
6055 // Ensure all uses match the required pattern.
6056 User *Shuffle = Use.getUser();
6057 ArrayRef<int> Mask;
6058
6059 if (!match(V: Shuffle,
6060 P: m_Shuffle(v1: m_Specific(V: OldLoad), v2: m_Undef(), mask: m_Mask(Mask))))
6061 return std::nullopt;
6062
6063 // Ignore shufflevector instructions that have no uses.
6064 if (Shuffle->use_empty())
6065 continue;
6066
6067 // Find the min and max indices used by the shufflevector instruction.
6068 for (int Index : Mask) {
6069 if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
6070 OutputRange.first = std::min(a: Index, b: OutputRange.first);
6071 OutputRange.second = std::max(a: Index, b: OutputRange.second);
6072 }
6073 }
6074 }
6075
6076 if (OutputRange.second < OutputRange.first)
6077 return std::nullopt;
6078
6079 return OutputRange;
6080 };
6081
6082 // Get the range of vector elements used by shufflevector instructions.
6083 if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
6084 unsigned const NewNumElements = Indices->second + 1u;
6085
6086 // If the range of vector elements is smaller than the full load, attempt
6087 // to create a smaller load.
6088 if (NewNumElements < OldNumElements) {
6089 IRBuilder Builder(&I);
6090 Builder.SetCurrentDebugLocation(I.getDebugLoc());
6091
6092 // Calculate costs of old and new ops.
6093 Type *ElemTy = OldLoadTy->getElementType();
6094 FixedVectorType *NewLoadTy = FixedVectorType::get(ElementType: ElemTy, NumElts: NewNumElements);
6095 Value *PtrOp = OldLoad->getPointerOperand();
6096
6097 InstructionCost OldCost = TTI.getMemoryOpCost(
6098 Opcode: Instruction::Load, Src: OldLoad->getType(), Alignment: OldLoad->getAlign(),
6099 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
6100 InstructionCost NewCost =
6101 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: NewLoadTy, Alignment: OldLoad->getAlign(),
6102 AddressSpace: OldLoad->getPointerAddressSpace(), CostKind);
6103
6104 using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
6105 SmallVector<UseEntry, 4u> NewUses;
6106 unsigned const MaxIndex = NewNumElements * 2u;
6107
6108 for (llvm::Use &Use : I.uses()) {
6109 auto *Shuffle = cast<ShuffleVectorInst>(Val: Use.getUser());
6110
6111 // Ignore shufflevector instructions that have no uses.
6112 if (Shuffle->use_empty())
6113 continue;
6114
6115 ArrayRef<int> OldMask = Shuffle->getShuffleMask();
6116
6117 // Create entry for new use.
6118 NewUses.push_back(Elt: {Shuffle, OldMask});
6119
6120 // Validate mask indices.
6121 for (int Index : OldMask) {
6122 if (Index >= static_cast<int>(MaxIndex))
6123 return false;
6124 }
6125
6126 // Update costs.
6127 OldCost +=
6128 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
6129 SrcTy: OldLoadTy, Mask: OldMask, CostKind);
6130 NewCost +=
6131 TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Shuffle->getType(),
6132 SrcTy: NewLoadTy, Mask: OldMask, CostKind);
6133 }
6134
6135 LLVM_DEBUG(
6136 dbgs() << "Found a load used only by shufflevector instructions: "
6137 << I << "\n OldCost: " << OldCost
6138 << " vs NewCost: " << NewCost << "\n");
6139
6140 if (OldCost < NewCost || !NewCost.isValid())
6141 return false;
6142
6143 // Create new load of smaller vector.
6144 auto *NewLoad = cast<LoadInst>(
6145 Val: Builder.CreateAlignedLoad(Ty: NewLoadTy, Ptr: PtrOp, Align: OldLoad->getAlign()));
6146 NewLoad->copyMetadata(SrcInst: I);
6147
6148 // Replace all uses.
6149 for (UseEntry &Use : NewUses) {
6150 ShuffleVectorInst *Shuffle = Use.first;
6151 std::vector<int> &NewMask = Use.second;
6152
6153 Builder.SetInsertPoint(Shuffle);
6154 Builder.SetCurrentDebugLocation(Shuffle->getDebugLoc());
6155 Value *NewShuffle = Builder.CreateShuffleVector(
6156 V1: NewLoad, V2: PoisonValue::get(T: NewLoadTy), Mask: NewMask);
6157
6158 replaceValue(Old&: *Shuffle, New&: *NewShuffle, Erase: false);
6159 }
6160
6161 return true;
6162 }
6163 }
6164 return false;
6165}
6166
6167// Attempt to narrow a phi of shufflevector instructions where the two incoming
6168// values have the same operands but different masks. If the two shuffle masks
6169// are offsets of one another we can use one branch to rotate the incoming
6170// vector and perform one larger shuffle after the phi.
6171bool VectorCombine::shrinkPhiOfShuffles(Instruction &I) {
6172 auto *Phi = dyn_cast<PHINode>(Val: &I);
6173 if (!Phi || Phi->getNumIncomingValues() != 2u)
6174 return false;
6175
6176 Value *Op = nullptr;
6177 ArrayRef<int> Mask0;
6178 ArrayRef<int> Mask1;
6179
6180 if (!match(V: Phi->getOperand(i_nocapture: 0u),
6181 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op), v2: m_Poison(), mask: m_Mask(Mask0)))) ||
6182 !match(V: Phi->getOperand(i_nocapture: 1u),
6183 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Specific(V: Op), v2: m_Poison(), mask: m_Mask(Mask1)))))
6184 return false;
6185
6186 auto *Shuf = cast<ShuffleVectorInst>(Val: Phi->getOperand(i_nocapture: 0u));
6187
6188 // Ensure result vectors are wider than the argument vector.
6189 auto *InputVT = cast<FixedVectorType>(Val: Op->getType());
6190 auto *ResultVT = cast<FixedVectorType>(Val: Shuf->getType());
6191 auto const InputNumElements = InputVT->getNumElements();
6192
6193 if (InputNumElements >= ResultVT->getNumElements())
6194 return false;
6195
6196 // Take the difference of the two shuffle masks at each index. Ignore poison
6197 // values at the same index in both masks.
6198 SmallVector<int, 16> NewMask;
6199 NewMask.reserve(N: Mask0.size());
6200
6201 for (auto [M0, M1] : zip(t&: Mask0, u&: Mask1)) {
6202 if (M0 >= 0 && M1 >= 0)
6203 NewMask.push_back(Elt: M0 - M1);
6204 else if (M0 == -1 && M1 == -1)
6205 continue;
6206 else
6207 return false;
6208 }
6209
6210 // Ensure all elements of the new mask are equal. If the difference between
6211 // the incoming mask elements is the same, the two must be constant offsets
6212 // of one another.
6213 if (NewMask.empty() || !all_equal(Range&: NewMask))
6214 return false;
6215
6216 // Create new mask using difference of the two incoming masks.
6217 int MaskOffset = NewMask[0u];
6218 unsigned Index = (InputNumElements + MaskOffset) % InputNumElements;
6219 NewMask.clear();
6220
6221 for (unsigned I = 0u; I < InputNumElements; ++I) {
6222 NewMask.push_back(Elt: Index);
6223 Index = (Index + 1u) % InputNumElements;
6224 }
6225
6226 // Calculate costs for worst cases and compare.
6227 auto const Kind = TTI::SK_PermuteSingleSrc;
6228 auto OldCost =
6229 std::max(a: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask0, CostKind),
6230 b: TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind));
6231 auto NewCost = TTI.getShuffleCost(Kind, DstTy: InputVT, SrcTy: InputVT, Mask: NewMask, CostKind) +
6232 TTI.getShuffleCost(Kind, DstTy: ResultVT, SrcTy: InputVT, Mask: Mask1, CostKind);
6233
6234 LLVM_DEBUG(dbgs() << "Found a phi of mergeable shuffles: " << I
6235 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
6236 << "\n");
6237
6238 if (NewCost > OldCost)
6239 return false;
6240
6241 // Create new shuffles and narrowed phi.
6242 auto Builder = IRBuilder(Shuf);
6243 Builder.SetCurrentDebugLocation(Shuf->getDebugLoc());
6244 auto *PoisonVal = PoisonValue::get(T: InputVT);
6245 auto *NewShuf0 = Builder.CreateShuffleVector(V1: Op, V2: PoisonVal, Mask: NewMask);
6246 Worklist.push(I: cast<Instruction>(Val: NewShuf0));
6247
6248 Builder.SetInsertPoint(Phi);
6249 Builder.SetCurrentDebugLocation(Phi->getDebugLoc());
6250 auto *NewPhi = Builder.CreatePHI(Ty: NewShuf0->getType(), NumReservedValues: 2u);
6251 NewPhi->addIncoming(V: NewShuf0, BB: Phi->getIncomingBlock(i: 0u));
6252 NewPhi->addIncoming(V: Op, BB: Phi->getIncomingBlock(i: 1u));
6253
6254 Builder.SetInsertPoint(*NewPhi->getInsertionPointAfterDef());
6255 PoisonVal = PoisonValue::get(T: NewPhi->getType());
6256 auto *NewShuf1 = Builder.CreateShuffleVector(V1: NewPhi, V2: PoisonVal, Mask: Mask1);
6257
6258 replaceValue(Old&: *Phi, New&: *NewShuf1);
6259 return true;
6260}
6261
6262/// This is the entry point for all transforms. Pass manager differences are
6263/// handled in the callers of this function.
6264bool VectorCombine::run() {
6265 if (DisableVectorCombine)
6266 return false;
6267
6268 // Don't attempt vectorization if the target does not support vectors.
6269 if (!TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(/*Vector*/ true)))
6270 return false;
6271
6272 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
6273
6274 auto FoldInst = [this](Instruction &I) {
6275 Builder.SetInsertPoint(&I);
6276 bool IsVectorType = isa<VectorType>(Val: I.getType());
6277 bool IsFixedVectorType = isa<FixedVectorType>(Val: I.getType());
6278 auto Opcode = I.getOpcode();
6279
6280 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
6281
6282 // These folds should be beneficial regardless of when this pass is run
6283 // in the optimization pipeline.
6284 // The type checking is for run-time efficiency. We can avoid wasting time
6285 // dispatching to folding functions if there's no chance of matching.
6286 if (IsFixedVectorType) {
6287 switch (Opcode) {
6288 case Instruction::InsertElement:
6289 if (vectorizeLoadInsert(I))
6290 return true;
6291 break;
6292 case Instruction::ShuffleVector:
6293 if (widenSubvectorLoad(I))
6294 return true;
6295 break;
6296 default:
6297 break;
6298 }
6299 }
6300
6301 // This transform works with scalable and fixed vectors
6302 // TODO: Identify and allow other scalable transforms
6303 if (IsVectorType) {
6304 if (scalarizeOpOrCmp(I))
6305 return true;
6306 if (scalarizeLoad(I))
6307 return true;
6308 if (scalarizeExtExtract(I))
6309 return true;
6310 if (scalarizeVPIntrinsic(I))
6311 return true;
6312 if (foldInterleaveIntrinsics(I))
6313 return true;
6314 if (foldBitcastOfVPLoad(I))
6315 return true;
6316 }
6317
6318 if (foldDeinterleaveIntrinsics(I))
6319 return true;
6320
6321 if (Opcode == Instruction::Store)
6322 if (foldSingleElementStore(I))
6323 return true;
6324
6325 // If this is an early pipeline invocation of this pass, we are done.
6326 if (TryEarlyFoldsOnly)
6327 return false;
6328
6329 if (Opcode == Instruction::Call)
6330 if (foldBitOrderReverseAndSwap(I))
6331 return true;
6332
6333 // Otherwise, try folds that improve codegen but may interfere with
6334 // early IR canonicalizations.
6335 // The type checking is for run-time efficiency. We can avoid wasting time
6336 // dispatching to folding functions if there's no chance of matching.
6337 if (IsFixedVectorType) {
6338 switch (Opcode) {
6339 case Instruction::InsertElement:
6340 if (foldInsExtFNeg(I))
6341 return true;
6342 if (foldInsExtBinop(I))
6343 return true;
6344 if (foldInsExtVectorToShuffle(I))
6345 return true;
6346 break;
6347 case Instruction::ShuffleVector:
6348 if (foldPermuteOfBinops(I))
6349 return true;
6350 if (foldShuffleOfBinops(I))
6351 return true;
6352 if (foldShuffleOfSelects(I))
6353 return true;
6354 if (foldShuffleOfCastops(I))
6355 return true;
6356 if (foldShuffleOfShuffles(I))
6357 return true;
6358 if (foldPermuteOfIntrinsic(I))
6359 return true;
6360 if (foldShufflesOfLengthChangingShuffles(I))
6361 return true;
6362 if (foldShuffleOfIntrinsics(I))
6363 return true;
6364 if (foldSelectShuffle(I))
6365 return true;
6366 if (foldShuffleToIdentity(I))
6367 return true;
6368 break;
6369 case Instruction::Load:
6370 if (shrinkLoadForShuffles(I))
6371 return true;
6372 break;
6373 case Instruction::BitCast:
6374 if (foldBitcastShuffle(I))
6375 return true;
6376 if (foldSelectsFromBitcast(I))
6377 return true;
6378 break;
6379 case Instruction::And:
6380 case Instruction::Or:
6381 case Instruction::Xor:
6382 if (foldBitOpOfCastops(I))
6383 return true;
6384 if (foldBitOpOfCastConstant(I))
6385 return true;
6386 break;
6387 case Instruction::PHI:
6388 if (shrinkPhiOfShuffles(I))
6389 return true;
6390 break;
6391 default:
6392 if (shrinkType(I))
6393 return true;
6394 break;
6395 }
6396 } else {
6397 switch (Opcode) {
6398 case Instruction::Call:
6399 if (foldShuffleFromReductions(I))
6400 return true;
6401 if (foldCastFromReductions(I))
6402 return true;
6403 break;
6404 case Instruction::ExtractElement:
6405 if (foldShuffleChainsToReduce(I))
6406 return true;
6407 break;
6408 case Instruction::ICmp:
6409 if (foldSignBitReductionCmp(I))
6410 return true;
6411 if (foldICmpEqZeroVectorReduce(I))
6412 return true;
6413 if (foldReductionZeroTest(I))
6414 return true;
6415 if (foldEquivalentReductionCmp(I))
6416 return true;
6417 if (foldReduceAddCmpZero(I))
6418 return true;
6419 [[fallthrough]];
6420 case Instruction::FCmp:
6421 if (foldExtractExtract(I))
6422 return true;
6423 break;
6424 case Instruction::Or:
6425 if (foldConcatOfBoolMasks(I))
6426 return true;
6427 [[fallthrough]];
6428 default:
6429 if (Instruction::isBinaryOp(Opcode)) {
6430 if (foldExtractExtract(I))
6431 return true;
6432 if (foldExtractedCmps(I))
6433 return true;
6434 if (foldBinopOfReductions(I))
6435 return true;
6436 }
6437 break;
6438 }
6439 }
6440 return false;
6441 };
6442
6443 bool MadeChange = false;
6444 for (BasicBlock &BB : F) {
6445 // Ignore unreachable basic blocks.
6446 if (!DT.isReachableFromEntry(A: &BB))
6447 continue;
6448 // Use early increment range so that we can erase instructions in loop.
6449 // make_early_inc_range is not applicable here, as the next iterator may
6450 // be invalidated by RecursivelyDeleteTriviallyDeadInstructions.
6451 // We manually maintain the next instruction and update it when it is about
6452 // to be deleted.
6453 Instruction *I = &BB.front();
6454 while (I) {
6455 NextInst = I->getNextNode();
6456 if (!I->isDebugOrPseudoInst())
6457 MadeChange |= FoldInst(*I);
6458 I = NextInst;
6459 }
6460 }
6461
6462 NextInst = nullptr;
6463
6464 while (!Worklist.isEmpty()) {
6465 Instruction *I = Worklist.removeOne();
6466 if (!I)
6467 continue;
6468
6469 if (isInstructionTriviallyDead(I)) {
6470 eraseInstruction(I&: *I);
6471 continue;
6472 }
6473
6474 MadeChange |= FoldInst(*I);
6475 }
6476
6477 return MadeChange;
6478}
6479
6480PreservedAnalyses VectorCombinePass::run(Function &F,
6481 FunctionAnalysisManager &FAM) {
6482 auto &AC = FAM.getResult<AssumptionAnalysis>(IR&: F);
6483 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
6484 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
6485 AAResults &AA = FAM.getResult<AAManager>(IR&: F);
6486 const DataLayout *DL = &F.getDataLayout();
6487 TTI::TargetCostKind CostKind =
6488 F.hasOptSize() ? TTI::TCK_CodeSize : TTI::TCK_RecipThroughput;
6489 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, CostKind, TryEarlyFoldsOnly);
6490 if (!Combiner.run())
6491 return PreservedAnalyses::all();
6492 PreservedAnalyses PA;
6493 PA.preserveSet<CFGAnalyses>();
6494 return PA;
6495}
6496