1//===------- VectorCombine.cpp - Optimize partial vector operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass optimizes scalar/vector interactions using target cost models. The
10// transforms implemented here may not fit in traditional loop-based or SLP
11// vectorization passes.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/Transforms/Vectorize/VectorCombine.h"
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/ScopeExit.h"
19#include "llvm/ADT/Statistic.h"
20#include "llvm/Analysis/AssumptionCache.h"
21#include "llvm/Analysis/BasicAliasAnalysis.h"
22#include "llvm/Analysis/GlobalsModRef.h"
23#include "llvm/Analysis/InstSimplifyFolder.h"
24#include "llvm/Analysis/Loads.h"
25#include "llvm/Analysis/TargetFolder.h"
26#include "llvm/Analysis/TargetTransformInfo.h"
27#include "llvm/Analysis/ValueTracking.h"
28#include "llvm/Analysis/VectorUtils.h"
29#include "llvm/IR/Dominators.h"
30#include "llvm/IR/Function.h"
31#include "llvm/IR/IRBuilder.h"
32#include "llvm/IR/PatternMatch.h"
33#include "llvm/Support/CommandLine.h"
34#include "llvm/Transforms/Utils/Local.h"
35#include "llvm/Transforms/Utils/LoopUtils.h"
36#include <numeric>
37#include <queue>
38#include <set>
39
40#define DEBUG_TYPE "vector-combine"
41#include "llvm/Transforms/Utils/InstructionWorklist.h"
42
43using namespace llvm;
44using namespace llvm::PatternMatch;
45
46STATISTIC(NumVecLoad, "Number of vector loads formed");
47STATISTIC(NumVecCmp, "Number of vector compares formed");
48STATISTIC(NumVecBO, "Number of vector binops formed");
49STATISTIC(NumVecCmpBO, "Number of vector compare + binop formed");
50STATISTIC(NumShufOfBitcast, "Number of shuffles moved after bitcast");
51STATISTIC(NumScalarOps, "Number of scalar unary + binary ops formed");
52STATISTIC(NumScalarCmp, "Number of scalar compares formed");
53STATISTIC(NumScalarIntrinsic, "Number of scalar intrinsic calls formed");
54
55static cl::opt<bool> DisableVectorCombine(
56 "disable-vector-combine", cl::init(Val: false), cl::Hidden,
57 cl::desc("Disable all vector combine transforms"));
58
59static cl::opt<bool> DisableBinopExtractShuffle(
60 "disable-binop-extract-shuffle", cl::init(Val: false), cl::Hidden,
61 cl::desc("Disable binop extract to shuffle transforms"));
62
63static cl::opt<unsigned> MaxInstrsToScan(
64 "vector-combine-max-scan-instrs", cl::init(Val: 30), cl::Hidden,
65 cl::desc("Max number of instructions to scan for vector combining."));
66
67static const unsigned InvalidIndex = std::numeric_limits<unsigned>::max();
68
69namespace {
70class VectorCombine {
71public:
72 VectorCombine(Function &F, const TargetTransformInfo &TTI,
73 const DominatorTree &DT, AAResults &AA, AssumptionCache &AC,
74 const DataLayout *DL, TTI::TargetCostKind CostKind,
75 bool TryEarlyFoldsOnly)
76 : F(F), Builder(F.getContext(), InstSimplifyFolder(*DL)), TTI(TTI),
77 DT(DT), AA(AA), AC(AC), DL(DL), CostKind(CostKind),
78 TryEarlyFoldsOnly(TryEarlyFoldsOnly) {}
79
80 bool run();
81
82private:
83 Function &F;
84 IRBuilder<InstSimplifyFolder> Builder;
85 const TargetTransformInfo &TTI;
86 const DominatorTree &DT;
87 AAResults &AA;
88 AssumptionCache &AC;
89 const DataLayout *DL;
90 TTI::TargetCostKind CostKind;
91
92 /// If true, only perform beneficial early IR transforms. Do not introduce new
93 /// vector operations.
94 bool TryEarlyFoldsOnly;
95
96 InstructionWorklist Worklist;
97
98 // TODO: Direct calls from the top-level "run" loop use a plain "Instruction"
99 // parameter. That should be updated to specific sub-classes because the
100 // run loop was changed to dispatch on opcode.
101 bool vectorizeLoadInsert(Instruction &I);
102 bool widenSubvectorLoad(Instruction &I);
103 ExtractElementInst *getShuffleExtract(ExtractElementInst *Ext0,
104 ExtractElementInst *Ext1,
105 unsigned PreferredExtractIndex) const;
106 bool isExtractExtractCheap(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
107 const Instruction &I,
108 ExtractElementInst *&ConvertToShuffle,
109 unsigned PreferredExtractIndex);
110 void foldExtExtCmp(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
111 Instruction &I);
112 void foldExtExtBinop(ExtractElementInst *Ext0, ExtractElementInst *Ext1,
113 Instruction &I);
114 bool foldExtractExtract(Instruction &I);
115 bool foldInsExtFNeg(Instruction &I);
116 bool foldInsExtBinop(Instruction &I);
117 bool foldInsExtVectorToShuffle(Instruction &I);
118 bool foldBitOpOfBitcasts(Instruction &I);
119 bool foldBitcastShuffle(Instruction &I);
120 bool scalarizeOpOrCmp(Instruction &I);
121 bool scalarizeVPIntrinsic(Instruction &I);
122 bool foldExtractedCmps(Instruction &I);
123 bool foldBinopOfReductions(Instruction &I);
124 bool foldSingleElementStore(Instruction &I);
125 bool scalarizeLoadExtract(Instruction &I);
126 bool scalarizeExtExtract(Instruction &I);
127 bool foldConcatOfBoolMasks(Instruction &I);
128 bool foldPermuteOfBinops(Instruction &I);
129 bool foldShuffleOfBinops(Instruction &I);
130 bool foldShuffleOfSelects(Instruction &I);
131 bool foldShuffleOfCastops(Instruction &I);
132 bool foldShuffleOfShuffles(Instruction &I);
133 bool foldShuffleOfIntrinsics(Instruction &I);
134 bool foldShuffleToIdentity(Instruction &I);
135 bool foldShuffleFromReductions(Instruction &I);
136 bool foldCastFromReductions(Instruction &I);
137 bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
138 bool foldInterleaveIntrinsics(Instruction &I);
139 bool shrinkType(Instruction &I);
140
141 void replaceValue(Value &Old, Value &New) {
142 LLVM_DEBUG(dbgs() << "VC: Replacing: " << Old << '\n');
143 LLVM_DEBUG(dbgs() << " With: " << New << '\n');
144 Old.replaceAllUsesWith(V: &New);
145 if (auto *NewI = dyn_cast<Instruction>(Val: &New)) {
146 New.takeName(V: &Old);
147 Worklist.pushUsersToWorkList(I&: *NewI);
148 Worklist.pushValue(V: NewI);
149 }
150 Worklist.pushValue(V: &Old);
151 }
152
153 void eraseInstruction(Instruction &I) {
154 LLVM_DEBUG(dbgs() << "VC: Erasing: " << I << '\n');
155 SmallVector<Value *> Ops(I.operands());
156 Worklist.remove(I: &I);
157 I.eraseFromParent();
158
159 // Push remaining users of the operands and then the operand itself - allows
160 // further folds that were hindered by OneUse limits.
161 for (Value *Op : Ops)
162 if (auto *OpI = dyn_cast<Instruction>(Val: Op)) {
163 Worklist.pushUsersToWorkList(I&: *OpI);
164 Worklist.pushValue(V: OpI);
165 }
166 }
167};
168} // namespace
169
170/// Return the source operand of a potentially bitcasted value. If there is no
171/// bitcast, return the input value itself.
172static Value *peekThroughBitcasts(Value *V) {
173 while (auto *BitCast = dyn_cast<BitCastInst>(Val: V))
174 V = BitCast->getOperand(i_nocapture: 0);
175 return V;
176}
177
178static bool canWidenLoad(LoadInst *Load, const TargetTransformInfo &TTI) {
179 // Do not widen load if atomic/volatile or under asan/hwasan/memtag/tsan.
180 // The widened load may load data from dirty regions or create data races
181 // non-existent in the source.
182 if (!Load || !Load->isSimple() || !Load->hasOneUse() ||
183 Load->getFunction()->hasFnAttribute(Kind: Attribute::SanitizeMemTag) ||
184 mustSuppressSpeculation(LI: *Load))
185 return false;
186
187 // We are potentially transforming byte-sized (8-bit) memory accesses, so make
188 // sure we have all of our type-based constraints in place for this target.
189 Type *ScalarTy = Load->getType()->getScalarType();
190 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
191 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
192 if (!ScalarSize || !MinVectorSize || MinVectorSize % ScalarSize != 0 ||
193 ScalarSize % 8 != 0)
194 return false;
195
196 return true;
197}
198
199bool VectorCombine::vectorizeLoadInsert(Instruction &I) {
200 // Match insert into fixed vector of scalar value.
201 // TODO: Handle non-zero insert index.
202 Value *Scalar;
203 if (!match(V: &I,
204 P: m_InsertElt(Val: m_Poison(), Elt: m_OneUse(SubPattern: m_Value(V&: Scalar)), Idx: m_ZeroInt())))
205 return false;
206
207 // Optionally match an extract from another vector.
208 Value *X;
209 bool HasExtract = match(V: Scalar, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ZeroInt()));
210 if (!HasExtract)
211 X = Scalar;
212
213 auto *Load = dyn_cast<LoadInst>(Val: X);
214 if (!canWidenLoad(Load, TTI))
215 return false;
216
217 Type *ScalarTy = Scalar->getType();
218 uint64_t ScalarSize = ScalarTy->getPrimitiveSizeInBits();
219 unsigned MinVectorSize = TTI.getMinVectorRegisterBitWidth();
220
221 // Check safety of replacing the scalar load with a larger vector load.
222 // We use minimal alignment (maximum flexibility) because we only care about
223 // the dereferenceable region. When calculating cost and creating a new op,
224 // we may use a larger value based on alignment attributes.
225 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
226 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
227
228 unsigned MinVecNumElts = MinVectorSize / ScalarSize;
229 auto *MinVecTy = VectorType::get(ElementType: ScalarTy, NumElements: MinVecNumElts, Scalable: false);
230 unsigned OffsetEltIndex = 0;
231 Align Alignment = Load->getAlign();
232 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
233 DT: &DT)) {
234 // It is not safe to load directly from the pointer, but we can still peek
235 // through gep offsets and check if it safe to load from a base address with
236 // updated alignment. If it is, we can shuffle the element(s) into place
237 // after loading.
238 unsigned OffsetBitWidth = DL->getIndexTypeSizeInBits(Ty: SrcPtr->getType());
239 APInt Offset(OffsetBitWidth, 0);
240 SrcPtr = SrcPtr->stripAndAccumulateInBoundsConstantOffsets(DL: *DL, Offset);
241
242 // We want to shuffle the result down from a high element of a vector, so
243 // the offset must be positive.
244 if (Offset.isNegative())
245 return false;
246
247 // The offset must be a multiple of the scalar element to shuffle cleanly
248 // in the element's size.
249 uint64_t ScalarSizeInBytes = ScalarSize / 8;
250 if (Offset.urem(RHS: ScalarSizeInBytes) != 0)
251 return false;
252
253 // If we load MinVecNumElts, will our target element still be loaded?
254 OffsetEltIndex = Offset.udiv(RHS: ScalarSizeInBytes).getZExtValue();
255 if (OffsetEltIndex >= MinVecNumElts)
256 return false;
257
258 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty: MinVecTy, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC,
259 DT: &DT))
260 return false;
261
262 // Update alignment with offset value. Note that the offset could be negated
263 // to more accurately represent "(new) SrcPtr - Offset = (old) SrcPtr", but
264 // negation does not change the result of the alignment calculation.
265 Alignment = commonAlignment(A: Alignment, Offset: Offset.getZExtValue());
266 }
267
268 // Original pattern: insertelt undef, load [free casts of] PtrOp, 0
269 // Use the greater of the alignment on the load or its source pointer.
270 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
271 Type *LoadTy = Load->getType();
272 unsigned AS = Load->getPointerAddressSpace();
273 InstructionCost OldCost =
274 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
275 APInt DemandedElts = APInt::getOneBitSet(numBits: MinVecNumElts, BitNo: 0);
276 OldCost +=
277 TTI.getScalarizationOverhead(Ty: MinVecTy, DemandedElts,
278 /* Insert */ true, Extract: HasExtract, CostKind);
279
280 // New pattern: load VecPtr
281 InstructionCost NewCost =
282 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: MinVecTy, Alignment, AddressSpace: AS, CostKind);
283 // Optionally, we are shuffling the loaded vector element(s) into place.
284 // For the mask set everything but element 0 to undef to prevent poison from
285 // propagating from the extra loaded memory. This will also optionally
286 // shrink/grow the vector from the loaded size to the output size.
287 // We assume this operation has no cost in codegen if there was no offset.
288 // Note that we could use freeze to avoid poison problems, but then we might
289 // still need a shuffle to change the vector size.
290 auto *Ty = cast<FixedVectorType>(Val: I.getType());
291 unsigned OutputNumElts = Ty->getNumElements();
292 SmallVector<int, 16> Mask(OutputNumElts, PoisonMaskElem);
293 assert(OffsetEltIndex < MinVecNumElts && "Address offset too big");
294 Mask[0] = OffsetEltIndex;
295 if (OffsetEltIndex)
296 NewCost += TTI.getShuffleCost(Kind: TTI::SK_PermuteSingleSrc, DstTy: Ty, SrcTy: MinVecTy, Mask,
297 CostKind);
298
299 // We can aggressively convert to the vector form because the backend can
300 // invert this transform if it does not result in a performance win.
301 if (OldCost < NewCost || !NewCost.isValid())
302 return false;
303
304 // It is safe and potentially profitable to load a vector directly:
305 // inselt undef, load Scalar, 0 --> load VecPtr
306 IRBuilder<> Builder(Load);
307 Value *CastedPtr =
308 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
309 Value *VecLd = Builder.CreateAlignedLoad(Ty: MinVecTy, Ptr: CastedPtr, Align: Alignment);
310 VecLd = Builder.CreateShuffleVector(V: VecLd, Mask);
311
312 replaceValue(Old&: I, New&: *VecLd);
313 ++NumVecLoad;
314 return true;
315}
316
317/// If we are loading a vector and then inserting it into a larger vector with
318/// undefined elements, try to load the larger vector and eliminate the insert.
319/// This removes a shuffle in IR and may allow combining of other loaded values.
320bool VectorCombine::widenSubvectorLoad(Instruction &I) {
321 // Match subvector insert of fixed vector.
322 auto *Shuf = cast<ShuffleVectorInst>(Val: &I);
323 if (!Shuf->isIdentityWithPadding())
324 return false;
325
326 // Allow a non-canonical shuffle mask that is choosing elements from op1.
327 unsigned NumOpElts =
328 cast<FixedVectorType>(Val: Shuf->getOperand(i_nocapture: 0)->getType())->getNumElements();
329 unsigned OpIndex = any_of(Range: Shuf->getShuffleMask(), P: [&NumOpElts](int M) {
330 return M >= (int)(NumOpElts);
331 });
332
333 auto *Load = dyn_cast<LoadInst>(Val: Shuf->getOperand(i_nocapture: OpIndex));
334 if (!canWidenLoad(Load, TTI))
335 return false;
336
337 // We use minimal alignment (maximum flexibility) because we only care about
338 // the dereferenceable region. When calculating cost and creating a new op,
339 // we may use a larger value based on alignment attributes.
340 auto *Ty = cast<FixedVectorType>(Val: I.getType());
341 Value *SrcPtr = Load->getPointerOperand()->stripPointerCasts();
342 assert(isa<PointerType>(SrcPtr->getType()) && "Expected a pointer type");
343 Align Alignment = Load->getAlign();
344 if (!isSafeToLoadUnconditionally(V: SrcPtr, Ty, Alignment: Align(1), DL: *DL, ScanFrom: Load, AC: &AC, DT: &DT))
345 return false;
346
347 Alignment = std::max(a: SrcPtr->getPointerAlignment(DL: *DL), b: Alignment);
348 Type *LoadTy = Load->getType();
349 unsigned AS = Load->getPointerAddressSpace();
350
351 // Original pattern: insert_subvector (load PtrOp)
352 // This conservatively assumes that the cost of a subvector insert into an
353 // undef value is 0. We could add that cost if the cost model accurately
354 // reflects the real cost of that operation.
355 InstructionCost OldCost =
356 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: LoadTy, Alignment, AddressSpace: AS, CostKind);
357
358 // New pattern: load PtrOp
359 InstructionCost NewCost =
360 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: Ty, Alignment, AddressSpace: AS, CostKind);
361
362 // We can aggressively convert to the vector form because the backend can
363 // invert this transform if it does not result in a performance win.
364 if (OldCost < NewCost || !NewCost.isValid())
365 return false;
366
367 IRBuilder<> Builder(Load);
368 Value *CastedPtr =
369 Builder.CreatePointerBitCastOrAddrSpaceCast(V: SrcPtr, DestTy: Builder.getPtrTy(AddrSpace: AS));
370 Value *VecLd = Builder.CreateAlignedLoad(Ty, Ptr: CastedPtr, Align: Alignment);
371 replaceValue(Old&: I, New&: *VecLd);
372 ++NumVecLoad;
373 return true;
374}
375
376/// Determine which, if any, of the inputs should be replaced by a shuffle
377/// followed by extract from a different index.
378ExtractElementInst *VectorCombine::getShuffleExtract(
379 ExtractElementInst *Ext0, ExtractElementInst *Ext1,
380 unsigned PreferredExtractIndex = InvalidIndex) const {
381 auto *Index0C = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
382 auto *Index1C = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
383 assert(Index0C && Index1C && "Expected constant extract indexes");
384
385 unsigned Index0 = Index0C->getZExtValue();
386 unsigned Index1 = Index1C->getZExtValue();
387
388 // If the extract indexes are identical, no shuffle is needed.
389 if (Index0 == Index1)
390 return nullptr;
391
392 Type *VecTy = Ext0->getVectorOperand()->getType();
393 assert(VecTy == Ext1->getVectorOperand()->getType() && "Need matching types");
394 InstructionCost Cost0 =
395 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
396 InstructionCost Cost1 =
397 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
398
399 // If both costs are invalid no shuffle is needed
400 if (!Cost0.isValid() && !Cost1.isValid())
401 return nullptr;
402
403 // We are extracting from 2 different indexes, so one operand must be shuffled
404 // before performing a vector operation and/or extract. The more expensive
405 // extract will be replaced by a shuffle.
406 if (Cost0 > Cost1)
407 return Ext0;
408 if (Cost1 > Cost0)
409 return Ext1;
410
411 // If the costs are equal and there is a preferred extract index, shuffle the
412 // opposite operand.
413 if (PreferredExtractIndex == Index0)
414 return Ext1;
415 if (PreferredExtractIndex == Index1)
416 return Ext0;
417
418 // Otherwise, replace the extract with the higher index.
419 return Index0 > Index1 ? Ext0 : Ext1;
420}
421
422/// Compare the relative costs of 2 extracts followed by scalar operation vs.
423/// vector operation(s) followed by extract. Return true if the existing
424/// instructions are cheaper than a vector alternative. Otherwise, return false
425/// and if one of the extracts should be transformed to a shufflevector, set
426/// \p ConvertToShuffle to that extract instruction.
427bool VectorCombine::isExtractExtractCheap(ExtractElementInst *Ext0,
428 ExtractElementInst *Ext1,
429 const Instruction &I,
430 ExtractElementInst *&ConvertToShuffle,
431 unsigned PreferredExtractIndex) {
432 auto *Ext0IndexC = dyn_cast<ConstantInt>(Val: Ext0->getIndexOperand());
433 auto *Ext1IndexC = dyn_cast<ConstantInt>(Val: Ext1->getIndexOperand());
434 assert(Ext0IndexC && Ext1IndexC && "Expected constant extract indexes");
435
436 unsigned Opcode = I.getOpcode();
437 Value *Ext0Src = Ext0->getVectorOperand();
438 Value *Ext1Src = Ext1->getVectorOperand();
439 Type *ScalarTy = Ext0->getType();
440 auto *VecTy = cast<VectorType>(Val: Ext0Src->getType());
441 InstructionCost ScalarOpCost, VectorOpCost;
442
443 // Get cost estimates for scalar and vector versions of the operation.
444 bool IsBinOp = Instruction::isBinaryOp(Opcode);
445 if (IsBinOp) {
446 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
447 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
448 } else {
449 assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) &&
450 "Expected a compare");
451 CmpInst::Predicate Pred = cast<CmpInst>(Val: I).getPredicate();
452 ScalarOpCost = TTI.getCmpSelInstrCost(
453 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
454 VectorOpCost = TTI.getCmpSelInstrCost(
455 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
456 }
457
458 // Get cost estimates for the extract elements. These costs will factor into
459 // both sequences.
460 unsigned Ext0Index = Ext0IndexC->getZExtValue();
461 unsigned Ext1Index = Ext1IndexC->getZExtValue();
462
463 InstructionCost Extract0Cost =
464 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Ext0Index);
465 InstructionCost Extract1Cost =
466 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Ext1Index);
467
468 // A more expensive extract will always be replaced by a splat shuffle.
469 // For example, if Ext0 is more expensive:
470 // opcode (extelt V0, Ext0), (ext V1, Ext1) -->
471 // extelt (opcode (splat V0, Ext0), V1), Ext1
472 // TODO: Evaluate whether that always results in lowest cost. Alternatively,
473 // check the cost of creating a broadcast shuffle and shuffling both
474 // operands to element 0.
475 unsigned BestExtIndex = Extract0Cost > Extract1Cost ? Ext0Index : Ext1Index;
476 unsigned BestInsIndex = Extract0Cost > Extract1Cost ? Ext1Index : Ext0Index;
477 InstructionCost CheapExtractCost = std::min(a: Extract0Cost, b: Extract1Cost);
478
479 // Extra uses of the extracts mean that we include those costs in the
480 // vector total because those instructions will not be eliminated.
481 InstructionCost OldCost, NewCost;
482 if (Ext0Src == Ext1Src && Ext0Index == Ext1Index) {
483 // Handle a special case. If the 2 extracts are identical, adjust the
484 // formulas to account for that. The extra use charge allows for either the
485 // CSE'd pattern or an unoptimized form with identical values:
486 // opcode (extelt V, C), (extelt V, C) --> extelt (opcode V, V), C
487 bool HasUseTax = Ext0 == Ext1 ? !Ext0->hasNUses(N: 2)
488 : !Ext0->hasOneUse() || !Ext1->hasOneUse();
489 OldCost = CheapExtractCost + ScalarOpCost;
490 NewCost = VectorOpCost + CheapExtractCost + HasUseTax * CheapExtractCost;
491 } else {
492 // Handle the general case. Each extract is actually a different value:
493 // opcode (extelt V0, C0), (extelt V1, C1) --> extelt (opcode V0, V1), C
494 OldCost = Extract0Cost + Extract1Cost + ScalarOpCost;
495 NewCost = VectorOpCost + CheapExtractCost +
496 !Ext0->hasOneUse() * Extract0Cost +
497 !Ext1->hasOneUse() * Extract1Cost;
498 }
499
500 ConvertToShuffle = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex);
501 if (ConvertToShuffle) {
502 if (IsBinOp && DisableBinopExtractShuffle)
503 return true;
504
505 // If we are extracting from 2 different indexes, then one operand must be
506 // shuffled before performing the vector operation. The shuffle mask is
507 // poison except for 1 lane that is being translated to the remaining
508 // extraction lane. Therefore, it is a splat shuffle. Ex:
509 // ShufMask = { poison, poison, 0, poison }
510 // TODO: The cost model has an option for a "broadcast" shuffle
511 // (splat-from-element-0), but no option for a more general splat.
512 if (auto *FixedVecTy = dyn_cast<FixedVectorType>(Val: VecTy)) {
513 SmallVector<int> ShuffleMask(FixedVecTy->getNumElements(),
514 PoisonMaskElem);
515 ShuffleMask[BestInsIndex] = BestExtIndex;
516 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
517 DstTy: VecTy, SrcTy: VecTy, Mask: ShuffleMask, CostKind, Index: 0,
518 SubTp: nullptr, Args: {ConvertToShuffle});
519 } else {
520 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
521 DstTy: VecTy, SrcTy: VecTy, Mask: {}, CostKind, Index: 0, SubTp: nullptr,
522 Args: {ConvertToShuffle});
523 }
524 }
525
526 // Aggressively form a vector op if the cost is equal because the transform
527 // may enable further optimization.
528 // Codegen can reverse this transform (scalarize) if it was not profitable.
529 return OldCost < NewCost;
530}
531
532/// Create a shuffle that translates (shifts) 1 element from the input vector
533/// to a new element location.
534static Value *createShiftShuffle(Value *Vec, unsigned OldIndex,
535 unsigned NewIndex, IRBuilderBase &Builder) {
536 // The shuffle mask is poison except for 1 lane that is being translated
537 // to the new element index. Example for OldIndex == 2 and NewIndex == 0:
538 // ShufMask = { 2, poison, poison, poison }
539 auto *VecTy = cast<FixedVectorType>(Val: Vec->getType());
540 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
541 ShufMask[NewIndex] = OldIndex;
542 return Builder.CreateShuffleVector(V: Vec, Mask: ShufMask, Name: "shift");
543}
544
545/// Given an extract element instruction with constant index operand, shuffle
546/// the source vector (shift the scalar element) to a NewIndex for extraction.
547/// Return null if the input can be constant folded, so that we are not creating
548/// unnecessary instructions.
549static ExtractElementInst *translateExtract(ExtractElementInst *ExtElt,
550 unsigned NewIndex,
551 IRBuilderBase &Builder) {
552 // Shufflevectors can only be created for fixed-width vectors.
553 Value *X = ExtElt->getVectorOperand();
554 if (!isa<FixedVectorType>(Val: X->getType()))
555 return nullptr;
556
557 // If the extract can be constant-folded, this code is unsimplified. Defer
558 // to other passes to handle that.
559 Value *C = ExtElt->getIndexOperand();
560 assert(isa<ConstantInt>(C) && "Expected a constant index operand");
561 if (isa<Constant>(Val: X))
562 return nullptr;
563
564 Value *Shuf = createShiftShuffle(Vec: X, OldIndex: cast<ConstantInt>(Val: C)->getZExtValue(),
565 NewIndex, Builder);
566 return cast<ExtractElementInst>(Val: Builder.CreateExtractElement(Vec: Shuf, Idx: NewIndex));
567}
568
569/// Try to reduce extract element costs by converting scalar compares to vector
570/// compares followed by extract.
571/// cmp (ext0 V0, C), (ext1 V1, C)
572void VectorCombine::foldExtExtCmp(ExtractElementInst *Ext0,
573 ExtractElementInst *Ext1, Instruction &I) {
574 assert(isa<CmpInst>(&I) && "Expected a compare");
575 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
576 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
577 "Expected matching constant extract indexes");
578
579 // cmp Pred (extelt V0, C), (extelt V1, C) --> extelt (cmp Pred V0, V1), C
580 ++NumVecCmp;
581 CmpInst::Predicate Pred = cast<CmpInst>(Val: &I)->getPredicate();
582 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
583 Value *VecCmp = Builder.CreateCmp(Pred, LHS: V0, RHS: V1);
584 Value *NewExt = Builder.CreateExtractElement(Vec: VecCmp, Idx: Ext0->getIndexOperand());
585 replaceValue(Old&: I, New&: *NewExt);
586}
587
588/// Try to reduce extract element costs by converting scalar binops to vector
589/// binops followed by extract.
590/// bo (ext0 V0, C), (ext1 V1, C)
591void VectorCombine::foldExtExtBinop(ExtractElementInst *Ext0,
592 ExtractElementInst *Ext1, Instruction &I) {
593 assert(isa<BinaryOperator>(&I) && "Expected a binary operator");
594 assert(cast<ConstantInt>(Ext0->getIndexOperand())->getZExtValue() ==
595 cast<ConstantInt>(Ext1->getIndexOperand())->getZExtValue() &&
596 "Expected matching constant extract indexes");
597
598 // bo (extelt V0, C), (extelt V1, C) --> extelt (bo V0, V1), C
599 ++NumVecBO;
600 Value *V0 = Ext0->getVectorOperand(), *V1 = Ext1->getVectorOperand();
601 Value *VecBO =
602 Builder.CreateBinOp(Opc: cast<BinaryOperator>(Val: &I)->getOpcode(), LHS: V0, RHS: V1);
603
604 // All IR flags are safe to back-propagate because any potential poison
605 // created in unused vector elements is discarded by the extract.
606 if (auto *VecBOInst = dyn_cast<Instruction>(Val: VecBO))
607 VecBOInst->copyIRFlags(V: &I);
608
609 Value *NewExt = Builder.CreateExtractElement(Vec: VecBO, Idx: Ext0->getIndexOperand());
610 replaceValue(Old&: I, New&: *NewExt);
611}
612
613/// Match an instruction with extracted vector operands.
614bool VectorCombine::foldExtractExtract(Instruction &I) {
615 // It is not safe to transform things like div, urem, etc. because we may
616 // create undefined behavior when executing those on unknown vector elements.
617 if (!isSafeToSpeculativelyExecute(I: &I))
618 return false;
619
620 Instruction *I0, *I1;
621 CmpPredicate Pred = CmpInst::BAD_ICMP_PREDICATE;
622 if (!match(V: &I, P: m_Cmp(Pred, L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))) &&
623 !match(V: &I, P: m_BinOp(L: m_Instruction(I&: I0), R: m_Instruction(I&: I1))))
624 return false;
625
626 Value *V0, *V1;
627 uint64_t C0, C1;
628 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: V0), Idx: m_ConstantInt(V&: C0))) ||
629 !match(V: I1, P: m_ExtractElt(Val: m_Value(V&: V1), Idx: m_ConstantInt(V&: C1))) ||
630 V0->getType() != V1->getType())
631 return false;
632
633 // If the scalar value 'I' is going to be re-inserted into a vector, then try
634 // to create an extract to that same element. The extract/insert can be
635 // reduced to a "select shuffle".
636 // TODO: If we add a larger pattern match that starts from an insert, this
637 // probably becomes unnecessary.
638 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
639 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
640 uint64_t InsertIndex = InvalidIndex;
641 if (I.hasOneUse())
642 match(V: I.user_back(),
643 P: m_InsertElt(Val: m_Value(), Elt: m_Value(), Idx: m_ConstantInt(V&: InsertIndex)));
644
645 ExtractElementInst *ExtractToChange;
646 if (isExtractExtractCheap(Ext0, Ext1, I, ConvertToShuffle&: ExtractToChange, PreferredExtractIndex: InsertIndex))
647 return false;
648
649 if (ExtractToChange) {
650 unsigned CheapExtractIdx = ExtractToChange == Ext0 ? C1 : C0;
651 ExtractElementInst *NewExtract =
652 translateExtract(ExtElt: ExtractToChange, NewIndex: CheapExtractIdx, Builder);
653 if (!NewExtract)
654 return false;
655 if (ExtractToChange == Ext0)
656 Ext0 = NewExtract;
657 else
658 Ext1 = NewExtract;
659 }
660
661 if (Pred != CmpInst::BAD_ICMP_PREDICATE)
662 foldExtExtCmp(Ext0, Ext1, I);
663 else
664 foldExtExtBinop(Ext0, Ext1, I);
665
666 Worklist.push(I: Ext0);
667 Worklist.push(I: Ext1);
668 return true;
669}
670
671/// Try to replace an extract + scalar fneg + insert with a vector fneg +
672/// shuffle.
673bool VectorCombine::foldInsExtFNeg(Instruction &I) {
674 // Match an insert (op (extract)) pattern.
675 Value *DestVec;
676 uint64_t Index;
677 Instruction *FNeg;
678 if (!match(V: &I, P: m_InsertElt(Val: m_Value(V&: DestVec), Elt: m_OneUse(SubPattern: m_Instruction(I&: FNeg)),
679 Idx: m_ConstantInt(V&: Index))))
680 return false;
681
682 // Note: This handles the canonical fneg instruction and "fsub -0.0, X".
683 Value *SrcVec;
684 Instruction *Extract;
685 if (!match(V: FNeg, P: m_FNeg(X: m_CombineAnd(
686 L: m_Instruction(I&: Extract),
687 R: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_SpecificInt(V: Index))))))
688 return false;
689
690 auto *VecTy = cast<FixedVectorType>(Val: I.getType());
691 auto *ScalarTy = VecTy->getScalarType();
692 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
693 if (!SrcVecTy || ScalarTy != SrcVecTy->getScalarType())
694 return false;
695
696 // Ignore bogus insert/extract index.
697 unsigned NumElts = VecTy->getNumElements();
698 if (Index >= NumElts)
699 return false;
700
701 // We are inserting the negated element into the same lane that we extracted
702 // from. This is equivalent to a select-shuffle that chooses all but the
703 // negated element from the destination vector.
704 SmallVector<int> Mask(NumElts);
705 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
706 Mask[Index] = Index + NumElts;
707 InstructionCost OldCost =
708 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: ScalarTy, CostKind) +
709 TTI.getVectorInstrCost(I, Val: VecTy, CostKind, Index);
710
711 // If the extract has one use, it will be eliminated, so count it in the
712 // original cost. If it has more than one use, ignore the cost because it will
713 // be the same before/after.
714 if (Extract->hasOneUse())
715 OldCost += TTI.getVectorInstrCost(I: *Extract, Val: VecTy, CostKind, Index);
716
717 InstructionCost NewCost =
718 TTI.getArithmeticInstrCost(Opcode: Instruction::FNeg, Ty: VecTy, CostKind) +
719 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: VecTy, SrcTy: VecTy,
720 Mask, CostKind);
721
722 bool NeedLenChg = SrcVecTy->getNumElements() != NumElts;
723 // If the lengths of the two vectors are not equal,
724 // we need to add a length-change vector. Add this cost.
725 SmallVector<int> SrcMask;
726 if (NeedLenChg) {
727 SrcMask.assign(NumElts, Elt: PoisonMaskElem);
728 SrcMask[Index] = Index;
729 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
730 DstTy: VecTy, SrcTy: SrcVecTy, Mask: SrcMask, CostKind);
731 }
732
733 if (NewCost > OldCost)
734 return false;
735
736 Value *NewShuf;
737 // insertelt DestVec, (fneg (extractelt SrcVec, Index)), Index
738 Value *VecFNeg = Builder.CreateFNegFMF(V: SrcVec, FMFSource: FNeg);
739 if (NeedLenChg) {
740 // shuffle DestVec, (shuffle (fneg SrcVec), poison, SrcMask), Mask
741 Value *LenChgShuf = Builder.CreateShuffleVector(V: VecFNeg, Mask: SrcMask);
742 NewShuf = Builder.CreateShuffleVector(V1: DestVec, V2: LenChgShuf, Mask);
743 } else {
744 // shuffle DestVec, (fneg SrcVec), Mask
745 NewShuf = Builder.CreateShuffleVector(V1: DestVec, V2: VecFNeg, Mask);
746 }
747
748 replaceValue(Old&: I, New&: *NewShuf);
749 return true;
750}
751
752/// Try to fold insert(binop(x,y),binop(a,b),idx)
753/// --> binop(insert(x,a,idx),insert(y,b,idx))
754bool VectorCombine::foldInsExtBinop(Instruction &I) {
755 BinaryOperator *VecBinOp, *SclBinOp;
756 uint64_t Index;
757 if (!match(V: &I,
758 P: m_InsertElt(Val: m_OneUse(SubPattern: m_BinOp(I&: VecBinOp)),
759 Elt: m_OneUse(SubPattern: m_BinOp(I&: SclBinOp)), Idx: m_ConstantInt(V&: Index))))
760 return false;
761
762 // TODO: Add support for addlike etc.
763 Instruction::BinaryOps BinOpcode = VecBinOp->getOpcode();
764 if (BinOpcode != SclBinOp->getOpcode())
765 return false;
766
767 auto *ResultTy = dyn_cast<FixedVectorType>(Val: I.getType());
768 if (!ResultTy)
769 return false;
770
771 // TODO: Attempt to detect m_ExtractElt for scalar operands and convert to
772 // shuffle?
773
774 InstructionCost OldCost = TTI.getInstructionCost(U: &I, CostKind) +
775 TTI.getInstructionCost(U: VecBinOp, CostKind) +
776 TTI.getInstructionCost(U: SclBinOp, CostKind);
777 InstructionCost NewCost =
778 TTI.getArithmeticInstrCost(Opcode: BinOpcode, Ty: ResultTy, CostKind) +
779 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
780 Index, Op0: VecBinOp->getOperand(i_nocapture: 0),
781 Op1: SclBinOp->getOperand(i_nocapture: 0)) +
782 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: ResultTy, CostKind,
783 Index, Op0: VecBinOp->getOperand(i_nocapture: 1),
784 Op1: SclBinOp->getOperand(i_nocapture: 1));
785
786 LLVM_DEBUG(dbgs() << "Found an insertion of two binops: " << I
787 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
788 << "\n");
789 if (NewCost > OldCost)
790 return false;
791
792 Value *NewIns0 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 0),
793 NewElt: SclBinOp->getOperand(i_nocapture: 0), Idx: Index);
794 Value *NewIns1 = Builder.CreateInsertElement(Vec: VecBinOp->getOperand(i_nocapture: 1),
795 NewElt: SclBinOp->getOperand(i_nocapture: 1), Idx: Index);
796 Value *NewBO = Builder.CreateBinOp(Opc: BinOpcode, LHS: NewIns0, RHS: NewIns1);
797
798 // Intersect flags from the old binops.
799 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
800 NewInst->copyIRFlags(V: VecBinOp);
801 NewInst->andIRFlags(V: SclBinOp);
802 }
803
804 Worklist.pushValue(V: NewIns0);
805 Worklist.pushValue(V: NewIns1);
806 replaceValue(Old&: I, New&: *NewBO);
807 return true;
808}
809
810bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
811 // Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
812 Value *LHSSrc, *RHSSrc;
813 if (!match(V: &I, P: m_BitwiseLogic(L: m_BitCast(Op: m_Value(V&: LHSSrc)),
814 R: m_BitCast(Op: m_Value(V&: RHSSrc)))))
815 return false;
816
817 // Source types must match
818 if (LHSSrc->getType() != RHSSrc->getType())
819 return false;
820 if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
821 return false;
822
823 // Only handle vector types
824 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: LHSSrc->getType());
825 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
826 if (!SrcVecTy || !DstVecTy)
827 return false;
828
829 // Same total bit width
830 assert(SrcVecTy->getPrimitiveSizeInBits() ==
831 DstVecTy->getPrimitiveSizeInBits() &&
832 "Bitcast should preserve total bit width");
833
834 // Cost Check :
835 // OldCost = bitlogic + 2*bitcasts
836 // NewCost = bitlogic + bitcast
837 auto *BinOp = cast<BinaryOperator>(Val: &I);
838 InstructionCost OldCost =
839 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: DstVecTy) +
840 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DstVecTy, Src: LHSSrc->getType(),
841 CCH: TTI::CastContextHint::None) +
842 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DstVecTy, Src: RHSSrc->getType(),
843 CCH: TTI::CastContextHint::None);
844 InstructionCost NewCost =
845 TTI.getArithmeticInstrCost(Opcode: BinOp->getOpcode(), Ty: SrcVecTy) +
846 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DstVecTy, Src: SrcVecTy,
847 CCH: TTI::CastContextHint::None);
848
849 LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
850 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
851 << "\n");
852
853 if (NewCost > OldCost)
854 return false;
855
856 // Create the operation on the source type
857 Value *NewOp = Builder.CreateBinOp(Opc: BinOp->getOpcode(), LHS: LHSSrc, RHS: RHSSrc,
858 Name: BinOp->getName() + ".inner");
859 if (auto *NewBinOp = dyn_cast<BinaryOperator>(Val: NewOp))
860 NewBinOp->copyIRFlags(V: BinOp);
861
862 Worklist.pushValue(V: NewOp);
863
864 // Bitcast the result back
865 Value *Result = Builder.CreateBitCast(V: NewOp, DestTy: I.getType());
866 replaceValue(Old&: I, New&: *Result);
867 return true;
868}
869
870/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
871/// destination type followed by shuffle. This can enable further transforms by
872/// moving bitcasts or shuffles together.
873bool VectorCombine::foldBitcastShuffle(Instruction &I) {
874 Value *V0, *V1;
875 ArrayRef<int> Mask;
876 if (!match(V: &I, P: m_BitCast(Op: m_OneUse(
877 SubPattern: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(Mask))))))
878 return false;
879
880 // 1) Do not fold bitcast shuffle for scalable type. First, shuffle cost for
881 // scalable type is unknown; Second, we cannot reason if the narrowed shuffle
882 // mask for scalable type is a splat or not.
883 // 2) Disallow non-vector casts.
884 // TODO: We could allow any shuffle.
885 auto *DestTy = dyn_cast<FixedVectorType>(Val: I.getType());
886 auto *SrcTy = dyn_cast<FixedVectorType>(Val: V0->getType());
887 if (!DestTy || !SrcTy)
888 return false;
889
890 unsigned DestEltSize = DestTy->getScalarSizeInBits();
891 unsigned SrcEltSize = SrcTy->getScalarSizeInBits();
892 if (SrcTy->getPrimitiveSizeInBits() % DestEltSize != 0)
893 return false;
894
895 bool IsUnary = isa<UndefValue>(Val: V1);
896
897 // For binary shuffles, only fold bitcast(shuffle(X,Y))
898 // if it won't increase the number of bitcasts.
899 if (!IsUnary) {
900 auto *BCTy0 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V0)->getType());
901 auto *BCTy1 = dyn_cast<FixedVectorType>(Val: peekThroughBitcasts(V: V1)->getType());
902 if (!(BCTy0 && BCTy0->getElementType() == DestTy->getElementType()) &&
903 !(BCTy1 && BCTy1->getElementType() == DestTy->getElementType()))
904 return false;
905 }
906
907 SmallVector<int, 16> NewMask;
908 if (DestEltSize <= SrcEltSize) {
909 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
910 // always be expanded to the equivalent form choosing narrower elements.
911 assert(SrcEltSize % DestEltSize == 0 && "Unexpected shuffle mask");
912 unsigned ScaleFactor = SrcEltSize / DestEltSize;
913 narrowShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask);
914 } else {
915 // The bitcast is from narrow elements to wide elements. The shuffle mask
916 // must choose consecutive elements to allow casting first.
917 assert(DestEltSize % SrcEltSize == 0 && "Unexpected shuffle mask");
918 unsigned ScaleFactor = DestEltSize / SrcEltSize;
919 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask, ScaledMask&: NewMask))
920 return false;
921 }
922
923 // Bitcast the shuffle src - keep its original width but using the destination
924 // scalar type.
925 unsigned NumSrcElts = SrcTy->getPrimitiveSizeInBits() / DestEltSize;
926 auto *NewShuffleTy =
927 FixedVectorType::get(ElementType: DestTy->getScalarType(), NumElts: NumSrcElts);
928 auto *OldShuffleTy =
929 FixedVectorType::get(ElementType: SrcTy->getScalarType(), NumElts: Mask.size());
930 unsigned NumOps = IsUnary ? 1 : 2;
931
932 // The new shuffle must not cost more than the old shuffle.
933 TargetTransformInfo::ShuffleKind SK =
934 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
935 : TargetTransformInfo::SK_PermuteTwoSrc;
936
937 InstructionCost NewCost =
938 TTI.getShuffleCost(Kind: SK, DstTy: DestTy, SrcTy: NewShuffleTy, Mask: NewMask, CostKind) +
939 (NumOps * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: NewShuffleTy, Src: SrcTy,
940 CCH: TargetTransformInfo::CastContextHint::None,
941 CostKind));
942 InstructionCost OldCost =
943 TTI.getShuffleCost(Kind: SK, DstTy: OldShuffleTy, SrcTy, Mask, CostKind) +
944 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: DestTy, Src: OldShuffleTy,
945 CCH: TargetTransformInfo::CastContextHint::None,
946 CostKind);
947
948 LLVM_DEBUG(dbgs() << "Found a bitcasted shuffle: " << I << "\n OldCost: "
949 << OldCost << " vs NewCost: " << NewCost << "\n");
950
951 if (NewCost > OldCost || !NewCost.isValid())
952 return false;
953
954 // bitcast (shuf V0, V1, MaskC) --> shuf (bitcast V0), (bitcast V1), MaskC'
955 ++NumShufOfBitcast;
956 Value *CastV0 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V0), DestTy: NewShuffleTy);
957 Value *CastV1 = Builder.CreateBitCast(V: peekThroughBitcasts(V: V1), DestTy: NewShuffleTy);
958 Value *Shuf = Builder.CreateShuffleVector(V1: CastV0, V2: CastV1, Mask: NewMask);
959 replaceValue(Old&: I, New&: *Shuf);
960 return true;
961}
962
963/// VP Intrinsics whose vector operands are both splat values may be simplified
964/// into the scalar version of the operation and the result splatted. This
965/// can lead to scalarization down the line.
966bool VectorCombine::scalarizeVPIntrinsic(Instruction &I) {
967 if (!isa<VPIntrinsic>(Val: I))
968 return false;
969 VPIntrinsic &VPI = cast<VPIntrinsic>(Val&: I);
970 Value *Op0 = VPI.getArgOperand(i: 0);
971 Value *Op1 = VPI.getArgOperand(i: 1);
972
973 if (!isSplatValue(V: Op0) || !isSplatValue(V: Op1))
974 return false;
975
976 // Check getSplatValue early in this function, to avoid doing unnecessary
977 // work.
978 Value *ScalarOp0 = getSplatValue(V: Op0);
979 Value *ScalarOp1 = getSplatValue(V: Op1);
980 if (!ScalarOp0 || !ScalarOp1)
981 return false;
982
983 // For the binary VP intrinsics supported here, the result on disabled lanes
984 // is a poison value. For now, only do this simplification if all lanes
985 // are active.
986 // TODO: Relax the condition that all lanes are active by using insertelement
987 // on inactive lanes.
988 auto IsAllTrueMask = [](Value *MaskVal) {
989 if (Value *SplattedVal = getSplatValue(V: MaskVal))
990 if (auto *ConstValue = dyn_cast<Constant>(Val: SplattedVal))
991 return ConstValue->isAllOnesValue();
992 return false;
993 };
994 if (!IsAllTrueMask(VPI.getArgOperand(i: 2)))
995 return false;
996
997 // Check to make sure we support scalarization of the intrinsic
998 Intrinsic::ID IntrID = VPI.getIntrinsicID();
999 if (!VPBinOpIntrinsic::isVPBinOp(ID: IntrID))
1000 return false;
1001
1002 // Calculate cost of splatting both operands into vectors and the vector
1003 // intrinsic
1004 VectorType *VecTy = cast<VectorType>(Val: VPI.getType());
1005 SmallVector<int> Mask;
1006 if (auto *FVTy = dyn_cast<FixedVectorType>(Val: VecTy))
1007 Mask.resize(N: FVTy->getNumElements(), NV: 0);
1008 InstructionCost SplatCost =
1009 TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: 0) +
1010 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_Broadcast, DstTy: VecTy, SrcTy: VecTy, Mask,
1011 CostKind);
1012
1013 // Calculate the cost of the VP Intrinsic
1014 SmallVector<Type *, 4> Args;
1015 for (Value *V : VPI.args())
1016 Args.push_back(Elt: V->getType());
1017 IntrinsicCostAttributes Attrs(IntrID, VecTy, Args);
1018 InstructionCost VectorOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1019 InstructionCost OldCost = 2 * SplatCost + VectorOpCost;
1020
1021 // Determine scalar opcode
1022 std::optional<unsigned> FunctionalOpcode =
1023 VPI.getFunctionalOpcode();
1024 std::optional<Intrinsic::ID> ScalarIntrID = std::nullopt;
1025 if (!FunctionalOpcode) {
1026 ScalarIntrID = VPI.getFunctionalIntrinsicID();
1027 if (!ScalarIntrID)
1028 return false;
1029 }
1030
1031 // Calculate cost of scalarizing
1032 InstructionCost ScalarOpCost = 0;
1033 if (ScalarIntrID) {
1034 IntrinsicCostAttributes Attrs(*ScalarIntrID, VecTy->getScalarType(), Args);
1035 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: Attrs, CostKind);
1036 } else {
1037 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode: *FunctionalOpcode,
1038 Ty: VecTy->getScalarType(), CostKind);
1039 }
1040
1041 // The existing splats may be kept around if other instructions use them.
1042 InstructionCost CostToKeepSplats =
1043 (SplatCost * !Op0->hasOneUse()) + (SplatCost * !Op1->hasOneUse());
1044 InstructionCost NewCost = ScalarOpCost + SplatCost + CostToKeepSplats;
1045
1046 LLVM_DEBUG(dbgs() << "Found a VP Intrinsic to scalarize: " << VPI
1047 << "\n");
1048 LLVM_DEBUG(dbgs() << "Cost of Intrinsic: " << OldCost
1049 << ", Cost of scalarizing:" << NewCost << "\n");
1050
1051 // We want to scalarize unless the vector variant actually has lower cost.
1052 if (OldCost < NewCost || !NewCost.isValid())
1053 return false;
1054
1055 // Scalarize the intrinsic
1056 ElementCount EC = cast<VectorType>(Val: Op0->getType())->getElementCount();
1057 Value *EVL = VPI.getArgOperand(i: 3);
1058
1059 // If the VP op might introduce UB or poison, we can scalarize it provided
1060 // that we know the EVL > 0: If the EVL is zero, then the original VP op
1061 // becomes a no-op and thus won't be UB, so make sure we don't introduce UB by
1062 // scalarizing it.
1063 bool SafeToSpeculate;
1064 if (ScalarIntrID)
1065 SafeToSpeculate = Intrinsic::getFnAttributes(C&: I.getContext(), id: *ScalarIntrID)
1066 .hasAttribute(Kind: Attribute::AttrKind::Speculatable);
1067 else
1068 SafeToSpeculate = isSafeToSpeculativelyExecuteWithOpcode(
1069 Opcode: *FunctionalOpcode, Inst: &VPI, CtxI: nullptr, AC: &AC, DT: &DT);
1070 if (!SafeToSpeculate &&
1071 !isKnownNonZero(V: EVL, Q: SimplifyQuery(*DL, &DT, &AC, &VPI)))
1072 return false;
1073
1074 Value *ScalarVal =
1075 ScalarIntrID
1076 ? Builder.CreateIntrinsic(RetTy: VecTy->getScalarType(), ID: *ScalarIntrID,
1077 Args: {ScalarOp0, ScalarOp1})
1078 : Builder.CreateBinOp(Opc: (Instruction::BinaryOps)(*FunctionalOpcode),
1079 LHS: ScalarOp0, RHS: ScalarOp1);
1080
1081 replaceValue(Old&: VPI, New&: *Builder.CreateVectorSplat(EC, V: ScalarVal));
1082 return true;
1083}
1084
1085/// Match a vector op/compare/intrinsic with at least one
1086/// inserted scalar operand and convert to scalar op/cmp/intrinsic followed
1087/// by insertelement.
1088bool VectorCombine::scalarizeOpOrCmp(Instruction &I) {
1089 auto *UO = dyn_cast<UnaryOperator>(Val: &I);
1090 auto *BO = dyn_cast<BinaryOperator>(Val: &I);
1091 auto *CI = dyn_cast<CmpInst>(Val: &I);
1092 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
1093 if (!UO && !BO && !CI && !II)
1094 return false;
1095
1096 // TODO: Allow intrinsics with different argument types
1097 if (II) {
1098 if (!isTriviallyVectorizable(ID: II->getIntrinsicID()))
1099 return false;
1100 for (auto [Idx, Arg] : enumerate(First: II->args()))
1101 if (Arg->getType() != II->getType() &&
1102 !isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI))
1103 return false;
1104 }
1105
1106 // Do not convert the vector condition of a vector select into a scalar
1107 // condition. That may cause problems for codegen because of differences in
1108 // boolean formats and register-file transfers.
1109 // TODO: Can we account for that in the cost model?
1110 if (CI)
1111 for (User *U : I.users())
1112 if (match(V: U, P: m_Select(C: m_Specific(V: &I), L: m_Value(), R: m_Value())))
1113 return false;
1114
1115 // Match constant vectors or scalars being inserted into constant vectors:
1116 // vec_op [VecC0 | (inselt VecC0, V0, Index)], ...
1117 SmallVector<Value *> VecCs, ScalarOps;
1118 std::optional<uint64_t> Index;
1119
1120 auto Ops = II ? II->args() : I.operands();
1121 for (auto [OpNum, Op] : enumerate(First&: Ops)) {
1122 Constant *VecC;
1123 Value *V;
1124 uint64_t InsIdx = 0;
1125 if (match(V: Op.get(), P: m_InsertElt(Val: m_Constant(C&: VecC), Elt: m_Value(V),
1126 Idx: m_ConstantInt(V&: InsIdx)))) {
1127 // Bail if any inserts are out of bounds.
1128 VectorType *OpTy = cast<VectorType>(Val: Op->getType());
1129 if (OpTy->getElementCount().getKnownMinValue() <= InsIdx)
1130 return false;
1131 // All inserts must have the same index.
1132 // TODO: Deal with mismatched index constants and variable indexes?
1133 if (!Index)
1134 Index = InsIdx;
1135 else if (InsIdx != *Index)
1136 return false;
1137 VecCs.push_back(Elt: VecC);
1138 ScalarOps.push_back(Elt: V);
1139 } else if (II && isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(),
1140 ScalarOpdIdx: OpNum, TTI: &TTI)) {
1141 VecCs.push_back(Elt: Op.get());
1142 ScalarOps.push_back(Elt: Op.get());
1143 } else if (match(V: Op.get(), P: m_Constant(C&: VecC))) {
1144 VecCs.push_back(Elt: VecC);
1145 ScalarOps.push_back(Elt: nullptr);
1146 } else {
1147 return false;
1148 }
1149 }
1150
1151 // Bail if all operands are constant.
1152 if (!Index.has_value())
1153 return false;
1154
1155 VectorType *VecTy = cast<VectorType>(Val: I.getType());
1156 Type *ScalarTy = VecTy->getScalarType();
1157 assert(VecTy->isVectorTy() &&
1158 (ScalarTy->isIntegerTy() || ScalarTy->isFloatingPointTy() ||
1159 ScalarTy->isPointerTy()) &&
1160 "Unexpected types for insert element into binop or cmp");
1161
1162 unsigned Opcode = I.getOpcode();
1163 InstructionCost ScalarOpCost, VectorOpCost;
1164 if (CI) {
1165 CmpInst::Predicate Pred = CI->getPredicate();
1166 ScalarOpCost = TTI.getCmpSelInstrCost(
1167 Opcode, ValTy: ScalarTy, CondTy: CmpInst::makeCmpResultType(opnd_type: ScalarTy), VecPred: Pred, CostKind);
1168 VectorOpCost = TTI.getCmpSelInstrCost(
1169 Opcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1170 } else if (UO || BO) {
1171 ScalarOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: ScalarTy, CostKind);
1172 VectorOpCost = TTI.getArithmeticInstrCost(Opcode, Ty: VecTy, CostKind);
1173 } else {
1174 IntrinsicCostAttributes ScalarICA(
1175 II->getIntrinsicID(), ScalarTy,
1176 SmallVector<Type *>(II->arg_size(), ScalarTy));
1177 ScalarOpCost = TTI.getIntrinsicInstrCost(ICA: ScalarICA, CostKind);
1178 IntrinsicCostAttributes VectorICA(
1179 II->getIntrinsicID(), VecTy,
1180 SmallVector<Type *>(II->arg_size(), VecTy));
1181 VectorOpCost = TTI.getIntrinsicInstrCost(ICA: VectorICA, CostKind);
1182 }
1183
1184 // Fold the vector constants in the original vectors into a new base vector to
1185 // get more accurate cost modelling.
1186 Value *NewVecC = nullptr;
1187 TargetFolder Folder(*DL);
1188 if (CI)
1189 NewVecC = Folder.FoldCmp(P: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1]);
1190 else if (UO)
1191 NewVecC =
1192 Folder.FoldUnOpFMF(Opc: UO->getOpcode(), V: VecCs[0], FMF: UO->getFastMathFlags());
1193 else if (BO)
1194 NewVecC = Folder.FoldBinOp(Opc: BO->getOpcode(), LHS: VecCs[0], RHS: VecCs[1]);
1195 else if (II->arg_size() == 2)
1196 NewVecC = Folder.FoldBinaryIntrinsic(ID: II->getIntrinsicID(), LHS: VecCs[0],
1197 RHS: VecCs[1], Ty: II->getType(), FMFSource: &I);
1198
1199 // Get cost estimate for the insert element. This cost will factor into
1200 // both sequences.
1201 InstructionCost OldCost = VectorOpCost;
1202 InstructionCost NewCost =
1203 ScalarOpCost + TTI.getVectorInstrCost(Opcode: Instruction::InsertElement, Val: VecTy,
1204 CostKind, Index: *Index, Op0: NewVecC);
1205 for (auto [Idx, Op, VecC, Scalar] : enumerate(First&: Ops, Rest&: VecCs, Rest&: ScalarOps)) {
1206 if (!Scalar || (II && isVectorIntrinsicWithScalarOpAtArg(
1207 ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI: &TTI)))
1208 continue;
1209 InstructionCost InsertCost = TTI.getVectorInstrCost(
1210 Opcode: Instruction::InsertElement, Val: VecTy, CostKind, Index: *Index, Op0: VecC, Op1: Scalar);
1211 OldCost += InsertCost;
1212 NewCost += !Op->hasOneUse() * InsertCost;
1213 }
1214
1215 // We want to scalarize unless the vector variant actually has lower cost.
1216 if (OldCost < NewCost || !NewCost.isValid())
1217 return false;
1218
1219 // vec_op (inselt VecC0, V0, Index), (inselt VecC1, V1, Index) -->
1220 // inselt NewVecC, (scalar_op V0, V1), Index
1221 if (CI)
1222 ++NumScalarCmp;
1223 else if (UO || BO)
1224 ++NumScalarOps;
1225 else
1226 ++NumScalarIntrinsic;
1227
1228 // For constant cases, extract the scalar element, this should constant fold.
1229 for (auto [OpIdx, Scalar, VecC] : enumerate(First&: ScalarOps, Rest&: VecCs))
1230 if (!Scalar)
1231 ScalarOps[OpIdx] = ConstantExpr::getExtractElement(
1232 Vec: cast<Constant>(Val: VecC), Idx: Builder.getInt64(C: *Index));
1233
1234 Value *Scalar;
1235 if (CI)
1236 Scalar = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: ScalarOps[0], RHS: ScalarOps[1]);
1237 else if (UO || BO)
1238 Scalar = Builder.CreateNAryOp(Opc: Opcode, Ops: ScalarOps);
1239 else
1240 Scalar = Builder.CreateIntrinsic(RetTy: ScalarTy, ID: II->getIntrinsicID(), Args: ScalarOps);
1241
1242 Scalar->setName(I.getName() + ".scalar");
1243
1244 // All IR flags are safe to back-propagate. There is no potential for extra
1245 // poison to be created by the scalar instruction.
1246 if (auto *ScalarInst = dyn_cast<Instruction>(Val: Scalar))
1247 ScalarInst->copyIRFlags(V: &I);
1248
1249 // Create a new base vector if the constant folding failed.
1250 if (!NewVecC) {
1251 if (CI)
1252 NewVecC = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: VecCs[0], RHS: VecCs[1]);
1253 else if (UO || BO)
1254 NewVecC = Builder.CreateNAryOp(Opc: Opcode, Ops: VecCs);
1255 else
1256 NewVecC = Builder.CreateIntrinsic(RetTy: VecTy, ID: II->getIntrinsicID(), Args: VecCs);
1257 }
1258 Value *Insert = Builder.CreateInsertElement(Vec: NewVecC, NewElt: Scalar, Idx: *Index);
1259 replaceValue(Old&: I, New&: *Insert);
1260 return true;
1261}
1262
1263/// Try to combine a scalar binop + 2 scalar compares of extracted elements of
1264/// a vector into vector operations followed by extract. Note: The SLP pass
1265/// may miss this pattern because of implementation problems.
1266bool VectorCombine::foldExtractedCmps(Instruction &I) {
1267 auto *BI = dyn_cast<BinaryOperator>(Val: &I);
1268
1269 // We are looking for a scalar binop of booleans.
1270 // binop i1 (cmp Pred I0, C0), (cmp Pred I1, C1)
1271 if (!BI || !I.getType()->isIntegerTy(Bitwidth: 1))
1272 return false;
1273
1274 // The compare predicates should match, and each compare should have a
1275 // constant operand.
1276 Value *B0 = I.getOperand(i: 0), *B1 = I.getOperand(i: 1);
1277 Instruction *I0, *I1;
1278 Constant *C0, *C1;
1279 CmpPredicate P0, P1;
1280 if (!match(V: B0, P: m_Cmp(Pred&: P0, L: m_Instruction(I&: I0), R: m_Constant(C&: C0))) ||
1281 !match(V: B1, P: m_Cmp(Pred&: P1, L: m_Instruction(I&: I1), R: m_Constant(C&: C1))))
1282 return false;
1283
1284 auto MatchingPred = CmpPredicate::getMatching(A: P0, B: P1);
1285 if (!MatchingPred)
1286 return false;
1287
1288 // The compare operands must be extracts of the same vector with constant
1289 // extract indexes.
1290 Value *X;
1291 uint64_t Index0, Index1;
1292 if (!match(V: I0, P: m_ExtractElt(Val: m_Value(V&: X), Idx: m_ConstantInt(V&: Index0))) ||
1293 !match(V: I1, P: m_ExtractElt(Val: m_Specific(V: X), Idx: m_ConstantInt(V&: Index1))))
1294 return false;
1295
1296 auto *Ext0 = cast<ExtractElementInst>(Val: I0);
1297 auto *Ext1 = cast<ExtractElementInst>(Val: I1);
1298 ExtractElementInst *ConvertToShuf = getShuffleExtract(Ext0, Ext1, PreferredExtractIndex: CostKind);
1299 if (!ConvertToShuf)
1300 return false;
1301 assert((ConvertToShuf == Ext0 || ConvertToShuf == Ext1) &&
1302 "Unknown ExtractElementInst");
1303
1304 // The original scalar pattern is:
1305 // binop i1 (cmp Pred (ext X, Index0), C0), (cmp Pred (ext X, Index1), C1)
1306 CmpInst::Predicate Pred = *MatchingPred;
1307 unsigned CmpOpcode =
1308 CmpInst::isFPPredicate(P: Pred) ? Instruction::FCmp : Instruction::ICmp;
1309 auto *VecTy = dyn_cast<FixedVectorType>(Val: X->getType());
1310 if (!VecTy)
1311 return false;
1312
1313 InstructionCost Ext0Cost =
1314 TTI.getVectorInstrCost(I: *Ext0, Val: VecTy, CostKind, Index: Index0);
1315 InstructionCost Ext1Cost =
1316 TTI.getVectorInstrCost(I: *Ext1, Val: VecTy, CostKind, Index: Index1);
1317 InstructionCost CmpCost = TTI.getCmpSelInstrCost(
1318 Opcode: CmpOpcode, ValTy: I0->getType(), CondTy: CmpInst::makeCmpResultType(opnd_type: I0->getType()), VecPred: Pred,
1319 CostKind);
1320
1321 InstructionCost OldCost =
1322 Ext0Cost + Ext1Cost + CmpCost * 2 +
1323 TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: I.getType(), CostKind);
1324
1325 // The proposed vector pattern is:
1326 // vcmp = cmp Pred X, VecC
1327 // ext (binop vNi1 vcmp, (shuffle vcmp, Index1)), Index0
1328 int CheapIndex = ConvertToShuf == Ext0 ? Index1 : Index0;
1329 int ExpensiveIndex = ConvertToShuf == Ext0 ? Index0 : Index1;
1330 auto *CmpTy = cast<FixedVectorType>(Val: CmpInst::makeCmpResultType(opnd_type: VecTy));
1331 InstructionCost NewCost = TTI.getCmpSelInstrCost(
1332 Opcode: CmpOpcode, ValTy: VecTy, CondTy: CmpInst::makeCmpResultType(opnd_type: VecTy), VecPred: Pred, CostKind);
1333 SmallVector<int, 32> ShufMask(VecTy->getNumElements(), PoisonMaskElem);
1334 ShufMask[CheapIndex] = ExpensiveIndex;
1335 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: CmpTy,
1336 SrcTy: CmpTy, Mask: ShufMask, CostKind);
1337 NewCost += TTI.getArithmeticInstrCost(Opcode: I.getOpcode(), Ty: CmpTy, CostKind);
1338 NewCost += TTI.getVectorInstrCost(I: *Ext0, Val: CmpTy, CostKind, Index: CheapIndex);
1339 NewCost += Ext0->hasOneUse() ? 0 : Ext0Cost;
1340 NewCost += Ext1->hasOneUse() ? 0 : Ext1Cost;
1341
1342 // Aggressively form vector ops if the cost is equal because the transform
1343 // may enable further optimization.
1344 // Codegen can reverse this transform (scalarize) if it was not profitable.
1345 if (OldCost < NewCost || !NewCost.isValid())
1346 return false;
1347
1348 // Create a vector constant from the 2 scalar constants.
1349 SmallVector<Constant *, 32> CmpC(VecTy->getNumElements(),
1350 PoisonValue::get(T: VecTy->getElementType()));
1351 CmpC[Index0] = C0;
1352 CmpC[Index1] = C1;
1353 Value *VCmp = Builder.CreateCmp(Pred, LHS: X, RHS: ConstantVector::get(V: CmpC));
1354 Value *Shuf = createShiftShuffle(Vec: VCmp, OldIndex: ExpensiveIndex, NewIndex: CheapIndex, Builder);
1355 Value *LHS = ConvertToShuf == Ext0 ? Shuf : VCmp;
1356 Value *RHS = ConvertToShuf == Ext0 ? VCmp : Shuf;
1357 Value *VecLogic = Builder.CreateBinOp(Opc: BI->getOpcode(), LHS, RHS);
1358 Value *NewExt = Builder.CreateExtractElement(Vec: VecLogic, Idx: CheapIndex);
1359 replaceValue(Old&: I, New&: *NewExt);
1360 ++NumVecCmpBO;
1361 return true;
1362}
1363
1364static void analyzeCostOfVecReduction(const IntrinsicInst &II,
1365 TTI::TargetCostKind CostKind,
1366 const TargetTransformInfo &TTI,
1367 InstructionCost &CostBeforeReduction,
1368 InstructionCost &CostAfterReduction) {
1369 Instruction *Op0, *Op1;
1370 auto *RedOp = dyn_cast<Instruction>(Val: II.getOperand(i_nocapture: 0));
1371 auto *VecRedTy = cast<VectorType>(Val: II.getOperand(i_nocapture: 0)->getType());
1372 unsigned ReductionOpc =
1373 getArithmeticReductionInstruction(RdxID: II.getIntrinsicID());
1374 if (RedOp && match(V: RedOp, P: m_ZExtOrSExt(Op: m_Value()))) {
1375 bool IsUnsigned = isa<ZExtInst>(Val: RedOp);
1376 auto *ExtType = cast<VectorType>(Val: RedOp->getOperand(i: 0)->getType());
1377
1378 CostBeforeReduction =
1379 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: ExtType,
1380 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1381 CostAfterReduction =
1382 TTI.getExtendedReductionCost(Opcode: ReductionOpc, IsUnsigned, ResTy: II.getType(),
1383 Ty: ExtType, FMF: FastMathFlags(), CostKind);
1384 return;
1385 }
1386 if (RedOp && II.getIntrinsicID() == Intrinsic::vector_reduce_add &&
1387 match(V: RedOp,
1388 P: m_ZExtOrSExt(Op: m_Mul(L: m_Instruction(I&: Op0), R: m_Instruction(I&: Op1)))) &&
1389 match(V: Op0, P: m_ZExtOrSExt(Op: m_Value())) &&
1390 Op0->getOpcode() == Op1->getOpcode() &&
1391 Op0->getOperand(i: 0)->getType() == Op1->getOperand(i: 0)->getType() &&
1392 (Op0->getOpcode() == RedOp->getOpcode() || Op0 == Op1)) {
1393 // Matched reduce.add(ext(mul(ext(A), ext(B)))
1394 bool IsUnsigned = isa<ZExtInst>(Val: Op0);
1395 auto *ExtType = cast<VectorType>(Val: Op0->getOperand(i: 0)->getType());
1396 VectorType *MulType = VectorType::get(ElementType: Op0->getType(), Other: VecRedTy);
1397
1398 InstructionCost ExtCost =
1399 TTI.getCastInstrCost(Opcode: Op0->getOpcode(), Dst: MulType, Src: ExtType,
1400 CCH: TTI::CastContextHint::None, CostKind, I: Op0);
1401 InstructionCost MulCost =
1402 TTI.getArithmeticInstrCost(Opcode: Instruction::Mul, Ty: MulType, CostKind);
1403 InstructionCost Ext2Cost =
1404 TTI.getCastInstrCost(Opcode: RedOp->getOpcode(), Dst: VecRedTy, Src: MulType,
1405 CCH: TTI::CastContextHint::None, CostKind, I: RedOp);
1406
1407 CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
1408 CostAfterReduction =
1409 TTI.getMulAccReductionCost(IsUnsigned, ResTy: II.getType(), Ty: ExtType, CostKind);
1410 return;
1411 }
1412 CostAfterReduction = TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VecRedTy,
1413 FMF: std::nullopt, CostKind);
1414}
1415
1416bool VectorCombine::foldBinopOfReductions(Instruction &I) {
1417 Instruction::BinaryOps BinOpOpc = cast<BinaryOperator>(Val: &I)->getOpcode();
1418 Intrinsic::ID ReductionIID = getReductionForBinop(Opc: BinOpOpc);
1419 if (BinOpOpc == Instruction::Sub)
1420 ReductionIID = Intrinsic::vector_reduce_add;
1421 if (ReductionIID == Intrinsic::not_intrinsic)
1422 return false;
1423
1424 auto checkIntrinsicAndGetItsArgument = [](Value *V,
1425 Intrinsic::ID IID) -> Value * {
1426 auto *II = dyn_cast<IntrinsicInst>(Val: V);
1427 if (!II)
1428 return nullptr;
1429 if (II->getIntrinsicID() == IID && II->hasOneUse())
1430 return II->getArgOperand(i: 0);
1431 return nullptr;
1432 };
1433
1434 Value *V0 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 0), ReductionIID);
1435 if (!V0)
1436 return false;
1437 Value *V1 = checkIntrinsicAndGetItsArgument(I.getOperand(i: 1), ReductionIID);
1438 if (!V1)
1439 return false;
1440
1441 auto *VTy = cast<VectorType>(Val: V0->getType());
1442 if (V1->getType() != VTy)
1443 return false;
1444 const auto &II0 = *cast<IntrinsicInst>(Val: I.getOperand(i: 0));
1445 const auto &II1 = *cast<IntrinsicInst>(Val: I.getOperand(i: 1));
1446 unsigned ReductionOpc =
1447 getArithmeticReductionInstruction(RdxID: II0.getIntrinsicID());
1448
1449 InstructionCost OldCost = 0;
1450 InstructionCost NewCost = 0;
1451 InstructionCost CostOfRedOperand0 = 0;
1452 InstructionCost CostOfRed0 = 0;
1453 InstructionCost CostOfRedOperand1 = 0;
1454 InstructionCost CostOfRed1 = 0;
1455 analyzeCostOfVecReduction(II: II0, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand0, CostAfterReduction&: CostOfRed0);
1456 analyzeCostOfVecReduction(II: II1, CostKind, TTI, CostBeforeReduction&: CostOfRedOperand1, CostAfterReduction&: CostOfRed1);
1457 OldCost = CostOfRed0 + CostOfRed1 + TTI.getInstructionCost(U: &I, CostKind);
1458 NewCost =
1459 CostOfRedOperand0 + CostOfRedOperand1 +
1460 TTI.getArithmeticInstrCost(Opcode: BinOpOpc, Ty: VTy, CostKind) +
1461 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: VTy, FMF: std::nullopt, CostKind);
1462 if (NewCost >= OldCost || !NewCost.isValid())
1463 return false;
1464
1465 LLVM_DEBUG(dbgs() << "Found two mergeable reductions: " << I
1466 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1467 << "\n");
1468 Value *VectorBO;
1469 if (BinOpOpc == Instruction::Or)
1470 VectorBO = Builder.CreateOr(LHS: V0, RHS: V1, Name: "",
1471 IsDisjoint: cast<PossiblyDisjointInst>(Val&: I).isDisjoint());
1472 else
1473 VectorBO = Builder.CreateBinOp(Opc: BinOpOpc, LHS: V0, RHS: V1);
1474
1475 Instruction *Rdx = Builder.CreateIntrinsic(ID: ReductionIID, Types: {VTy}, Args: {VectorBO});
1476 replaceValue(Old&: I, New&: *Rdx);
1477 return true;
1478}
1479
1480// Check if memory loc modified between two instrs in the same BB
1481static bool isMemModifiedBetween(BasicBlock::iterator Begin,
1482 BasicBlock::iterator End,
1483 const MemoryLocation &Loc, AAResults &AA) {
1484 unsigned NumScanned = 0;
1485 return std::any_of(first: Begin, last: End, pred: [&](const Instruction &Instr) {
1486 return isModSet(MRI: AA.getModRefInfo(I: &Instr, OptLoc: Loc)) ||
1487 ++NumScanned > MaxInstrsToScan;
1488 });
1489}
1490
1491namespace {
1492/// Helper class to indicate whether a vector index can be safely scalarized and
1493/// if a freeze needs to be inserted.
1494class ScalarizationResult {
1495 enum class StatusTy { Unsafe, Safe, SafeWithFreeze };
1496
1497 StatusTy Status;
1498 Value *ToFreeze;
1499
1500 ScalarizationResult(StatusTy Status, Value *ToFreeze = nullptr)
1501 : Status(Status), ToFreeze(ToFreeze) {}
1502
1503public:
1504 ScalarizationResult(const ScalarizationResult &Other) = default;
1505 ~ScalarizationResult() {
1506 assert(!ToFreeze && "freeze() not called with ToFreeze being set");
1507 }
1508
1509 static ScalarizationResult unsafe() { return {StatusTy::Unsafe}; }
1510 static ScalarizationResult safe() { return {StatusTy::Safe}; }
1511 static ScalarizationResult safeWithFreeze(Value *ToFreeze) {
1512 return {StatusTy::SafeWithFreeze, ToFreeze};
1513 }
1514
1515 /// Returns true if the index can be scalarize without requiring a freeze.
1516 bool isSafe() const { return Status == StatusTy::Safe; }
1517 /// Returns true if the index cannot be scalarized.
1518 bool isUnsafe() const { return Status == StatusTy::Unsafe; }
1519 /// Returns true if the index can be scalarize, but requires inserting a
1520 /// freeze.
1521 bool isSafeWithFreeze() const { return Status == StatusTy::SafeWithFreeze; }
1522
1523 /// Reset the state of Unsafe and clear ToFreze if set.
1524 void discard() {
1525 ToFreeze = nullptr;
1526 Status = StatusTy::Unsafe;
1527 }
1528
1529 /// Freeze the ToFreeze and update the use in \p User to use it.
1530 void freeze(IRBuilderBase &Builder, Instruction &UserI) {
1531 assert(isSafeWithFreeze() &&
1532 "should only be used when freezing is required");
1533 assert(is_contained(ToFreeze->users(), &UserI) &&
1534 "UserI must be a user of ToFreeze");
1535 IRBuilder<>::InsertPointGuard Guard(Builder);
1536 Builder.SetInsertPoint(cast<Instruction>(Val: &UserI));
1537 Value *Frozen =
1538 Builder.CreateFreeze(V: ToFreeze, Name: ToFreeze->getName() + ".frozen");
1539 for (Use &U : make_early_inc_range(Range: (UserI.operands())))
1540 if (U.get() == ToFreeze)
1541 U.set(Frozen);
1542
1543 ToFreeze = nullptr;
1544 }
1545};
1546} // namespace
1547
1548/// Check if it is legal to scalarize a memory access to \p VecTy at index \p
1549/// Idx. \p Idx must access a valid vector element.
1550static ScalarizationResult canScalarizeAccess(VectorType *VecTy, Value *Idx,
1551 Instruction *CtxI,
1552 AssumptionCache &AC,
1553 const DominatorTree &DT) {
1554 // We do checks for both fixed vector types and scalable vector types.
1555 // This is the number of elements of fixed vector types,
1556 // or the minimum number of elements of scalable vector types.
1557 uint64_t NumElements = VecTy->getElementCount().getKnownMinValue();
1558 unsigned IntWidth = Idx->getType()->getScalarSizeInBits();
1559
1560 if (auto *C = dyn_cast<ConstantInt>(Val: Idx)) {
1561 if (C->getValue().ult(RHS: NumElements))
1562 return ScalarizationResult::safe();
1563 return ScalarizationResult::unsafe();
1564 }
1565
1566 // Always unsafe if the index type can't handle all inbound values.
1567 if (!llvm::isUIntN(N: IntWidth, x: NumElements))
1568 return ScalarizationResult::unsafe();
1569
1570 APInt Zero(IntWidth, 0);
1571 APInt MaxElts(IntWidth, NumElements);
1572 ConstantRange ValidIndices(Zero, MaxElts);
1573 ConstantRange IdxRange(IntWidth, true);
1574
1575 if (isGuaranteedNotToBePoison(V: Idx, AC: &AC)) {
1576 if (ValidIndices.contains(CR: computeConstantRange(V: Idx, /* ForSigned */ false,
1577 UseInstrInfo: true, AC: &AC, CtxI, DT: &DT)))
1578 return ScalarizationResult::safe();
1579 return ScalarizationResult::unsafe();
1580 }
1581
1582 // If the index may be poison, check if we can insert a freeze before the
1583 // range of the index is restricted.
1584 Value *IdxBase;
1585 ConstantInt *CI;
1586 if (match(V: Idx, P: m_And(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1587 IdxRange = IdxRange.binaryAnd(Other: CI->getValue());
1588 } else if (match(V: Idx, P: m_URem(L: m_Value(V&: IdxBase), R: m_ConstantInt(CI)))) {
1589 IdxRange = IdxRange.urem(Other: CI->getValue());
1590 }
1591
1592 if (ValidIndices.contains(CR: IdxRange))
1593 return ScalarizationResult::safeWithFreeze(ToFreeze: IdxBase);
1594 return ScalarizationResult::unsafe();
1595}
1596
1597/// The memory operation on a vector of \p ScalarType had alignment of
1598/// \p VectorAlignment. Compute the maximal, but conservatively correct,
1599/// alignment that will be valid for the memory operation on a single scalar
1600/// element of the same type with index \p Idx.
1601static Align computeAlignmentAfterScalarization(Align VectorAlignment,
1602 Type *ScalarType, Value *Idx,
1603 const DataLayout &DL) {
1604 if (auto *C = dyn_cast<ConstantInt>(Val: Idx))
1605 return commonAlignment(A: VectorAlignment,
1606 Offset: C->getZExtValue() * DL.getTypeStoreSize(Ty: ScalarType));
1607 return commonAlignment(A: VectorAlignment, Offset: DL.getTypeStoreSize(Ty: ScalarType));
1608}
1609
1610// Combine patterns like:
1611// %0 = load <4 x i32>, <4 x i32>* %a
1612// %1 = insertelement <4 x i32> %0, i32 %b, i32 1
1613// store <4 x i32> %1, <4 x i32>* %a
1614// to:
1615// %0 = bitcast <4 x i32>* %a to i32*
1616// %1 = getelementptr inbounds i32, i32* %0, i64 0, i64 1
1617// store i32 %b, i32* %1
1618bool VectorCombine::foldSingleElementStore(Instruction &I) {
1619 auto *SI = cast<StoreInst>(Val: &I);
1620 if (!SI->isSimple() || !isa<VectorType>(Val: SI->getValueOperand()->getType()))
1621 return false;
1622
1623 // TODO: Combine more complicated patterns (multiple insert) by referencing
1624 // TargetTransformInfo.
1625 Instruction *Source;
1626 Value *NewElement;
1627 Value *Idx;
1628 if (!match(V: SI->getValueOperand(),
1629 P: m_InsertElt(Val: m_Instruction(I&: Source), Elt: m_Value(V&: NewElement),
1630 Idx: m_Value(V&: Idx))))
1631 return false;
1632
1633 if (auto *Load = dyn_cast<LoadInst>(Val: Source)) {
1634 auto VecTy = cast<VectorType>(Val: SI->getValueOperand()->getType());
1635 Value *SrcAddr = Load->getPointerOperand()->stripPointerCasts();
1636 // Don't optimize for atomic/volatile load or store. Ensure memory is not
1637 // modified between, vector type matches store size, and index is inbounds.
1638 if (!Load->isSimple() || Load->getParent() != SI->getParent() ||
1639 !DL->typeSizeEqualsStoreSize(Ty: Load->getType()->getScalarType()) ||
1640 SrcAddr != SI->getPointerOperand()->stripPointerCasts())
1641 return false;
1642
1643 auto ScalarizableIdx = canScalarizeAccess(VecTy, Idx, CtxI: Load, AC, DT);
1644 if (ScalarizableIdx.isUnsafe() ||
1645 isMemModifiedBetween(Begin: Load->getIterator(), End: SI->getIterator(),
1646 Loc: MemoryLocation::get(SI), AA))
1647 return false;
1648
1649 // Ensure we add the load back to the worklist BEFORE its users so they can
1650 // erased in the correct order.
1651 Worklist.push(I: Load);
1652
1653 if (ScalarizableIdx.isSafeWithFreeze())
1654 ScalarizableIdx.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
1655 Value *GEP = Builder.CreateInBoundsGEP(
1656 Ty: SI->getValueOperand()->getType(), Ptr: SI->getPointerOperand(),
1657 IdxList: {ConstantInt::get(Ty: Idx->getType(), V: 0), Idx});
1658 StoreInst *NSI = Builder.CreateStore(Val: NewElement, Ptr: GEP);
1659 NSI->copyMetadata(SrcInst: *SI);
1660 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1661 VectorAlignment: std::max(a: SI->getAlign(), b: Load->getAlign()), ScalarType: NewElement->getType(), Idx,
1662 DL: *DL);
1663 NSI->setAlignment(ScalarOpAlignment);
1664 replaceValue(Old&: I, New&: *NSI);
1665 eraseInstruction(I);
1666 return true;
1667 }
1668
1669 return false;
1670}
1671
1672/// Try to scalarize vector loads feeding extractelement instructions.
1673bool VectorCombine::scalarizeLoadExtract(Instruction &I) {
1674 Value *Ptr;
1675 if (!match(V: &I, P: m_Load(Op: m_Value(V&: Ptr))))
1676 return false;
1677
1678 auto *LI = cast<LoadInst>(Val: &I);
1679 auto *VecTy = cast<VectorType>(Val: LI->getType());
1680 if (LI->isVolatile() || !DL->typeSizeEqualsStoreSize(Ty: VecTy->getScalarType()))
1681 return false;
1682
1683 InstructionCost OriginalCost =
1684 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy, Alignment: LI->getAlign(),
1685 AddressSpace: LI->getPointerAddressSpace(), CostKind);
1686 InstructionCost ScalarizedCost = 0;
1687
1688 Instruction *LastCheckedInst = LI;
1689 unsigned NumInstChecked = 0;
1690 DenseMap<ExtractElementInst *, ScalarizationResult> NeedFreeze;
1691 auto FailureGuard = make_scope_exit(F: [&]() {
1692 // If the transform is aborted, discard the ScalarizationResults.
1693 for (auto &Pair : NeedFreeze)
1694 Pair.second.discard();
1695 });
1696
1697 // Check if all users of the load are extracts with no memory modifications
1698 // between the load and the extract. Compute the cost of both the original
1699 // code and the scalarized version.
1700 for (User *U : LI->users()) {
1701 auto *UI = dyn_cast<ExtractElementInst>(Val: U);
1702 if (!UI || UI->getParent() != LI->getParent())
1703 return false;
1704
1705 // If any extract is waiting to be erased, then bail out as this will
1706 // distort the cost calculation and possibly lead to infinite loops.
1707 if (UI->use_empty())
1708 return false;
1709
1710 // Check if any instruction between the load and the extract may modify
1711 // memory.
1712 if (LastCheckedInst->comesBefore(Other: UI)) {
1713 for (Instruction &I :
1714 make_range(x: std::next(x: LI->getIterator()), y: UI->getIterator())) {
1715 // Bail out if we reached the check limit or the instruction may write
1716 // to memory.
1717 if (NumInstChecked == MaxInstrsToScan || I.mayWriteToMemory())
1718 return false;
1719 NumInstChecked++;
1720 }
1721 LastCheckedInst = UI;
1722 }
1723
1724 auto ScalarIdx =
1725 canScalarizeAccess(VecTy, Idx: UI->getIndexOperand(), CtxI: LI, AC, DT);
1726 if (ScalarIdx.isUnsafe())
1727 return false;
1728 if (ScalarIdx.isSafeWithFreeze()) {
1729 NeedFreeze.try_emplace(Key: UI, Args&: ScalarIdx);
1730 ScalarIdx.discard();
1731 }
1732
1733 auto *Index = dyn_cast<ConstantInt>(Val: UI->getIndexOperand());
1734 OriginalCost +=
1735 TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: VecTy, CostKind,
1736 Index: Index ? Index->getZExtValue() : -1);
1737 ScalarizedCost +=
1738 TTI.getMemoryOpCost(Opcode: Instruction::Load, Src: VecTy->getElementType(),
1739 Alignment: Align(1), AddressSpace: LI->getPointerAddressSpace(), CostKind);
1740 ScalarizedCost += TTI.getAddressComputationCost(Ty: VecTy->getElementType());
1741 }
1742
1743 LLVM_DEBUG(dbgs() << "Found all extractions of a vector load: " << I
1744 << "\n LoadExtractCost: " << OriginalCost
1745 << " vs ScalarizedCost: " << ScalarizedCost << "\n");
1746
1747 if (ScalarizedCost >= OriginalCost)
1748 return false;
1749
1750 // Ensure we add the load back to the worklist BEFORE its users so they can
1751 // erased in the correct order.
1752 Worklist.push(I: LI);
1753
1754 // Replace extracts with narrow scalar loads.
1755 for (User *U : LI->users()) {
1756 auto *EI = cast<ExtractElementInst>(Val: U);
1757 Value *Idx = EI->getIndexOperand();
1758
1759 // Insert 'freeze' for poison indexes.
1760 auto It = NeedFreeze.find(Val: EI);
1761 if (It != NeedFreeze.end())
1762 It->second.freeze(Builder, UserI&: *cast<Instruction>(Val: Idx));
1763
1764 Builder.SetInsertPoint(EI);
1765 Value *GEP =
1766 Builder.CreateInBoundsGEP(Ty: VecTy, Ptr, IdxList: {Builder.getInt32(C: 0), Idx});
1767 auto *NewLoad = cast<LoadInst>(Val: Builder.CreateLoad(
1768 Ty: VecTy->getElementType(), Ptr: GEP, Name: EI->getName() + ".scalar"));
1769
1770 Align ScalarOpAlignment = computeAlignmentAfterScalarization(
1771 VectorAlignment: LI->getAlign(), ScalarType: VecTy->getElementType(), Idx, DL: *DL);
1772 NewLoad->setAlignment(ScalarOpAlignment);
1773
1774 replaceValue(Old&: *EI, New&: *NewLoad);
1775 }
1776
1777 FailureGuard.release();
1778 return true;
1779}
1780
1781bool VectorCombine::scalarizeExtExtract(Instruction &I) {
1782 auto *Ext = dyn_cast<ZExtInst>(Val: &I);
1783 if (!Ext)
1784 return false;
1785
1786 // Try to convert a vector zext feeding only extracts to a set of scalar
1787 // (Src << ExtIdx *Size) & (Size -1)
1788 // if profitable .
1789 auto *SrcTy = dyn_cast<FixedVectorType>(Val: Ext->getOperand(i_nocapture: 0)->getType());
1790 if (!SrcTy)
1791 return false;
1792 auto *DstTy = cast<FixedVectorType>(Val: Ext->getType());
1793
1794 Type *ScalarDstTy = DstTy->getElementType();
1795 if (DL->getTypeSizeInBits(Ty: SrcTy) != DL->getTypeSizeInBits(Ty: ScalarDstTy))
1796 return false;
1797
1798 InstructionCost VectorCost =
1799 TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: DstTy, Src: SrcTy,
1800 CCH: TTI::CastContextHint::None, CostKind, I: Ext);
1801 unsigned ExtCnt = 0;
1802 bool ExtLane0 = false;
1803 for (User *U : Ext->users()) {
1804 const APInt *Idx;
1805 if (!match(V: U, P: m_ExtractElt(Val: m_Value(), Idx: m_APInt(Res&: Idx))))
1806 return false;
1807 if (cast<Instruction>(Val: U)->use_empty())
1808 continue;
1809 ExtCnt += 1;
1810 ExtLane0 |= Idx->isZero();
1811 VectorCost += TTI.getVectorInstrCost(Opcode: Instruction::ExtractElement, Val: DstTy,
1812 CostKind, Index: Idx->getZExtValue(), Op0: U);
1813 }
1814
1815 InstructionCost ScalarCost =
1816 ExtCnt * TTI.getArithmeticInstrCost(
1817 Opcode: Instruction::And, Ty: ScalarDstTy, CostKind,
1818 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
1819 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None}) +
1820 (ExtCnt - ExtLane0) *
1821 TTI.getArithmeticInstrCost(
1822 Opcode: Instruction::LShr, Ty: ScalarDstTy, CostKind,
1823 Opd1Info: {.Kind: TTI::OK_AnyValue, .Properties: TTI::OP_None},
1824 Opd2Info: {.Kind: TTI::OK_NonUniformConstantValue, .Properties: TTI::OP_None});
1825 if (ScalarCost > VectorCost)
1826 return false;
1827
1828 Value *ScalarV = Ext->getOperand(i_nocapture: 0);
1829 if (!isGuaranteedNotToBePoison(V: ScalarV, AC: &AC, CtxI: dyn_cast<Instruction>(Val: ScalarV),
1830 DT: &DT))
1831 ScalarV = Builder.CreateFreeze(V: ScalarV);
1832 ScalarV = Builder.CreateBitCast(
1833 V: ScalarV,
1834 DestTy: IntegerType::get(C&: SrcTy->getContext(), NumBits: DL->getTypeSizeInBits(Ty: SrcTy)));
1835 uint64_t SrcEltSizeInBits = DL->getTypeSizeInBits(Ty: SrcTy->getElementType());
1836 uint64_t EltBitMask = (1ull << SrcEltSizeInBits) - 1;
1837 for (User *U : Ext->users()) {
1838 auto *Extract = cast<ExtractElementInst>(Val: U);
1839 uint64_t Idx =
1840 cast<ConstantInt>(Val: Extract->getIndexOperand())->getZExtValue();
1841 Value *LShr = Builder.CreateLShr(LHS: ScalarV, RHS: Idx * SrcEltSizeInBits);
1842 Value *And = Builder.CreateAnd(LHS: LShr, RHS: EltBitMask);
1843 U->replaceAllUsesWith(V: And);
1844 }
1845 return true;
1846}
1847
1848/// Try to fold "(or (zext (bitcast X)), (shl (zext (bitcast Y)), C))"
1849/// to "(bitcast (concat X, Y))"
1850/// where X/Y are bitcasted from i1 mask vectors.
1851bool VectorCombine::foldConcatOfBoolMasks(Instruction &I) {
1852 Type *Ty = I.getType();
1853 if (!Ty->isIntegerTy())
1854 return false;
1855
1856 // TODO: Add big endian test coverage
1857 if (DL->isBigEndian())
1858 return false;
1859
1860 // Restrict to disjoint cases so the mask vectors aren't overlapping.
1861 Instruction *X, *Y;
1862 if (!match(V: &I, P: m_DisjointOr(L: m_Instruction(I&: X), R: m_Instruction(I&: Y))))
1863 return false;
1864
1865 // Allow both sources to contain shl, to handle more generic pattern:
1866 // "(or (shl (zext (bitcast X)), C1), (shl (zext (bitcast Y)), C2))"
1867 Value *SrcX;
1868 uint64_t ShAmtX = 0;
1869 if (!match(V: X, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX)))))) &&
1870 !match(V: X, P: m_OneUse(
1871 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcX))))),
1872 R: m_ConstantInt(V&: ShAmtX)))))
1873 return false;
1874
1875 Value *SrcY;
1876 uint64_t ShAmtY = 0;
1877 if (!match(V: Y, P: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY)))))) &&
1878 !match(V: Y, P: m_OneUse(
1879 SubPattern: m_Shl(L: m_OneUse(SubPattern: m_ZExt(Op: m_OneUse(SubPattern: m_BitCast(Op: m_Value(V&: SrcY))))),
1880 R: m_ConstantInt(V&: ShAmtY)))))
1881 return false;
1882
1883 // Canonicalize larger shift to the RHS.
1884 if (ShAmtX > ShAmtY) {
1885 std::swap(a&: X, b&: Y);
1886 std::swap(a&: SrcX, b&: SrcY);
1887 std::swap(a&: ShAmtX, b&: ShAmtY);
1888 }
1889
1890 // Ensure both sources are matching vXi1 bool mask types, and that the shift
1891 // difference is the mask width so they can be easily concatenated together.
1892 uint64_t ShAmtDiff = ShAmtY - ShAmtX;
1893 unsigned NumSHL = (ShAmtX > 0) + (ShAmtY > 0);
1894 unsigned BitWidth = Ty->getPrimitiveSizeInBits();
1895 auto *MaskTy = dyn_cast<FixedVectorType>(Val: SrcX->getType());
1896 if (!MaskTy || SrcX->getType() != SrcY->getType() ||
1897 !MaskTy->getElementType()->isIntegerTy(Bitwidth: 1) ||
1898 MaskTy->getNumElements() != ShAmtDiff ||
1899 MaskTy->getNumElements() > (BitWidth / 2))
1900 return false;
1901
1902 auto *ConcatTy = FixedVectorType::getDoubleElementsVectorType(VTy: MaskTy);
1903 auto *ConcatIntTy =
1904 Type::getIntNTy(C&: Ty->getContext(), N: ConcatTy->getNumElements());
1905 auto *MaskIntTy = Type::getIntNTy(C&: Ty->getContext(), N: ShAmtDiff);
1906
1907 SmallVector<int, 32> ConcatMask(ConcatTy->getNumElements());
1908 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
1909
1910 // TODO: Is it worth supporting multi use cases?
1911 InstructionCost OldCost = 0;
1912 OldCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Or, Ty, CostKind);
1913 OldCost +=
1914 NumSHL * TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
1915 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: MaskIntTy,
1916 CCH: TTI::CastContextHint::None, CostKind);
1917 OldCost += 2 * TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: MaskIntTy, Src: MaskTy,
1918 CCH: TTI::CastContextHint::None, CostKind);
1919
1920 InstructionCost NewCost = 0;
1921 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ConcatTy,
1922 SrcTy: MaskTy, Mask: ConcatMask, CostKind);
1923 NewCost += TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: ConcatIntTy, Src: ConcatTy,
1924 CCH: TTI::CastContextHint::None, CostKind);
1925 if (Ty != ConcatIntTy)
1926 NewCost += TTI.getCastInstrCost(Opcode: Instruction::ZExt, Dst: Ty, Src: ConcatIntTy,
1927 CCH: TTI::CastContextHint::None, CostKind);
1928 if (ShAmtX > 0)
1929 NewCost += TTI.getArithmeticInstrCost(Opcode: Instruction::Shl, Ty, CostKind);
1930
1931 LLVM_DEBUG(dbgs() << "Found a concatenation of bitcasted bool masks: " << I
1932 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
1933 << "\n");
1934
1935 if (NewCost > OldCost)
1936 return false;
1937
1938 // Build bool mask concatenation, bitcast back to scalar integer, and perform
1939 // any residual zero-extension or shifting.
1940 Value *Concat = Builder.CreateShuffleVector(V1: SrcX, V2: SrcY, Mask: ConcatMask);
1941 Worklist.pushValue(V: Concat);
1942
1943 Value *Result = Builder.CreateBitCast(V: Concat, DestTy: ConcatIntTy);
1944
1945 if (Ty != ConcatIntTy) {
1946 Worklist.pushValue(V: Result);
1947 Result = Builder.CreateZExt(V: Result, DestTy: Ty);
1948 }
1949
1950 if (ShAmtX > 0) {
1951 Worklist.pushValue(V: Result);
1952 Result = Builder.CreateShl(LHS: Result, RHS: ShAmtX);
1953 }
1954
1955 replaceValue(Old&: I, New&: *Result);
1956 return true;
1957}
1958
1959/// Try to convert "shuffle (binop (shuffle, shuffle)), undef"
1960/// --> "binop (shuffle), (shuffle)".
1961bool VectorCombine::foldPermuteOfBinops(Instruction &I) {
1962 BinaryOperator *BinOp;
1963 ArrayRef<int> OuterMask;
1964 if (!match(V: &I,
1965 P: m_Shuffle(v1: m_OneUse(SubPattern: m_BinOp(I&: BinOp)), v2: m_Undef(), mask: m_Mask(OuterMask))))
1966 return false;
1967
1968 // Don't introduce poison into div/rem.
1969 if (BinOp->isIntDivRem() && llvm::is_contained(Range&: OuterMask, Element: PoisonMaskElem))
1970 return false;
1971
1972 Value *Op00, *Op01, *Op10, *Op11;
1973 ArrayRef<int> Mask0, Mask1;
1974 bool Match0 =
1975 match(V: BinOp->getOperand(i_nocapture: 0),
1976 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op00), v2: m_Value(V&: Op01), mask: m_Mask(Mask0))));
1977 bool Match1 =
1978 match(V: BinOp->getOperand(i_nocapture: 1),
1979 P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: Op10), v2: m_Value(V&: Op11), mask: m_Mask(Mask1))));
1980 if (!Match0 && !Match1)
1981 return false;
1982
1983 Op00 = Match0 ? Op00 : BinOp->getOperand(i_nocapture: 0);
1984 Op01 = Match0 ? Op01 : BinOp->getOperand(i_nocapture: 0);
1985 Op10 = Match1 ? Op10 : BinOp->getOperand(i_nocapture: 1);
1986 Op11 = Match1 ? Op11 : BinOp->getOperand(i_nocapture: 1);
1987
1988 Instruction::BinaryOps Opcode = BinOp->getOpcode();
1989 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
1990 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: BinOp->getType());
1991 auto *Op0Ty = dyn_cast<FixedVectorType>(Val: Op00->getType());
1992 auto *Op1Ty = dyn_cast<FixedVectorType>(Val: Op10->getType());
1993 if (!ShuffleDstTy || !BinOpTy || !Op0Ty || !Op1Ty)
1994 return false;
1995
1996 unsigned NumSrcElts = BinOpTy->getNumElements();
1997
1998 // Don't accept shuffles that reference the second operand in
1999 // div/rem or if its an undef arg.
2000 if ((BinOp->isIntDivRem() || !isa<PoisonValue>(Val: I.getOperand(i: 1))) &&
2001 any_of(Range&: OuterMask, P: [NumSrcElts](int M) { return M >= (int)NumSrcElts; }))
2002 return false;
2003
2004 // Merge outer / inner (or identity if no match) shuffles.
2005 SmallVector<int> NewMask0, NewMask1;
2006 for (int M : OuterMask) {
2007 if (M < 0 || M >= (int)NumSrcElts) {
2008 NewMask0.push_back(Elt: PoisonMaskElem);
2009 NewMask1.push_back(Elt: PoisonMaskElem);
2010 } else {
2011 NewMask0.push_back(Elt: Match0 ? Mask0[M] : M);
2012 NewMask1.push_back(Elt: Match1 ? Mask1[M] : M);
2013 }
2014 }
2015
2016 unsigned NumOpElts = Op0Ty->getNumElements();
2017 bool IsIdentity0 = ShuffleDstTy == Op0Ty &&
2018 all_of(Range&: NewMask0, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2019 ShuffleVectorInst::isIdentityMask(Mask: NewMask0, NumSrcElts: NumOpElts);
2020 bool IsIdentity1 = ShuffleDstTy == Op1Ty &&
2021 all_of(Range&: NewMask1, P: [NumOpElts](int M) { return M < (int)NumOpElts; }) &&
2022 ShuffleVectorInst::isIdentityMask(Mask: NewMask1, NumSrcElts: NumOpElts);
2023
2024 // Try to merge shuffles across the binop if the new shuffles are not costly.
2025 InstructionCost OldCost =
2026 TTI.getArithmeticInstrCost(Opcode, Ty: BinOpTy, CostKind) +
2027 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc, DstTy: ShuffleDstTy,
2028 SrcTy: BinOpTy, Mask: OuterMask, CostKind, Index: 0, SubTp: nullptr, Args: {BinOp}, CxtI: &I);
2029 if (Match0)
2030 OldCost += TTI.getShuffleCost(
2031 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op0Ty, Mask: Mask0, CostKind,
2032 Index: 0, SubTp: nullptr, Args: {Op00, Op01}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 0)));
2033 if (Match1)
2034 OldCost += TTI.getShuffleCost(
2035 Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: BinOpTy, SrcTy: Op1Ty, Mask: Mask1, CostKind,
2036 Index: 0, SubTp: nullptr, Args: {Op10, Op11}, CxtI: cast<Instruction>(Val: BinOp->getOperand(i_nocapture: 1)));
2037
2038 InstructionCost NewCost =
2039 TTI.getArithmeticInstrCost(Opcode, Ty: ShuffleDstTy, CostKind);
2040
2041 if (!IsIdentity0)
2042 NewCost +=
2043 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2044 SrcTy: Op0Ty, Mask: NewMask0, CostKind, Index: 0, SubTp: nullptr, Args: {Op00, Op01});
2045 if (!IsIdentity1)
2046 NewCost +=
2047 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2048 SrcTy: Op1Ty, Mask: NewMask1, CostKind, Index: 0, SubTp: nullptr, Args: {Op10, Op11});
2049
2050 LLVM_DEBUG(dbgs() << "Found a shuffle feeding a shuffled binop: " << I
2051 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2052 << "\n");
2053
2054 // If costs are equal, still fold as we reduce instruction count.
2055 if (NewCost > OldCost)
2056 return false;
2057
2058 Value *LHS =
2059 IsIdentity0 ? Op00 : Builder.CreateShuffleVector(V1: Op00, V2: Op01, Mask: NewMask0);
2060 Value *RHS =
2061 IsIdentity1 ? Op10 : Builder.CreateShuffleVector(V1: Op10, V2: Op11, Mask: NewMask1);
2062 Value *NewBO = Builder.CreateBinOp(Opc: Opcode, LHS, RHS);
2063
2064 // Intersect flags from the old binops.
2065 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO))
2066 NewInst->copyIRFlags(V: BinOp);
2067
2068 Worklist.pushValue(V: LHS);
2069 Worklist.pushValue(V: RHS);
2070 replaceValue(Old&: I, New&: *NewBO);
2071 return true;
2072}
2073
2074/// Try to convert "shuffle (binop), (binop)" into "binop (shuffle), (shuffle)".
2075/// Try to convert "shuffle (cmpop), (cmpop)" into "cmpop (shuffle), (shuffle)".
2076bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
2077 ArrayRef<int> OldMask;
2078 Instruction *LHS, *RHS;
2079 if (!match(V: &I, P: m_Shuffle(v1: m_OneUse(SubPattern: m_Instruction(I&: LHS)),
2080 v2: m_OneUse(SubPattern: m_Instruction(I&: RHS)), mask: m_Mask(OldMask))))
2081 return false;
2082
2083 // TODO: Add support for addlike etc.
2084 if (LHS->getOpcode() != RHS->getOpcode())
2085 return false;
2086
2087 Value *X, *Y, *Z, *W;
2088 bool IsCommutative = false;
2089 CmpPredicate PredLHS = CmpInst::BAD_ICMP_PREDICATE;
2090 CmpPredicate PredRHS = CmpInst::BAD_ICMP_PREDICATE;
2091 if (match(V: LHS, P: m_BinOp(L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2092 match(V: RHS, P: m_BinOp(L: m_Value(V&: Z), R: m_Value(V&: W)))) {
2093 auto *BO = cast<BinaryOperator>(Val: LHS);
2094 // Don't introduce poison into div/rem.
2095 if (llvm::is_contained(Range&: OldMask, Element: PoisonMaskElem) && BO->isIntDivRem())
2096 return false;
2097 IsCommutative = BinaryOperator::isCommutative(Opcode: BO->getOpcode());
2098 } else if (match(V: LHS, P: m_Cmp(Pred&: PredLHS, L: m_Value(V&: X), R: m_Value(V&: Y))) &&
2099 match(V: RHS, P: m_Cmp(Pred&: PredRHS, L: m_Value(V&: Z), R: m_Value(V&: W))) &&
2100 (CmpInst::Predicate)PredLHS == (CmpInst::Predicate)PredRHS) {
2101 IsCommutative = cast<CmpInst>(Val: LHS)->isCommutative();
2102 } else
2103 return false;
2104
2105 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2106 auto *BinResTy = dyn_cast<FixedVectorType>(Val: LHS->getType());
2107 auto *BinOpTy = dyn_cast<FixedVectorType>(Val: X->getType());
2108 if (!ShuffleDstTy || !BinResTy || !BinOpTy || X->getType() != Z->getType())
2109 return false;
2110
2111 unsigned NumSrcElts = BinOpTy->getNumElements();
2112
2113 // If we have something like "add X, Y" and "add Z, X", swap ops to match.
2114 if (IsCommutative && X != Z && Y != W && (X == W || Y == Z))
2115 std::swap(a&: X, b&: Y);
2116
2117 auto ConvertToUnary = [NumSrcElts](int &M) {
2118 if (M >= (int)NumSrcElts)
2119 M -= NumSrcElts;
2120 };
2121
2122 SmallVector<int> NewMask0(OldMask);
2123 TargetTransformInfo::ShuffleKind SK0 = TargetTransformInfo::SK_PermuteTwoSrc;
2124 if (X == Z) {
2125 llvm::for_each(Range&: NewMask0, F: ConvertToUnary);
2126 SK0 = TargetTransformInfo::SK_PermuteSingleSrc;
2127 Z = PoisonValue::get(T: BinOpTy);
2128 }
2129
2130 SmallVector<int> NewMask1(OldMask);
2131 TargetTransformInfo::ShuffleKind SK1 = TargetTransformInfo::SK_PermuteTwoSrc;
2132 if (Y == W) {
2133 llvm::for_each(Range&: NewMask1, F: ConvertToUnary);
2134 SK1 = TargetTransformInfo::SK_PermuteSingleSrc;
2135 W = PoisonValue::get(T: BinOpTy);
2136 }
2137
2138 // Try to replace a binop with a shuffle if the shuffle is not costly.
2139 InstructionCost OldCost =
2140 TTI.getInstructionCost(U: LHS, CostKind) +
2141 TTI.getInstructionCost(U: RHS, CostKind) +
2142 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2143 SrcTy: BinResTy, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {LHS, RHS},
2144 CxtI: &I);
2145
2146 // Handle shuffle(binop(shuffle(x),y),binop(z,shuffle(w))) style patterns
2147 // where one use shuffles have gotten split across the binop/cmp. These
2148 // often allow a major reduction in total cost that wouldn't happen as
2149 // individual folds.
2150 auto MergeInner = [&](Value *&Op, int Offset, MutableArrayRef<int> Mask,
2151 TTI::TargetCostKind CostKind) -> bool {
2152 Value *InnerOp;
2153 ArrayRef<int> InnerMask;
2154 if (match(V: Op, P: m_OneUse(SubPattern: m_Shuffle(v1: m_Value(V&: InnerOp), v2: m_Undef(),
2155 mask: m_Mask(InnerMask)))) &&
2156 InnerOp->getType() == Op->getType() &&
2157 all_of(Range&: InnerMask,
2158 P: [NumSrcElts](int M) { return M < (int)NumSrcElts; })) {
2159 for (int &M : Mask)
2160 if (Offset <= M && M < (int)(Offset + NumSrcElts)) {
2161 M = InnerMask[M - Offset];
2162 M = 0 <= M ? M + Offset : M;
2163 }
2164 OldCost += TTI.getInstructionCost(U: cast<Instruction>(Val: Op), CostKind);
2165 Op = InnerOp;
2166 return true;
2167 }
2168 return false;
2169 };
2170 bool ReducedInstCount = false;
2171 ReducedInstCount |= MergeInner(X, 0, NewMask0, CostKind);
2172 ReducedInstCount |= MergeInner(Y, 0, NewMask1, CostKind);
2173 ReducedInstCount |= MergeInner(Z, NumSrcElts, NewMask0, CostKind);
2174 ReducedInstCount |= MergeInner(W, NumSrcElts, NewMask1, CostKind);
2175
2176 auto *ShuffleCmpTy =
2177 FixedVectorType::get(ElementType: BinOpTy->getElementType(), FVTy: ShuffleDstTy);
2178 InstructionCost NewCost =
2179 TTI.getShuffleCost(Kind: SK0, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask0, CostKind, Index: 0,
2180 SubTp: nullptr, Args: {X, Z}) +
2181 TTI.getShuffleCost(Kind: SK1, DstTy: ShuffleCmpTy, SrcTy: BinOpTy, Mask: NewMask1, CostKind, Index: 0,
2182 SubTp: nullptr, Args: {Y, W});
2183
2184 if (PredLHS == CmpInst::BAD_ICMP_PREDICATE) {
2185 NewCost +=
2186 TTI.getArithmeticInstrCost(Opcode: LHS->getOpcode(), Ty: ShuffleDstTy, CostKind);
2187 } else {
2188 NewCost += TTI.getCmpSelInstrCost(Opcode: LHS->getOpcode(), ValTy: ShuffleCmpTy,
2189 CondTy: ShuffleDstTy, VecPred: PredLHS, CostKind);
2190 }
2191
2192 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two binops: " << I
2193 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2194 << "\n");
2195
2196 // If either shuffle will constant fold away, then fold for the same cost as
2197 // we will reduce the instruction count.
2198 ReducedInstCount |= (isa<Constant>(Val: X) && isa<Constant>(Val: Z)) ||
2199 (isa<Constant>(Val: Y) && isa<Constant>(Val: W));
2200 if (ReducedInstCount ? (NewCost > OldCost) : (NewCost >= OldCost))
2201 return false;
2202
2203 Value *Shuf0 = Builder.CreateShuffleVector(V1: X, V2: Z, Mask: NewMask0);
2204 Value *Shuf1 = Builder.CreateShuffleVector(V1: Y, V2: W, Mask: NewMask1);
2205 Value *NewBO = PredLHS == CmpInst::BAD_ICMP_PREDICATE
2206 ? Builder.CreateBinOp(
2207 Opc: cast<BinaryOperator>(Val: LHS)->getOpcode(), LHS: Shuf0, RHS: Shuf1)
2208 : Builder.CreateCmp(Pred: PredLHS, LHS: Shuf0, RHS: Shuf1);
2209
2210 // Intersect flags from the old binops.
2211 if (auto *NewInst = dyn_cast<Instruction>(Val: NewBO)) {
2212 NewInst->copyIRFlags(V: LHS);
2213 NewInst->andIRFlags(V: RHS);
2214 }
2215
2216 Worklist.pushValue(V: Shuf0);
2217 Worklist.pushValue(V: Shuf1);
2218 replaceValue(Old&: I, New&: *NewBO);
2219 return true;
2220}
2221
2222/// Try to convert,
2223/// (shuffle(select(c1,t1,f1)), (select(c2,t2,f2)), m) into
2224/// (select (shuffle c1,c2,m), (shuffle t1,t2,m), (shuffle f1,f2,m))
2225bool VectorCombine::foldShuffleOfSelects(Instruction &I) {
2226 ArrayRef<int> Mask;
2227 Value *C1, *T1, *F1, *C2, *T2, *F2;
2228 if (!match(V: &I, P: m_Shuffle(
2229 v1: m_OneUse(SubPattern: m_Select(C: m_Value(V&: C1), L: m_Value(V&: T1), R: m_Value(V&: F1))),
2230 v2: m_OneUse(SubPattern: m_Select(C: m_Value(V&: C2), L: m_Value(V&: T2), R: m_Value(V&: F2))),
2231 mask: m_Mask(Mask))))
2232 return false;
2233
2234 auto *C1VecTy = dyn_cast<FixedVectorType>(Val: C1->getType());
2235 auto *C2VecTy = dyn_cast<FixedVectorType>(Val: C2->getType());
2236 if (!C1VecTy || !C2VecTy || C1VecTy != C2VecTy)
2237 return false;
2238
2239 auto *SI0FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 0));
2240 auto *SI1FOp = dyn_cast<FPMathOperator>(Val: I.getOperand(i: 1));
2241 // SelectInsts must have the same FMF.
2242 if (((SI0FOp == nullptr) != (SI1FOp == nullptr)) ||
2243 ((SI0FOp != nullptr) &&
2244 (SI0FOp->getFastMathFlags() != SI1FOp->getFastMathFlags())))
2245 return false;
2246
2247 auto *SrcVecTy = cast<FixedVectorType>(Val: T1->getType());
2248 auto *DstVecTy = cast<FixedVectorType>(Val: I.getType());
2249 auto SK = TargetTransformInfo::SK_PermuteTwoSrc;
2250 auto SelOp = Instruction::Select;
2251 InstructionCost OldCost = TTI.getCmpSelInstrCost(
2252 Opcode: SelOp, ValTy: SrcVecTy, CondTy: C1VecTy, VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2253 OldCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: SrcVecTy, CondTy: C2VecTy,
2254 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2255 OldCost +=
2256 TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0, SubTp: nullptr,
2257 Args: {I.getOperand(i: 0), I.getOperand(i: 1)}, CxtI: &I);
2258
2259 InstructionCost NewCost = TTI.getShuffleCost(
2260 Kind: SK, DstTy: FixedVectorType::get(ElementType: C1VecTy->getScalarType(), NumElts: Mask.size()), SrcTy: C1VecTy,
2261 Mask, CostKind, Index: 0, SubTp: nullptr, Args: {C1, C2});
2262 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2263 SubTp: nullptr, Args: {T1, T2});
2264 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: SrcVecTy, Mask, CostKind, Index: 0,
2265 SubTp: nullptr, Args: {F1, F2});
2266 auto *C1C2ShuffledVecTy = cast<FixedVectorType>(
2267 Val: toVectorTy(Scalar: Type::getInt1Ty(C&: I.getContext()), VF: DstVecTy->getNumElements()));
2268 NewCost += TTI.getCmpSelInstrCost(Opcode: SelOp, ValTy: DstVecTy, CondTy: C1C2ShuffledVecTy,
2269 VecPred: CmpInst::BAD_ICMP_PREDICATE, CostKind);
2270
2271 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two selects: " << I
2272 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2273 << "\n");
2274 if (NewCost > OldCost)
2275 return false;
2276
2277 Value *ShuffleCmp = Builder.CreateShuffleVector(V1: C1, V2: C2, Mask);
2278 Value *ShuffleTrue = Builder.CreateShuffleVector(V1: T1, V2: T2, Mask);
2279 Value *ShuffleFalse = Builder.CreateShuffleVector(V1: F1, V2: F2, Mask);
2280 Value *NewSel;
2281 // We presuppose that the SelectInsts have the same FMF.
2282 if (SI0FOp)
2283 NewSel = Builder.CreateSelectFMF(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse,
2284 FMFSource: SI0FOp->getFastMathFlags());
2285 else
2286 NewSel = Builder.CreateSelect(C: ShuffleCmp, True: ShuffleTrue, False: ShuffleFalse);
2287
2288 Worklist.pushValue(V: ShuffleCmp);
2289 Worklist.pushValue(V: ShuffleTrue);
2290 Worklist.pushValue(V: ShuffleFalse);
2291 replaceValue(Old&: I, New&: *NewSel);
2292 return true;
2293}
2294
2295/// Try to convert "shuffle (castop), (castop)" with a shared castop operand
2296/// into "castop (shuffle)".
2297bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
2298 Value *V0, *V1;
2299 ArrayRef<int> OldMask;
2300 if (!match(V: &I, P: m_Shuffle(v1: m_Value(V&: V0), v2: m_Value(V&: V1), mask: m_Mask(OldMask))))
2301 return false;
2302
2303 auto *C0 = dyn_cast<CastInst>(Val: V0);
2304 auto *C1 = dyn_cast<CastInst>(Val: V1);
2305 if (!C0 || !C1)
2306 return false;
2307
2308 Instruction::CastOps Opcode = C0->getOpcode();
2309 if (C0->getSrcTy() != C1->getSrcTy())
2310 return false;
2311
2312 // Handle shuffle(zext_nneg(x), sext(y)) -> sext(shuffle(x,y)) folds.
2313 if (Opcode != C1->getOpcode()) {
2314 if (match(V: C0, P: m_SExtLike(Op: m_Value())) && match(V: C1, P: m_SExtLike(Op: m_Value())))
2315 Opcode = Instruction::SExt;
2316 else
2317 return false;
2318 }
2319
2320 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2321 auto *CastDstTy = dyn_cast<FixedVectorType>(Val: C0->getDestTy());
2322 auto *CastSrcTy = dyn_cast<FixedVectorType>(Val: C0->getSrcTy());
2323 if (!ShuffleDstTy || !CastDstTy || !CastSrcTy)
2324 return false;
2325
2326 unsigned NumSrcElts = CastSrcTy->getNumElements();
2327 unsigned NumDstElts = CastDstTy->getNumElements();
2328 assert((NumDstElts == NumSrcElts || Opcode == Instruction::BitCast) &&
2329 "Only bitcasts expected to alter src/dst element counts");
2330
2331 // Check for bitcasting of unscalable vector types.
2332 // e.g. <32 x i40> -> <40 x i32>
2333 if (NumDstElts != NumSrcElts && (NumSrcElts % NumDstElts) != 0 &&
2334 (NumDstElts % NumSrcElts) != 0)
2335 return false;
2336
2337 SmallVector<int, 16> NewMask;
2338 if (NumSrcElts >= NumDstElts) {
2339 // The bitcast is from wide to narrow/equal elements. The shuffle mask can
2340 // always be expanded to the equivalent form choosing narrower elements.
2341 assert(NumSrcElts % NumDstElts == 0 && "Unexpected shuffle mask");
2342 unsigned ScaleFactor = NumSrcElts / NumDstElts;
2343 narrowShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask);
2344 } else {
2345 // The bitcast is from narrow elements to wide elements. The shuffle mask
2346 // must choose consecutive elements to allow casting first.
2347 assert(NumDstElts % NumSrcElts == 0 && "Unexpected shuffle mask");
2348 unsigned ScaleFactor = NumDstElts / NumSrcElts;
2349 if (!widenShuffleMaskElts(Scale: ScaleFactor, Mask: OldMask, ScaledMask&: NewMask))
2350 return false;
2351 }
2352
2353 auto *NewShuffleDstTy =
2354 FixedVectorType::get(ElementType: CastSrcTy->getScalarType(), NumElts: NewMask.size());
2355
2356 // Try to replace a castop with a shuffle if the shuffle is not costly.
2357 InstructionCost CostC0 =
2358 TTI.getCastInstrCost(Opcode: C0->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2359 CCH: TTI::CastContextHint::None, CostKind);
2360 InstructionCost CostC1 =
2361 TTI.getCastInstrCost(Opcode: C1->getOpcode(), Dst: CastDstTy, Src: CastSrcTy,
2362 CCH: TTI::CastContextHint::None, CostKind);
2363 InstructionCost OldCost = CostC0 + CostC1;
2364 OldCost +=
2365 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2366 SrcTy: CastDstTy, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {}, CxtI: &I);
2367
2368 InstructionCost NewCost =
2369 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: NewShuffleDstTy,
2370 SrcTy: CastSrcTy, Mask: NewMask, CostKind);
2371 NewCost += TTI.getCastInstrCost(Opcode, Dst: ShuffleDstTy, Src: NewShuffleDstTy,
2372 CCH: TTI::CastContextHint::None, CostKind);
2373 if (!C0->hasOneUse())
2374 NewCost += CostC0;
2375 if (!C1->hasOneUse())
2376 NewCost += CostC1;
2377
2378 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two casts: " << I
2379 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2380 << "\n");
2381 if (NewCost > OldCost)
2382 return false;
2383
2384 Value *Shuf = Builder.CreateShuffleVector(V1: C0->getOperand(i_nocapture: 0),
2385 V2: C1->getOperand(i_nocapture: 0), Mask: NewMask);
2386 Value *Cast = Builder.CreateCast(Op: Opcode, V: Shuf, DestTy: ShuffleDstTy);
2387
2388 // Intersect flags from the old casts.
2389 if (auto *NewInst = dyn_cast<Instruction>(Val: Cast)) {
2390 NewInst->copyIRFlags(V: C0);
2391 NewInst->andIRFlags(V: C1);
2392 }
2393
2394 Worklist.pushValue(V: Shuf);
2395 replaceValue(Old&: I, New&: *Cast);
2396 return true;
2397}
2398
2399/// Try to convert any of:
2400/// "shuffle (shuffle x, y), (shuffle y, x)"
2401/// "shuffle (shuffle x, undef), (shuffle y, undef)"
2402/// "shuffle (shuffle x, undef), y"
2403/// "shuffle x, (shuffle y, undef)"
2404/// into "shuffle x, y".
2405bool VectorCombine::foldShuffleOfShuffles(Instruction &I) {
2406 ArrayRef<int> OuterMask;
2407 Value *OuterV0, *OuterV1;
2408 if (!match(V: &I,
2409 P: m_Shuffle(v1: m_Value(V&: OuterV0), v2: m_Value(V&: OuterV1), mask: m_Mask(OuterMask))))
2410 return false;
2411
2412 ArrayRef<int> InnerMask0, InnerMask1;
2413 Value *X0, *X1, *Y0, *Y1;
2414 bool Match0 =
2415 match(V: OuterV0, P: m_Shuffle(v1: m_Value(V&: X0), v2: m_Value(V&: Y0), mask: m_Mask(InnerMask0)));
2416 bool Match1 =
2417 match(V: OuterV1, P: m_Shuffle(v1: m_Value(V&: X1), v2: m_Value(V&: Y1), mask: m_Mask(InnerMask1)));
2418 if (!Match0 && !Match1)
2419 return false;
2420
2421 // If the outer shuffle is a permute, then create a fake inner all-poison
2422 // shuffle. This is easier than accounting for length-changing shuffles below.
2423 SmallVector<int, 16> PoisonMask1;
2424 if (!Match1 && isa<PoisonValue>(Val: OuterV1)) {
2425 X1 = X0;
2426 Y1 = Y0;
2427 PoisonMask1.append(NumInputs: InnerMask0.size(), Elt: PoisonMaskElem);
2428 InnerMask1 = PoisonMask1;
2429 Match1 = true; // fake match
2430 }
2431
2432 X0 = Match0 ? X0 : OuterV0;
2433 Y0 = Match0 ? Y0 : OuterV0;
2434 X1 = Match1 ? X1 : OuterV1;
2435 Y1 = Match1 ? Y1 : OuterV1;
2436 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2437 auto *ShuffleSrcTy = dyn_cast<FixedVectorType>(Val: X0->getType());
2438 auto *ShuffleImmTy = dyn_cast<FixedVectorType>(Val: OuterV0->getType());
2439 if (!ShuffleDstTy || !ShuffleSrcTy || !ShuffleImmTy ||
2440 X0->getType() != X1->getType())
2441 return false;
2442
2443 unsigned NumSrcElts = ShuffleSrcTy->getNumElements();
2444 unsigned NumImmElts = ShuffleImmTy->getNumElements();
2445
2446 // Attempt to merge shuffles, matching upto 2 source operands.
2447 // Replace index to a poison arg with PoisonMaskElem.
2448 // Bail if either inner masks reference an undef arg.
2449 SmallVector<int, 16> NewMask(OuterMask);
2450 Value *NewX = nullptr, *NewY = nullptr;
2451 for (int &M : NewMask) {
2452 Value *Src = nullptr;
2453 if (0 <= M && M < (int)NumImmElts) {
2454 Src = OuterV0;
2455 if (Match0) {
2456 M = InnerMask0[M];
2457 Src = M >= (int)NumSrcElts ? Y0 : X0;
2458 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2459 }
2460 } else if (M >= (int)NumImmElts) {
2461 Src = OuterV1;
2462 M -= NumImmElts;
2463 if (Match1) {
2464 M = InnerMask1[M];
2465 Src = M >= (int)NumSrcElts ? Y1 : X1;
2466 M = M >= (int)NumSrcElts ? (M - NumSrcElts) : M;
2467 }
2468 }
2469 if (Src && M != PoisonMaskElem) {
2470 assert(0 <= M && M < (int)NumSrcElts && "Unexpected shuffle mask index");
2471 if (isa<UndefValue>(Val: Src)) {
2472 // We've referenced an undef element - if its poison, update the shuffle
2473 // mask, else bail.
2474 if (!isa<PoisonValue>(Val: Src))
2475 return false;
2476 M = PoisonMaskElem;
2477 continue;
2478 }
2479 if (!NewX || NewX == Src) {
2480 NewX = Src;
2481 continue;
2482 }
2483 if (!NewY || NewY == Src) {
2484 M += NumSrcElts;
2485 NewY = Src;
2486 continue;
2487 }
2488 return false;
2489 }
2490 }
2491
2492 if (!NewX)
2493 return PoisonValue::get(T: ShuffleDstTy);
2494 if (!NewY)
2495 NewY = PoisonValue::get(T: ShuffleSrcTy);
2496
2497 // Have we folded to an Identity shuffle?
2498 if (ShuffleVectorInst::isIdentityMask(Mask: NewMask, NumSrcElts)) {
2499 replaceValue(Old&: I, New&: *NewX);
2500 return true;
2501 }
2502
2503 // Try to merge the shuffles if the new shuffle is not costly.
2504 InstructionCost InnerCost0 = 0;
2505 if (Match0)
2506 InnerCost0 = TTI.getInstructionCost(U: cast<User>(Val: OuterV0), CostKind);
2507
2508 InstructionCost InnerCost1 = 0;
2509 if (Match1)
2510 InnerCost1 = TTI.getInstructionCost(U: cast<User>(Val: OuterV1), CostKind);
2511
2512 InstructionCost OuterCost = TTI.getInstructionCost(U: &I, CostKind);
2513
2514 InstructionCost OldCost = InnerCost0 + InnerCost1 + OuterCost;
2515
2516 bool IsUnary = all_of(Range&: NewMask, P: [&](int M) { return M < (int)NumSrcElts; });
2517 TargetTransformInfo::ShuffleKind SK =
2518 IsUnary ? TargetTransformInfo::SK_PermuteSingleSrc
2519 : TargetTransformInfo::SK_PermuteTwoSrc;
2520 InstructionCost NewCost =
2521 TTI.getShuffleCost(Kind: SK, DstTy: ShuffleDstTy, SrcTy: ShuffleSrcTy, Mask: NewMask, CostKind, Index: 0,
2522 SubTp: nullptr, Args: {NewX, NewY});
2523 if (!OuterV0->hasOneUse())
2524 NewCost += InnerCost0;
2525 if (!OuterV1->hasOneUse())
2526 NewCost += InnerCost1;
2527
2528 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two shuffles: " << I
2529 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2530 << "\n");
2531 if (NewCost > OldCost)
2532 return false;
2533
2534 Value *Shuf = Builder.CreateShuffleVector(V1: NewX, V2: NewY, Mask: NewMask);
2535 replaceValue(Old&: I, New&: *Shuf);
2536 return true;
2537}
2538
2539/// Try to convert
2540/// "shuffle (intrinsic), (intrinsic)" into "intrinsic (shuffle), (shuffle)".
2541bool VectorCombine::foldShuffleOfIntrinsics(Instruction &I) {
2542 Value *V0, *V1;
2543 ArrayRef<int> OldMask;
2544 if (!match(V: &I, P: m_Shuffle(v1: m_OneUse(SubPattern: m_Value(V&: V0)), v2: m_OneUse(SubPattern: m_Value(V&: V1)),
2545 mask: m_Mask(OldMask))))
2546 return false;
2547
2548 auto *II0 = dyn_cast<IntrinsicInst>(Val: V0);
2549 auto *II1 = dyn_cast<IntrinsicInst>(Val: V1);
2550 if (!II0 || !II1)
2551 return false;
2552
2553 Intrinsic::ID IID = II0->getIntrinsicID();
2554 if (IID != II1->getIntrinsicID())
2555 return false;
2556
2557 auto *ShuffleDstTy = dyn_cast<FixedVectorType>(Val: I.getType());
2558 auto *II0Ty = dyn_cast<FixedVectorType>(Val: II0->getType());
2559 if (!ShuffleDstTy || !II0Ty)
2560 return false;
2561
2562 if (!isTriviallyVectorizable(ID: IID))
2563 return false;
2564
2565 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2566 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI) &&
2567 II0->getArgOperand(i: I) != II1->getArgOperand(i: I))
2568 return false;
2569
2570 InstructionCost OldCost =
2571 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II0), CostKind) +
2572 TTI.getIntrinsicInstrCost(ICA: IntrinsicCostAttributes(IID, *II1), CostKind) +
2573 TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc, DstTy: ShuffleDstTy,
2574 SrcTy: II0Ty, Mask: OldMask, CostKind, Index: 0, SubTp: nullptr, Args: {II0, II1}, CxtI: &I);
2575
2576 SmallVector<Type *> NewArgsTy;
2577 InstructionCost NewCost = 0;
2578 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I) {
2579 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
2580 NewArgsTy.push_back(Elt: II0->getArgOperand(i: I)->getType());
2581 } else {
2582 auto *VecTy = cast<FixedVectorType>(Val: II0->getArgOperand(i: I)->getType());
2583 auto *ArgTy = FixedVectorType::get(ElementType: VecTy->getElementType(),
2584 NumElts: ShuffleDstTy->getNumElements());
2585 NewArgsTy.push_back(Elt: ArgTy);
2586 NewCost += TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteTwoSrc,
2587 DstTy: ArgTy, SrcTy: VecTy, Mask: OldMask, CostKind);
2588 }
2589 }
2590 IntrinsicCostAttributes NewAttr(IID, ShuffleDstTy, NewArgsTy);
2591 NewCost += TTI.getIntrinsicInstrCost(ICA: NewAttr, CostKind);
2592
2593 LLVM_DEBUG(dbgs() << "Found a shuffle feeding two intrinsics: " << I
2594 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
2595 << "\n");
2596
2597 if (NewCost > OldCost)
2598 return false;
2599
2600 SmallVector<Value *> NewArgs;
2601 for (unsigned I = 0, E = II0->arg_size(); I != E; ++I)
2602 if (isVectorIntrinsicWithScalarOpAtArg(ID: IID, ScalarOpdIdx: I, TTI: &TTI)) {
2603 NewArgs.push_back(Elt: II0->getArgOperand(i: I));
2604 } else {
2605 Value *Shuf = Builder.CreateShuffleVector(V1: II0->getArgOperand(i: I),
2606 V2: II1->getArgOperand(i: I), Mask: OldMask);
2607 NewArgs.push_back(Elt: Shuf);
2608 Worklist.pushValue(V: Shuf);
2609 }
2610 Value *NewIntrinsic = Builder.CreateIntrinsic(RetTy: ShuffleDstTy, ID: IID, Args: NewArgs);
2611
2612 // Intersect flags from the old intrinsics.
2613 if (auto *NewInst = dyn_cast<Instruction>(Val: NewIntrinsic)) {
2614 NewInst->copyIRFlags(V: II0);
2615 NewInst->andIRFlags(V: II1);
2616 }
2617
2618 replaceValue(Old&: I, New&: *NewIntrinsic);
2619 return true;
2620}
2621
2622using InstLane = std::pair<Use *, int>;
2623
2624static InstLane lookThroughShuffles(Use *U, int Lane) {
2625 while (auto *SV = dyn_cast<ShuffleVectorInst>(Val: U->get())) {
2626 unsigned NumElts =
2627 cast<FixedVectorType>(Val: SV->getOperand(i_nocapture: 0)->getType())->getNumElements();
2628 int M = SV->getMaskValue(Elt: Lane);
2629 if (M < 0)
2630 return {nullptr, PoisonMaskElem};
2631 if (static_cast<unsigned>(M) < NumElts) {
2632 U = &SV->getOperandUse(i: 0);
2633 Lane = M;
2634 } else {
2635 U = &SV->getOperandUse(i: 1);
2636 Lane = M - NumElts;
2637 }
2638 }
2639 return InstLane{U, Lane};
2640}
2641
2642static SmallVector<InstLane>
2643generateInstLaneVectorFromOperand(ArrayRef<InstLane> Item, int Op) {
2644 SmallVector<InstLane> NItem;
2645 for (InstLane IL : Item) {
2646 auto [U, Lane] = IL;
2647 InstLane OpLane =
2648 U ? lookThroughShuffles(U: &cast<Instruction>(Val: U->get())->getOperandUse(i: Op),
2649 Lane)
2650 : InstLane{nullptr, PoisonMaskElem};
2651 NItem.emplace_back(Args&: OpLane);
2652 }
2653 return NItem;
2654}
2655
2656/// Detect concat of multiple values into a vector
2657static bool isFreeConcat(ArrayRef<InstLane> Item, TTI::TargetCostKind CostKind,
2658 const TargetTransformInfo &TTI) {
2659 auto *Ty = cast<FixedVectorType>(Val: Item.front().first->get()->getType());
2660 unsigned NumElts = Ty->getNumElements();
2661 if (Item.size() == NumElts || NumElts == 1 || Item.size() % NumElts != 0)
2662 return false;
2663
2664 // Check that the concat is free, usually meaning that the type will be split
2665 // during legalization.
2666 SmallVector<int, 16> ConcatMask(NumElts * 2);
2667 std::iota(first: ConcatMask.begin(), last: ConcatMask.end(), value: 0);
2668 if (TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc,
2669 DstTy: FixedVectorType::get(ElementType: Ty->getScalarType(), NumElts: NumElts * 2),
2670 SrcTy: Ty, Mask: ConcatMask, CostKind) != 0)
2671 return false;
2672
2673 unsigned NumSlices = Item.size() / NumElts;
2674 // Currently we generate a tree of shuffles for the concats, which limits us
2675 // to a power2.
2676 if (!isPowerOf2_32(Value: NumSlices))
2677 return false;
2678 for (unsigned Slice = 0; Slice < NumSlices; ++Slice) {
2679 Use *SliceV = Item[Slice * NumElts].first;
2680 if (!SliceV || SliceV->get()->getType() != Ty)
2681 return false;
2682 for (unsigned Elt = 0; Elt < NumElts; ++Elt) {
2683 auto [V, Lane] = Item[Slice * NumElts + Elt];
2684 if (Lane != static_cast<int>(Elt) || SliceV->get() != V->get())
2685 return false;
2686 }
2687 }
2688 return true;
2689}
2690
2691static Value *generateNewInstTree(ArrayRef<InstLane> Item, FixedVectorType *Ty,
2692 const SmallPtrSet<Use *, 4> &IdentityLeafs,
2693 const SmallPtrSet<Use *, 4> &SplatLeafs,
2694 const SmallPtrSet<Use *, 4> &ConcatLeafs,
2695 IRBuilderBase &Builder,
2696 const TargetTransformInfo *TTI) {
2697 auto [FrontU, FrontLane] = Item.front();
2698
2699 if (IdentityLeafs.contains(Ptr: FrontU)) {
2700 return FrontU->get();
2701 }
2702 if (SplatLeafs.contains(Ptr: FrontU)) {
2703 SmallVector<int, 16> Mask(Ty->getNumElements(), FrontLane);
2704 return Builder.CreateShuffleVector(V: FrontU->get(), Mask);
2705 }
2706 if (ConcatLeafs.contains(Ptr: FrontU)) {
2707 unsigned NumElts =
2708 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements();
2709 SmallVector<Value *> Values(Item.size() / NumElts, nullptr);
2710 for (unsigned S = 0; S < Values.size(); ++S)
2711 Values[S] = Item[S * NumElts].first->get();
2712
2713 while (Values.size() > 1) {
2714 NumElts *= 2;
2715 SmallVector<int, 16> Mask(NumElts, 0);
2716 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
2717 SmallVector<Value *> NewValues(Values.size() / 2, nullptr);
2718 for (unsigned S = 0; S < NewValues.size(); ++S)
2719 NewValues[S] =
2720 Builder.CreateShuffleVector(V1: Values[S * 2], V2: Values[S * 2 + 1], Mask);
2721 Values = NewValues;
2722 }
2723 return Values[0];
2724 }
2725
2726 auto *I = cast<Instruction>(Val: FrontU->get());
2727 auto *II = dyn_cast<IntrinsicInst>(Val: I);
2728 unsigned NumOps = I->getNumOperands() - (II ? 1 : 0);
2729 SmallVector<Value *> Ops(NumOps);
2730 for (unsigned Idx = 0; Idx < NumOps; Idx++) {
2731 if (II &&
2732 isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Idx, TTI)) {
2733 Ops[Idx] = II->getOperand(i_nocapture: Idx);
2734 continue;
2735 }
2736 Ops[Idx] = generateNewInstTree(Item: generateInstLaneVectorFromOperand(Item, Op: Idx),
2737 Ty, IdentityLeafs, SplatLeafs, ConcatLeafs,
2738 Builder, TTI);
2739 }
2740
2741 SmallVector<Value *, 8> ValueList;
2742 for (const auto &Lane : Item)
2743 if (Lane.first)
2744 ValueList.push_back(Elt: Lane.first->get());
2745
2746 Type *DstTy =
2747 FixedVectorType::get(ElementType: I->getType()->getScalarType(), NumElts: Ty->getNumElements());
2748 if (auto *BI = dyn_cast<BinaryOperator>(Val: I)) {
2749 auto *Value = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)BI->getOpcode(),
2750 LHS: Ops[0], RHS: Ops[1]);
2751 propagateIRFlags(I: Value, VL: ValueList);
2752 return Value;
2753 }
2754 if (auto *CI = dyn_cast<CmpInst>(Val: I)) {
2755 auto *Value = Builder.CreateCmp(Pred: CI->getPredicate(), LHS: Ops[0], RHS: Ops[1]);
2756 propagateIRFlags(I: Value, VL: ValueList);
2757 return Value;
2758 }
2759 if (auto *SI = dyn_cast<SelectInst>(Val: I)) {
2760 auto *Value = Builder.CreateSelect(C: Ops[0], True: Ops[1], False: Ops[2], Name: "", MDFrom: SI);
2761 propagateIRFlags(I: Value, VL: ValueList);
2762 return Value;
2763 }
2764 if (auto *CI = dyn_cast<CastInst>(Val: I)) {
2765 auto *Value = Builder.CreateCast(Op: (Instruction::CastOps)CI->getOpcode(),
2766 V: Ops[0], DestTy: DstTy);
2767 propagateIRFlags(I: Value, VL: ValueList);
2768 return Value;
2769 }
2770 if (II) {
2771 auto *Value = Builder.CreateIntrinsic(RetTy: DstTy, ID: II->getIntrinsicID(), Args: Ops);
2772 propagateIRFlags(I: Value, VL: ValueList);
2773 return Value;
2774 }
2775 assert(isa<UnaryInstruction>(I) && "Unexpected instruction type in Generate");
2776 auto *Value =
2777 Builder.CreateUnOp(Opc: (Instruction::UnaryOps)I->getOpcode(), V: Ops[0]);
2778 propagateIRFlags(I: Value, VL: ValueList);
2779 return Value;
2780}
2781
2782// Starting from a shuffle, look up through operands tracking the shuffled index
2783// of each lane. If we can simplify away the shuffles to identities then
2784// do so.
2785bool VectorCombine::foldShuffleToIdentity(Instruction &I) {
2786 auto *Ty = dyn_cast<FixedVectorType>(Val: I.getType());
2787 if (!Ty || I.use_empty())
2788 return false;
2789
2790 SmallVector<InstLane> Start(Ty->getNumElements());
2791 for (unsigned M = 0, E = Ty->getNumElements(); M < E; ++M)
2792 Start[M] = lookThroughShuffles(U: &*I.use_begin(), Lane: M);
2793
2794 SmallVector<SmallVector<InstLane>> Worklist;
2795 Worklist.push_back(Elt: Start);
2796 SmallPtrSet<Use *, 4> IdentityLeafs, SplatLeafs, ConcatLeafs;
2797 unsigned NumVisited = 0;
2798
2799 while (!Worklist.empty()) {
2800 if (++NumVisited > MaxInstrsToScan)
2801 return false;
2802
2803 SmallVector<InstLane> Item = Worklist.pop_back_val();
2804 auto [FrontU, FrontLane] = Item.front();
2805
2806 // If we found an undef first lane then bail out to keep things simple.
2807 if (!FrontU)
2808 return false;
2809
2810 // Helper to peek through bitcasts to the same value.
2811 auto IsEquiv = [&](Value *X, Value *Y) {
2812 return X->getType() == Y->getType() &&
2813 peekThroughBitcasts(V: X) == peekThroughBitcasts(V: Y);
2814 };
2815
2816 // Look for an identity value.
2817 if (FrontLane == 0 &&
2818 cast<FixedVectorType>(Val: FrontU->get()->getType())->getNumElements() ==
2819 Ty->getNumElements() &&
2820 all_of(Range: drop_begin(RangeOrContainer: enumerate(First&: Item)), P: [IsEquiv, Item](const auto &E) {
2821 Value *FrontV = Item.front().first->get();
2822 return !E.value().first || (IsEquiv(E.value().first->get(), FrontV) &&
2823 E.value().second == (int)E.index());
2824 })) {
2825 IdentityLeafs.insert(Ptr: FrontU);
2826 continue;
2827 }
2828 // Look for constants, for the moment only supporting constant splats.
2829 if (auto *C = dyn_cast<Constant>(Val: FrontU);
2830 C && C->getSplatValue() &&
2831 all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
2832 Value *FrontV = Item.front().first->get();
2833 Use *U = IL.first;
2834 return !U || (isa<Constant>(Val: U->get()) &&
2835 cast<Constant>(Val: U->get())->getSplatValue() ==
2836 cast<Constant>(Val: FrontV)->getSplatValue());
2837 })) {
2838 SplatLeafs.insert(Ptr: FrontU);
2839 continue;
2840 }
2841 // Look for a splat value.
2842 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item](InstLane &IL) {
2843 auto [FrontU, FrontLane] = Item.front();
2844 auto [U, Lane] = IL;
2845 return !U || (U->get() == FrontU->get() && Lane == FrontLane);
2846 })) {
2847 SplatLeafs.insert(Ptr: FrontU);
2848 continue;
2849 }
2850
2851 // We need each element to be the same type of value, and check that each
2852 // element has a single use.
2853 auto CheckLaneIsEquivalentToFirst = [Item](InstLane IL) {
2854 Value *FrontV = Item.front().first->get();
2855 if (!IL.first)
2856 return true;
2857 Value *V = IL.first->get();
2858 if (auto *I = dyn_cast<Instruction>(Val: V); I && !I->hasOneUse())
2859 return false;
2860 if (V->getValueID() != FrontV->getValueID())
2861 return false;
2862 if (auto *CI = dyn_cast<CmpInst>(Val: V))
2863 if (CI->getPredicate() != cast<CmpInst>(Val: FrontV)->getPredicate())
2864 return false;
2865 if (auto *CI = dyn_cast<CastInst>(Val: V))
2866 if (CI->getSrcTy()->getScalarType() !=
2867 cast<CastInst>(Val: FrontV)->getSrcTy()->getScalarType())
2868 return false;
2869 if (auto *SI = dyn_cast<SelectInst>(Val: V))
2870 if (!isa<VectorType>(Val: SI->getOperand(i_nocapture: 0)->getType()) ||
2871 SI->getOperand(i_nocapture: 0)->getType() !=
2872 cast<SelectInst>(Val: FrontV)->getOperand(i_nocapture: 0)->getType())
2873 return false;
2874 if (isa<CallInst>(Val: V) && !isa<IntrinsicInst>(Val: V))
2875 return false;
2876 auto *II = dyn_cast<IntrinsicInst>(Val: V);
2877 return !II || (isa<IntrinsicInst>(Val: FrontV) &&
2878 II->getIntrinsicID() ==
2879 cast<IntrinsicInst>(Val: FrontV)->getIntrinsicID() &&
2880 !II->hasOperandBundles());
2881 };
2882 if (all_of(Range: drop_begin(RangeOrContainer&: Item), P: CheckLaneIsEquivalentToFirst)) {
2883 // Check the operator is one that we support.
2884 if (isa<BinaryOperator, CmpInst>(Val: FrontU)) {
2885 // We exclude div/rem in case they hit UB from poison lanes.
2886 if (auto *BO = dyn_cast<BinaryOperator>(Val: FrontU);
2887 BO && BO->isIntDivRem())
2888 return false;
2889 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
2890 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
2891 continue;
2892 } else if (isa<UnaryOperator, TruncInst, ZExtInst, SExtInst, FPToSIInst,
2893 FPToUIInst, SIToFPInst, UIToFPInst>(Val: FrontU)) {
2894 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
2895 continue;
2896 } else if (auto *BitCast = dyn_cast<BitCastInst>(Val: FrontU)) {
2897 // TODO: Handle vector widening/narrowing bitcasts.
2898 auto *DstTy = dyn_cast<FixedVectorType>(Val: BitCast->getDestTy());
2899 auto *SrcTy = dyn_cast<FixedVectorType>(Val: BitCast->getSrcTy());
2900 if (DstTy && SrcTy &&
2901 SrcTy->getNumElements() == DstTy->getNumElements()) {
2902 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
2903 continue;
2904 }
2905 } else if (isa<SelectInst>(Val: FrontU)) {
2906 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 0));
2907 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 1));
2908 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op: 2));
2909 continue;
2910 } else if (auto *II = dyn_cast<IntrinsicInst>(Val: FrontU);
2911 II && isTriviallyVectorizable(ID: II->getIntrinsicID()) &&
2912 !II->hasOperandBundles()) {
2913 for (unsigned Op = 0, E = II->getNumOperands() - 1; Op < E; Op++) {
2914 if (isVectorIntrinsicWithScalarOpAtArg(ID: II->getIntrinsicID(), ScalarOpdIdx: Op,
2915 TTI: &TTI)) {
2916 if (!all_of(Range: drop_begin(RangeOrContainer&: Item), P: [Item, Op](InstLane &IL) {
2917 Value *FrontV = Item.front().first->get();
2918 Use *U = IL.first;
2919 return !U || (cast<Instruction>(Val: U->get())->getOperand(i: Op) ==
2920 cast<Instruction>(Val: FrontV)->getOperand(i: Op));
2921 }))
2922 return false;
2923 continue;
2924 }
2925 Worklist.push_back(Elt: generateInstLaneVectorFromOperand(Item, Op));
2926 }
2927 continue;
2928 }
2929 }
2930
2931 if (isFreeConcat(Item, CostKind, TTI)) {
2932 ConcatLeafs.insert(Ptr: FrontU);
2933 continue;
2934 }
2935
2936 return false;
2937 }
2938
2939 if (NumVisited <= 1)
2940 return false;
2941
2942 LLVM_DEBUG(dbgs() << "Found a superfluous identity shuffle: " << I << "\n");
2943
2944 // If we got this far, we know the shuffles are superfluous and can be
2945 // removed. Scan through again and generate the new tree of instructions.
2946 Builder.SetInsertPoint(&I);
2947 Value *V = generateNewInstTree(Item: Start, Ty, IdentityLeafs, SplatLeafs,
2948 ConcatLeafs, Builder, TTI: &TTI);
2949 replaceValue(Old&: I, New&: *V);
2950 return true;
2951}
2952
2953/// Given a commutative reduction, the order of the input lanes does not alter
2954/// the results. We can use this to remove certain shuffles feeding the
2955/// reduction, removing the need to shuffle at all.
2956bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2957 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
2958 if (!II)
2959 return false;
2960 switch (II->getIntrinsicID()) {
2961 case Intrinsic::vector_reduce_add:
2962 case Intrinsic::vector_reduce_mul:
2963 case Intrinsic::vector_reduce_and:
2964 case Intrinsic::vector_reduce_or:
2965 case Intrinsic::vector_reduce_xor:
2966 case Intrinsic::vector_reduce_smin:
2967 case Intrinsic::vector_reduce_smax:
2968 case Intrinsic::vector_reduce_umin:
2969 case Intrinsic::vector_reduce_umax:
2970 break;
2971 default:
2972 return false;
2973 }
2974
2975 // Find all the inputs when looking through operations that do not alter the
2976 // lane order (binops, for example). Currently we look for a single shuffle,
2977 // and can ignore splat values.
2978 std::queue<Value *> Worklist;
2979 SmallPtrSet<Value *, 4> Visited;
2980 ShuffleVectorInst *Shuffle = nullptr;
2981 if (auto *Op = dyn_cast<Instruction>(Val: I.getOperand(i: 0)))
2982 Worklist.push(x: Op);
2983
2984 while (!Worklist.empty()) {
2985 Value *CV = Worklist.front();
2986 Worklist.pop();
2987 if (Visited.contains(Ptr: CV))
2988 continue;
2989
2990 // Splats don't change the order, so can be safely ignored.
2991 if (isSplatValue(V: CV))
2992 continue;
2993
2994 Visited.insert(Ptr: CV);
2995
2996 if (auto *CI = dyn_cast<Instruction>(Val: CV)) {
2997 if (CI->isBinaryOp()) {
2998 for (auto *Op : CI->operand_values())
2999 Worklist.push(x: Op);
3000 continue;
3001 } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Val: CI)) {
3002 if (Shuffle && Shuffle != SV)
3003 return false;
3004 Shuffle = SV;
3005 continue;
3006 }
3007 }
3008
3009 // Anything else is currently an unknown node.
3010 return false;
3011 }
3012
3013 if (!Shuffle)
3014 return false;
3015
3016 // Check all uses of the binary ops and shuffles are also included in the
3017 // lane-invariant operations (Visited should be the list of lanewise
3018 // instructions, including the shuffle that we found).
3019 for (auto *V : Visited)
3020 for (auto *U : V->users())
3021 if (!Visited.contains(Ptr: U) && U != &I)
3022 return false;
3023
3024 FixedVectorType *VecType =
3025 dyn_cast<FixedVectorType>(Val: II->getOperand(i_nocapture: 0)->getType());
3026 if (!VecType)
3027 return false;
3028 FixedVectorType *ShuffleInputType =
3029 dyn_cast<FixedVectorType>(Val: Shuffle->getOperand(i_nocapture: 0)->getType());
3030 if (!ShuffleInputType)
3031 return false;
3032 unsigned NumInputElts = ShuffleInputType->getNumElements();
3033
3034 // Find the mask from sorting the lanes into order. This is most likely to
3035 // become a identity or concat mask. Undef elements are pushed to the end.
3036 SmallVector<int> ConcatMask;
3037 Shuffle->getShuffleMask(Result&: ConcatMask);
3038 sort(C&: ConcatMask, Comp: [](int X, int Y) { return (unsigned)X < (unsigned)Y; });
3039 bool UsesSecondVec =
3040 any_of(Range&: ConcatMask, P: [&](int M) { return M >= (int)NumInputElts; });
3041
3042 InstructionCost OldCost = TTI.getShuffleCost(
3043 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3044 SrcTy: ShuffleInputType, Mask: Shuffle->getShuffleMask(), CostKind);
3045 InstructionCost NewCost = TTI.getShuffleCost(
3046 Kind: UsesSecondVec ? TTI::SK_PermuteTwoSrc : TTI::SK_PermuteSingleSrc, DstTy: VecType,
3047 SrcTy: ShuffleInputType, Mask: ConcatMask, CostKind);
3048
3049 LLVM_DEBUG(dbgs() << "Found a reduction feeding from a shuffle: " << *Shuffle
3050 << "\n");
3051 LLVM_DEBUG(dbgs() << " OldCost: " << OldCost << " vs NewCost: " << NewCost
3052 << "\n");
3053 bool MadeChanges = false;
3054 if (NewCost < OldCost) {
3055 Builder.SetInsertPoint(Shuffle);
3056 Value *NewShuffle = Builder.CreateShuffleVector(
3057 V1: Shuffle->getOperand(i_nocapture: 0), V2: Shuffle->getOperand(i_nocapture: 1), Mask: ConcatMask);
3058 LLVM_DEBUG(dbgs() << "Created new shuffle: " << *NewShuffle << "\n");
3059 replaceValue(Old&: *Shuffle, New&: *NewShuffle);
3060 MadeChanges = true;
3061 }
3062
3063 // See if we can re-use foldSelectShuffle, getting it to reduce the size of
3064 // the shuffle into a nicer order, as it can ignore the order of the shuffles.
3065 MadeChanges |= foldSelectShuffle(I&: *Shuffle, FromReduction: true);
3066 return MadeChanges;
3067}
3068
3069/// Determine if its more efficient to fold:
3070/// reduce(trunc(x)) -> trunc(reduce(x)).
3071/// reduce(sext(x)) -> sext(reduce(x)).
3072/// reduce(zext(x)) -> zext(reduce(x)).
3073bool VectorCombine::foldCastFromReductions(Instruction &I) {
3074 auto *II = dyn_cast<IntrinsicInst>(Val: &I);
3075 if (!II)
3076 return false;
3077
3078 bool TruncOnly = false;
3079 Intrinsic::ID IID = II->getIntrinsicID();
3080 switch (IID) {
3081 case Intrinsic::vector_reduce_add:
3082 case Intrinsic::vector_reduce_mul:
3083 TruncOnly = true;
3084 break;
3085 case Intrinsic::vector_reduce_and:
3086 case Intrinsic::vector_reduce_or:
3087 case Intrinsic::vector_reduce_xor:
3088 break;
3089 default:
3090 return false;
3091 }
3092
3093 unsigned ReductionOpc = getArithmeticReductionInstruction(RdxID: IID);
3094 Value *ReductionSrc = I.getOperand(i: 0);
3095
3096 Value *Src;
3097 if (!match(V: ReductionSrc, P: m_OneUse(SubPattern: m_Trunc(Op: m_Value(V&: Src)))) &&
3098 (TruncOnly || !match(V: ReductionSrc, P: m_OneUse(SubPattern: m_ZExtOrSExt(Op: m_Value(V&: Src))))))
3099 return false;
3100
3101 auto CastOpc =
3102 (Instruction::CastOps)cast<Instruction>(Val: ReductionSrc)->getOpcode();
3103
3104 auto *SrcTy = cast<VectorType>(Val: Src->getType());
3105 auto *ReductionSrcTy = cast<VectorType>(Val: ReductionSrc->getType());
3106 Type *ResultTy = I.getType();
3107
3108 InstructionCost OldCost = TTI.getArithmeticReductionCost(
3109 Opcode: ReductionOpc, Ty: ReductionSrcTy, FMF: std::nullopt, CostKind);
3110 OldCost += TTI.getCastInstrCost(Opcode: CastOpc, Dst: ReductionSrcTy, Src: SrcTy,
3111 CCH: TTI::CastContextHint::None, CostKind,
3112 I: cast<CastInst>(Val: ReductionSrc));
3113 InstructionCost NewCost =
3114 TTI.getArithmeticReductionCost(Opcode: ReductionOpc, Ty: SrcTy, FMF: std::nullopt,
3115 CostKind) +
3116 TTI.getCastInstrCost(Opcode: CastOpc, Dst: ResultTy, Src: ReductionSrcTy->getScalarType(),
3117 CCH: TTI::CastContextHint::None, CostKind);
3118
3119 if (OldCost <= NewCost || !NewCost.isValid())
3120 return false;
3121
3122 Value *NewReduction = Builder.CreateIntrinsic(RetTy: SrcTy->getScalarType(),
3123 ID: II->getIntrinsicID(), Args: {Src});
3124 Value *NewCast = Builder.CreateCast(Op: CastOpc, V: NewReduction, DestTy: ResultTy);
3125 replaceValue(Old&: I, New&: *NewCast);
3126 return true;
3127}
3128
3129/// This method looks for groups of shuffles acting on binops, of the form:
3130/// %x = shuffle ...
3131/// %y = shuffle ...
3132/// %a = binop %x, %y
3133/// %b = binop %x, %y
3134/// shuffle %a, %b, selectmask
3135/// We may, especially if the shuffle is wider than legal, be able to convert
3136/// the shuffle to a form where only parts of a and b need to be computed. On
3137/// architectures with no obvious "select" shuffle, this can reduce the total
3138/// number of operations if the target reports them as cheaper.
3139bool VectorCombine::foldSelectShuffle(Instruction &I, bool FromReduction) {
3140 auto *SVI = cast<ShuffleVectorInst>(Val: &I);
3141 auto *VT = cast<FixedVectorType>(Val: I.getType());
3142 auto *Op0 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 0));
3143 auto *Op1 = dyn_cast<Instruction>(Val: SVI->getOperand(i_nocapture: 1));
3144 if (!Op0 || !Op1 || Op0 == Op1 || !Op0->isBinaryOp() || !Op1->isBinaryOp() ||
3145 VT != Op0->getType())
3146 return false;
3147
3148 auto *SVI0A = dyn_cast<Instruction>(Val: Op0->getOperand(i: 0));
3149 auto *SVI0B = dyn_cast<Instruction>(Val: Op0->getOperand(i: 1));
3150 auto *SVI1A = dyn_cast<Instruction>(Val: Op1->getOperand(i: 0));
3151 auto *SVI1B = dyn_cast<Instruction>(Val: Op1->getOperand(i: 1));
3152 SmallPtrSet<Instruction *, 4> InputShuffles({SVI0A, SVI0B, SVI1A, SVI1B});
3153 auto checkSVNonOpUses = [&](Instruction *I) {
3154 if (!I || I->getOperand(i: 0)->getType() != VT)
3155 return true;
3156 return any_of(Range: I->users(), P: [&](User *U) {
3157 return U != Op0 && U != Op1 &&
3158 !(isa<ShuffleVectorInst>(Val: U) &&
3159 (InputShuffles.contains(Ptr: cast<Instruction>(Val: U)) ||
3160 isInstructionTriviallyDead(I: cast<Instruction>(Val: U))));
3161 });
3162 };
3163 if (checkSVNonOpUses(SVI0A) || checkSVNonOpUses(SVI0B) ||
3164 checkSVNonOpUses(SVI1A) || checkSVNonOpUses(SVI1B))
3165 return false;
3166
3167 // Collect all the uses that are shuffles that we can transform together. We
3168 // may not have a single shuffle, but a group that can all be transformed
3169 // together profitably.
3170 SmallVector<ShuffleVectorInst *> Shuffles;
3171 auto collectShuffles = [&](Instruction *I) {
3172 for (auto *U : I->users()) {
3173 auto *SV = dyn_cast<ShuffleVectorInst>(Val: U);
3174 if (!SV || SV->getType() != VT)
3175 return false;
3176 if ((SV->getOperand(i_nocapture: 0) != Op0 && SV->getOperand(i_nocapture: 0) != Op1) ||
3177 (SV->getOperand(i_nocapture: 1) != Op0 && SV->getOperand(i_nocapture: 1) != Op1))
3178 return false;
3179 if (!llvm::is_contained(Range&: Shuffles, Element: SV))
3180 Shuffles.push_back(Elt: SV);
3181 }
3182 return true;
3183 };
3184 if (!collectShuffles(Op0) || !collectShuffles(Op1))
3185 return false;
3186 // From a reduction, we need to be processing a single shuffle, otherwise the
3187 // other uses will not be lane-invariant.
3188 if (FromReduction && Shuffles.size() > 1)
3189 return false;
3190
3191 // Add any shuffle uses for the shuffles we have found, to include them in our
3192 // cost calculations.
3193 if (!FromReduction) {
3194 for (ShuffleVectorInst *SV : Shuffles) {
3195 for (auto *U : SV->users()) {
3196 ShuffleVectorInst *SSV = dyn_cast<ShuffleVectorInst>(Val: U);
3197 if (SSV && isa<UndefValue>(Val: SSV->getOperand(i_nocapture: 1)) && SSV->getType() == VT)
3198 Shuffles.push_back(Elt: SSV);
3199 }
3200 }
3201 }
3202
3203 // For each of the output shuffles, we try to sort all the first vector
3204 // elements to the beginning, followed by the second array elements at the
3205 // end. If the binops are legalized to smaller vectors, this may reduce total
3206 // number of binops. We compute the ReconstructMask mask needed to convert
3207 // back to the original lane order.
3208 SmallVector<std::pair<int, int>> V1, V2;
3209 SmallVector<SmallVector<int>> OrigReconstructMasks;
3210 int MaxV1Elt = 0, MaxV2Elt = 0;
3211 unsigned NumElts = VT->getNumElements();
3212 for (ShuffleVectorInst *SVN : Shuffles) {
3213 SmallVector<int> Mask;
3214 SVN->getShuffleMask(Result&: Mask);
3215
3216 // Check the operands are the same as the original, or reversed (in which
3217 // case we need to commute the mask).
3218 Value *SVOp0 = SVN->getOperand(i_nocapture: 0);
3219 Value *SVOp1 = SVN->getOperand(i_nocapture: 1);
3220 if (isa<UndefValue>(Val: SVOp1)) {
3221 auto *SSV = cast<ShuffleVectorInst>(Val: SVOp0);
3222 SVOp0 = SSV->getOperand(i_nocapture: 0);
3223 SVOp1 = SSV->getOperand(i_nocapture: 1);
3224 for (int &Elem : Mask) {
3225 if (Elem >= static_cast<int>(SSV->getShuffleMask().size()))
3226 return false;
3227 Elem = Elem < 0 ? Elem : SSV->getMaskValue(Elt: Elem);
3228 }
3229 }
3230 if (SVOp0 == Op1 && SVOp1 == Op0) {
3231 std::swap(a&: SVOp0, b&: SVOp1);
3232 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumElts);
3233 }
3234 if (SVOp0 != Op0 || SVOp1 != Op1)
3235 return false;
3236
3237 // Calculate the reconstruction mask for this shuffle, as the mask needed to
3238 // take the packed values from Op0/Op1 and reconstructing to the original
3239 // order.
3240 SmallVector<int> ReconstructMask;
3241 for (unsigned I = 0; I < Mask.size(); I++) {
3242 if (Mask[I] < 0) {
3243 ReconstructMask.push_back(Elt: -1);
3244 } else if (Mask[I] < static_cast<int>(NumElts)) {
3245 MaxV1Elt = std::max(a: MaxV1Elt, b: Mask[I]);
3246 auto It = find_if(Range&: V1, P: [&](const std::pair<int, int> &A) {
3247 return Mask[I] == A.first;
3248 });
3249 if (It != V1.end())
3250 ReconstructMask.push_back(Elt: It - V1.begin());
3251 else {
3252 ReconstructMask.push_back(Elt: V1.size());
3253 V1.emplace_back(Args&: Mask[I], Args: V1.size());
3254 }
3255 } else {
3256 MaxV2Elt = std::max<int>(a: MaxV2Elt, b: Mask[I] - NumElts);
3257 auto It = find_if(Range&: V2, P: [&](const std::pair<int, int> &A) {
3258 return Mask[I] - static_cast<int>(NumElts) == A.first;
3259 });
3260 if (It != V2.end())
3261 ReconstructMask.push_back(Elt: NumElts + It - V2.begin());
3262 else {
3263 ReconstructMask.push_back(Elt: NumElts + V2.size());
3264 V2.emplace_back(Args: Mask[I] - NumElts, Args: NumElts + V2.size());
3265 }
3266 }
3267 }
3268
3269 // For reductions, we know that the lane ordering out doesn't alter the
3270 // result. In-order can help simplify the shuffle away.
3271 if (FromReduction)
3272 sort(C&: ReconstructMask);
3273 OrigReconstructMasks.push_back(Elt: std::move(ReconstructMask));
3274 }
3275
3276 // If the Maximum element used from V1 and V2 are not larger than the new
3277 // vectors, the vectors are already packes and performing the optimization
3278 // again will likely not help any further. This also prevents us from getting
3279 // stuck in a cycle in case the costs do not also rule it out.
3280 if (V1.empty() || V2.empty() ||
3281 (MaxV1Elt == static_cast<int>(V1.size()) - 1 &&
3282 MaxV2Elt == static_cast<int>(V2.size()) - 1))
3283 return false;
3284
3285 // GetBaseMaskValue takes one of the inputs, which may either be a shuffle, a
3286 // shuffle of another shuffle, or not a shuffle (that is treated like a
3287 // identity shuffle).
3288 auto GetBaseMaskValue = [&](Instruction *I, int M) {
3289 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
3290 if (!SV)
3291 return M;
3292 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
3293 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
3294 if (InputShuffles.contains(Ptr: SSV))
3295 return SSV->getMaskValue(Elt: SV->getMaskValue(Elt: M));
3296 return SV->getMaskValue(Elt: M);
3297 };
3298
3299 // Attempt to sort the inputs my ascending mask values to make simpler input
3300 // shuffles and push complex shuffles down to the uses. We sort on the first
3301 // of the two input shuffle orders, to try and get at least one input into a
3302 // nice order.
3303 auto SortBase = [&](Instruction *A, std::pair<int, int> X,
3304 std::pair<int, int> Y) {
3305 int MXA = GetBaseMaskValue(A, X.first);
3306 int MYA = GetBaseMaskValue(A, Y.first);
3307 return MXA < MYA;
3308 };
3309 stable_sort(Range&: V1, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
3310 return SortBase(SVI0A, A, B);
3311 });
3312 stable_sort(Range&: V2, C: [&](std::pair<int, int> A, std::pair<int, int> B) {
3313 return SortBase(SVI1A, A, B);
3314 });
3315 // Calculate our ReconstructMasks from the OrigReconstructMasks and the
3316 // modified order of the input shuffles.
3317 SmallVector<SmallVector<int>> ReconstructMasks;
3318 for (const auto &Mask : OrigReconstructMasks) {
3319 SmallVector<int> ReconstructMask;
3320 for (int M : Mask) {
3321 auto FindIndex = [](const SmallVector<std::pair<int, int>> &V, int M) {
3322 auto It = find_if(Range: V, P: [M](auto A) { return A.second == M; });
3323 assert(It != V.end() && "Expected all entries in Mask");
3324 return std::distance(first: V.begin(), last: It);
3325 };
3326 if (M < 0)
3327 ReconstructMask.push_back(Elt: -1);
3328 else if (M < static_cast<int>(NumElts)) {
3329 ReconstructMask.push_back(Elt: FindIndex(V1, M));
3330 } else {
3331 ReconstructMask.push_back(Elt: NumElts + FindIndex(V2, M));
3332 }
3333 }
3334 ReconstructMasks.push_back(Elt: std::move(ReconstructMask));
3335 }
3336
3337 // Calculate the masks needed for the new input shuffles, which get padded
3338 // with undef
3339 SmallVector<int> V1A, V1B, V2A, V2B;
3340 for (unsigned I = 0; I < V1.size(); I++) {
3341 V1A.push_back(Elt: GetBaseMaskValue(SVI0A, V1[I].first));
3342 V1B.push_back(Elt: GetBaseMaskValue(SVI0B, V1[I].first));
3343 }
3344 for (unsigned I = 0; I < V2.size(); I++) {
3345 V2A.push_back(Elt: GetBaseMaskValue(SVI1A, V2[I].first));
3346 V2B.push_back(Elt: GetBaseMaskValue(SVI1B, V2[I].first));
3347 }
3348 while (V1A.size() < NumElts) {
3349 V1A.push_back(Elt: PoisonMaskElem);
3350 V1B.push_back(Elt: PoisonMaskElem);
3351 }
3352 while (V2A.size() < NumElts) {
3353 V2A.push_back(Elt: PoisonMaskElem);
3354 V2B.push_back(Elt: PoisonMaskElem);
3355 }
3356
3357 auto AddShuffleCost = [&](InstructionCost C, Instruction *I) {
3358 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
3359 if (!SV)
3360 return C;
3361 return C + TTI.getShuffleCost(Kind: isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1))
3362 ? TTI::SK_PermuteSingleSrc
3363 : TTI::SK_PermuteTwoSrc,
3364 DstTy: VT, SrcTy: VT, Mask: SV->getShuffleMask(), CostKind);
3365 };
3366 auto AddShuffleMaskCost = [&](InstructionCost C, ArrayRef<int> Mask) {
3367 return C +
3368 TTI.getShuffleCost(Kind: TTI::SK_PermuteTwoSrc, DstTy: VT, SrcTy: VT, Mask, CostKind);
3369 };
3370
3371 // Get the costs of the shuffles + binops before and after with the new
3372 // shuffle masks.
3373 InstructionCost CostBefore =
3374 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: VT, CostKind) +
3375 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: VT, CostKind);
3376 CostBefore += std::accumulate(first: Shuffles.begin(), last: Shuffles.end(),
3377 init: InstructionCost(0), binary_op: AddShuffleCost);
3378 CostBefore += std::accumulate(first: InputShuffles.begin(), last: InputShuffles.end(),
3379 init: InstructionCost(0), binary_op: AddShuffleCost);
3380
3381 // The new binops will be unused for lanes past the used shuffle lengths.
3382 // These types attempt to get the correct cost for that from the target.
3383 FixedVectorType *Op0SmallVT =
3384 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V1.size());
3385 FixedVectorType *Op1SmallVT =
3386 FixedVectorType::get(ElementType: VT->getScalarType(), NumElts: V2.size());
3387 InstructionCost CostAfter =
3388 TTI.getArithmeticInstrCost(Opcode: Op0->getOpcode(), Ty: Op0SmallVT, CostKind) +
3389 TTI.getArithmeticInstrCost(Opcode: Op1->getOpcode(), Ty: Op1SmallVT, CostKind);
3390 CostAfter += std::accumulate(first: ReconstructMasks.begin(), last: ReconstructMasks.end(),
3391 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
3392 std::set<SmallVector<int>> OutputShuffleMasks({V1A, V1B, V2A, V2B});
3393 CostAfter +=
3394 std::accumulate(first: OutputShuffleMasks.begin(), last: OutputShuffleMasks.end(),
3395 init: InstructionCost(0), binary_op: AddShuffleMaskCost);
3396
3397 LLVM_DEBUG(dbgs() << "Found a binop select shuffle pattern: " << I << "\n");
3398 LLVM_DEBUG(dbgs() << " CostBefore: " << CostBefore
3399 << " vs CostAfter: " << CostAfter << "\n");
3400 if (CostBefore <= CostAfter)
3401 return false;
3402
3403 // The cost model has passed, create the new instructions.
3404 auto GetShuffleOperand = [&](Instruction *I, unsigned Op) -> Value * {
3405 auto *SV = dyn_cast<ShuffleVectorInst>(Val: I);
3406 if (!SV)
3407 return I;
3408 if (isa<UndefValue>(Val: SV->getOperand(i_nocapture: 1)))
3409 if (auto *SSV = dyn_cast<ShuffleVectorInst>(Val: SV->getOperand(i_nocapture: 0)))
3410 if (InputShuffles.contains(Ptr: SSV))
3411 return SSV->getOperand(i_nocapture: Op);
3412 return SV->getOperand(i_nocapture: Op);
3413 };
3414 Builder.SetInsertPoint(*SVI0A->getInsertionPointAfterDef());
3415 Value *NSV0A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0A, 0),
3416 V2: GetShuffleOperand(SVI0A, 1), Mask: V1A);
3417 Builder.SetInsertPoint(*SVI0B->getInsertionPointAfterDef());
3418 Value *NSV0B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI0B, 0),
3419 V2: GetShuffleOperand(SVI0B, 1), Mask: V1B);
3420 Builder.SetInsertPoint(*SVI1A->getInsertionPointAfterDef());
3421 Value *NSV1A = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1A, 0),
3422 V2: GetShuffleOperand(SVI1A, 1), Mask: V2A);
3423 Builder.SetInsertPoint(*SVI1B->getInsertionPointAfterDef());
3424 Value *NSV1B = Builder.CreateShuffleVector(V1: GetShuffleOperand(SVI1B, 0),
3425 V2: GetShuffleOperand(SVI1B, 1), Mask: V2B);
3426 Builder.SetInsertPoint(Op0);
3427 Value *NOp0 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op0->getOpcode(),
3428 LHS: NSV0A, RHS: NSV0B);
3429 if (auto *I = dyn_cast<Instruction>(Val: NOp0))
3430 I->copyIRFlags(V: Op0, IncludeWrapFlags: true);
3431 Builder.SetInsertPoint(Op1);
3432 Value *NOp1 = Builder.CreateBinOp(Opc: (Instruction::BinaryOps)Op1->getOpcode(),
3433 LHS: NSV1A, RHS: NSV1B);
3434 if (auto *I = dyn_cast<Instruction>(Val: NOp1))
3435 I->copyIRFlags(V: Op1, IncludeWrapFlags: true);
3436
3437 for (int S = 0, E = ReconstructMasks.size(); S != E; S++) {
3438 Builder.SetInsertPoint(Shuffles[S]);
3439 Value *NSV = Builder.CreateShuffleVector(V1: NOp0, V2: NOp1, Mask: ReconstructMasks[S]);
3440 replaceValue(Old&: *Shuffles[S], New&: *NSV);
3441 }
3442
3443 Worklist.pushValue(V: NSV0A);
3444 Worklist.pushValue(V: NSV0B);
3445 Worklist.pushValue(V: NSV1A);
3446 Worklist.pushValue(V: NSV1B);
3447 return true;
3448}
3449
3450/// Check if instruction depends on ZExt and this ZExt can be moved after the
3451/// instruction. Move ZExt if it is profitable. For example:
3452/// logic(zext(x),y) -> zext(logic(x,trunc(y)))
3453/// lshr((zext(x),y) -> zext(lshr(x,trunc(y)))
3454/// Cost model calculations takes into account if zext(x) has other users and
3455/// whether it can be propagated through them too.
3456bool VectorCombine::shrinkType(Instruction &I) {
3457 Value *ZExted, *OtherOperand;
3458 if (!match(V: &I, P: m_c_BitwiseLogic(L: m_ZExt(Op: m_Value(V&: ZExted)),
3459 R: m_Value(V&: OtherOperand))) &&
3460 !match(V: &I, P: m_LShr(L: m_ZExt(Op: m_Value(V&: ZExted)), R: m_Value(V&: OtherOperand))))
3461 return false;
3462
3463 Value *ZExtOperand = I.getOperand(i: I.getOperand(i: 0) == OtherOperand ? 1 : 0);
3464
3465 auto *BigTy = cast<FixedVectorType>(Val: I.getType());
3466 auto *SmallTy = cast<FixedVectorType>(Val: ZExted->getType());
3467 unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
3468
3469 if (I.getOpcode() == Instruction::LShr) {
3470 // Check that the shift amount is less than the number of bits in the
3471 // smaller type. Otherwise, the smaller lshr will return a poison value.
3472 KnownBits ShAmtKB = computeKnownBits(V: I.getOperand(i: 1), DL: *DL);
3473 if (ShAmtKB.getMaxValue().uge(RHS: BW))
3474 return false;
3475 } else {
3476 // Check that the expression overall uses at most the same number of bits as
3477 // ZExted
3478 KnownBits KB = computeKnownBits(V: &I, DL: *DL);
3479 if (KB.countMaxActiveBits() > BW)
3480 return false;
3481 }
3482
3483 // Calculate costs of leaving current IR as it is and moving ZExt operation
3484 // later, along with adding truncates if needed
3485 InstructionCost ZExtCost = TTI.getCastInstrCost(
3486 Opcode: Instruction::ZExt, Dst: BigTy, Src: SmallTy,
3487 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
3488 InstructionCost CurrentCost = ZExtCost;
3489 InstructionCost ShrinkCost = 0;
3490
3491 // Calculate total cost and check that we can propagate through all ZExt users
3492 for (User *U : ZExtOperand->users()) {
3493 auto *UI = cast<Instruction>(Val: U);
3494 if (UI == &I) {
3495 CurrentCost +=
3496 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
3497 ShrinkCost +=
3498 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
3499 ShrinkCost += ZExtCost;
3500 continue;
3501 }
3502
3503 if (!Instruction::isBinaryOp(Opcode: UI->getOpcode()))
3504 return false;
3505
3506 // Check if we can propagate ZExt through its other users
3507 KnownBits KB = computeKnownBits(V: UI, DL: *DL);
3508 if (KB.countMaxActiveBits() > BW)
3509 return false;
3510
3511 CurrentCost += TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: BigTy, CostKind);
3512 ShrinkCost +=
3513 TTI.getArithmeticInstrCost(Opcode: UI->getOpcode(), Ty: SmallTy, CostKind);
3514 ShrinkCost += ZExtCost;
3515 }
3516
3517 // If the other instruction operand is not a constant, we'll need to
3518 // generate a truncate instruction. So we have to adjust cost
3519 if (!isa<Constant>(Val: OtherOperand))
3520 ShrinkCost += TTI.getCastInstrCost(
3521 Opcode: Instruction::Trunc, Dst: SmallTy, Src: BigTy,
3522 CCH: TargetTransformInfo::CastContextHint::None, CostKind);
3523
3524 // If the cost of shrinking types and leaving the IR is the same, we'll lean
3525 // towards modifying the IR because shrinking opens opportunities for other
3526 // shrinking optimisations.
3527 if (ShrinkCost > CurrentCost)
3528 return false;
3529
3530 Builder.SetInsertPoint(&I);
3531 Value *Op0 = ZExted;
3532 Value *Op1 = Builder.CreateTrunc(V: OtherOperand, DestTy: SmallTy);
3533 // Keep the order of operands the same
3534 if (I.getOperand(i: 0) == OtherOperand)
3535 std::swap(a&: Op0, b&: Op1);
3536 Value *NewBinOp =
3537 Builder.CreateBinOp(Opc: (Instruction::BinaryOps)I.getOpcode(), LHS: Op0, RHS: Op1);
3538 cast<Instruction>(Val: NewBinOp)->copyIRFlags(V: &I);
3539 cast<Instruction>(Val: NewBinOp)->copyMetadata(SrcInst: I);
3540 Value *NewZExtr = Builder.CreateZExt(V: NewBinOp, DestTy: BigTy);
3541 replaceValue(Old&: I, New&: *NewZExtr);
3542 return true;
3543}
3544
3545/// insert (DstVec, (extract SrcVec, ExtIdx), InsIdx) -->
3546/// shuffle (DstVec, SrcVec, Mask)
3547bool VectorCombine::foldInsExtVectorToShuffle(Instruction &I) {
3548 Value *DstVec, *SrcVec;
3549 uint64_t ExtIdx, InsIdx;
3550 if (!match(V: &I,
3551 P: m_InsertElt(Val: m_Value(V&: DstVec),
3552 Elt: m_ExtractElt(Val: m_Value(V&: SrcVec), Idx: m_ConstantInt(V&: ExtIdx)),
3553 Idx: m_ConstantInt(V&: InsIdx))))
3554 return false;
3555
3556 auto *DstVecTy = dyn_cast<FixedVectorType>(Val: I.getType());
3557 auto *SrcVecTy = dyn_cast<FixedVectorType>(Val: SrcVec->getType());
3558 // We can try combining vectors with different element sizes.
3559 if (!DstVecTy || !SrcVecTy ||
3560 SrcVecTy->getElementType() != DstVecTy->getElementType())
3561 return false;
3562
3563 unsigned NumDstElts = DstVecTy->getNumElements();
3564 unsigned NumSrcElts = SrcVecTy->getNumElements();
3565 if (InsIdx >= NumDstElts || ExtIdx >= NumSrcElts || NumDstElts == 1)
3566 return false;
3567
3568 // Insertion into poison is a cheaper single operand shuffle.
3569 TargetTransformInfo::ShuffleKind SK;
3570 SmallVector<int> Mask(NumDstElts, PoisonMaskElem);
3571
3572 bool NeedExpOrNarrow = NumSrcElts != NumDstElts;
3573 bool IsExtIdxInBounds = ExtIdx < NumDstElts;
3574 bool NeedDstSrcSwap = isa<PoisonValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec);
3575 if (NeedDstSrcSwap) {
3576 SK = TargetTransformInfo::SK_PermuteSingleSrc;
3577 if (!IsExtIdxInBounds && NeedExpOrNarrow)
3578 Mask[InsIdx] = 0;
3579 else
3580 Mask[InsIdx] = ExtIdx;
3581 std::swap(a&: DstVec, b&: SrcVec);
3582 } else {
3583 SK = TargetTransformInfo::SK_PermuteTwoSrc;
3584 std::iota(first: Mask.begin(), last: Mask.end(), value: 0);
3585 if (!IsExtIdxInBounds && NeedExpOrNarrow)
3586 Mask[InsIdx] = NumDstElts;
3587 else
3588 Mask[InsIdx] = ExtIdx + NumDstElts;
3589 }
3590
3591 // Cost
3592 auto *Ins = cast<InsertElementInst>(Val: &I);
3593 auto *Ext = cast<ExtractElementInst>(Val: I.getOperand(i: 1));
3594 InstructionCost InsCost =
3595 TTI.getVectorInstrCost(I: *Ins, Val: DstVecTy, CostKind, Index: InsIdx);
3596 InstructionCost ExtCost =
3597 TTI.getVectorInstrCost(I: *Ext, Val: DstVecTy, CostKind, Index: ExtIdx);
3598 InstructionCost OldCost = ExtCost + InsCost;
3599
3600 InstructionCost NewCost = 0;
3601 SmallVector<int> ExtToVecMask;
3602 if (!NeedExpOrNarrow) {
3603 // Ignore 'free' identity insertion shuffle.
3604 // TODO: getShuffleCost should return TCC_Free for Identity shuffles.
3605 if (!ShuffleVectorInst::isIdentityMask(Mask, NumSrcElts))
3606 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind, Index: 0,
3607 SubTp: nullptr, Args: {DstVec, SrcVec});
3608 } else {
3609 // When creating length-changing-vector, always create with a Mask whose
3610 // first element has an ExtIdx, so that the first element of the vector
3611 // being created is always the target to be extracted.
3612 ExtToVecMask.assign(NumElts: NumDstElts, Elt: PoisonMaskElem);
3613 if (IsExtIdxInBounds)
3614 ExtToVecMask[ExtIdx] = ExtIdx;
3615 else
3616 ExtToVecMask[0] = ExtIdx;
3617 // Add cost for expanding or narrowing
3618 NewCost = TTI.getShuffleCost(Kind: TargetTransformInfo::SK_PermuteSingleSrc,
3619 DstTy: DstVecTy, SrcTy: SrcVecTy, Mask: ExtToVecMask, CostKind);
3620 NewCost += TTI.getShuffleCost(Kind: SK, DstTy: DstVecTy, SrcTy: DstVecTy, Mask, CostKind);
3621 }
3622
3623 if (!Ext->hasOneUse())
3624 NewCost += ExtCost;
3625
3626 LLVM_DEBUG(dbgs() << "Found a insert/extract shuffle-like pair: " << I
3627 << "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
3628 << "\n");
3629
3630 if (OldCost < NewCost)
3631 return false;
3632
3633 if (NeedExpOrNarrow) {
3634 if (!NeedDstSrcSwap)
3635 SrcVec = Builder.CreateShuffleVector(V: SrcVec, Mask: ExtToVecMask);
3636 else
3637 DstVec = Builder.CreateShuffleVector(V: DstVec, Mask: ExtToVecMask);
3638 }
3639
3640 // Canonicalize undef param to RHS to help further folds.
3641 if (isa<UndefValue>(Val: DstVec) && !isa<UndefValue>(Val: SrcVec)) {
3642 ShuffleVectorInst::commuteShuffleMask(Mask, InVecNumElts: NumDstElts);
3643 std::swap(a&: DstVec, b&: SrcVec);
3644 }
3645
3646 Value *Shuf = Builder.CreateShuffleVector(V1: DstVec, V2: SrcVec, Mask);
3647 replaceValue(Old&: I, New&: *Shuf);
3648
3649 return true;
3650}
3651
3652/// If we're interleaving 2 constant splats, for instance `<vscale x 8 x i32>
3653/// <splat of 666>` and `<vscale x 8 x i32> <splat of 777>`, we can create a
3654/// larger splat `<vscale x 8 x i64> <splat of ((777 << 32) | 666)>` first
3655/// before casting it back into `<vscale x 16 x i32>`.
3656bool VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
3657 const APInt *SplatVal0, *SplatVal1;
3658 if (!match(V: &I, P: m_Intrinsic<Intrinsic::vector_interleave2>(
3659 Op0: m_APInt(Res&: SplatVal0), Op1: m_APInt(Res&: SplatVal1))))
3660 return false;
3661
3662 LLVM_DEBUG(dbgs() << "VC: Folding interleave2 with two splats: " << I
3663 << "\n");
3664
3665 auto *VTy =
3666 cast<VectorType>(Val: cast<IntrinsicInst>(Val&: I).getArgOperand(i: 0)->getType());
3667 auto *ExtVTy = VectorType::getExtendedElementVectorType(VTy);
3668 unsigned Width = VTy->getElementType()->getIntegerBitWidth();
3669
3670 // Just in case the cost of interleave2 intrinsic and bitcast are both
3671 // invalid, in which case we want to bail out, we use <= rather
3672 // than < here. Even they both have valid and equal costs, it's probably
3673 // not a good idea to emit a high-cost constant splat.
3674 if (TTI.getInstructionCost(U: &I, CostKind) <=
3675 TTI.getCastInstrCost(Opcode: Instruction::BitCast, Dst: I.getType(), Src: ExtVTy,
3676 CCH: TTI::CastContextHint::None, CostKind)) {
3677 LLVM_DEBUG(dbgs() << "VC: The cost to cast from " << *ExtVTy << " to "
3678 << *I.getType() << " is too high.\n");
3679 return false;
3680 }
3681
3682 APInt NewSplatVal = SplatVal1->zext(width: Width * 2);
3683 NewSplatVal <<= Width;
3684 NewSplatVal |= SplatVal0->zext(width: Width * 2);
3685 auto *NewSplat = ConstantVector::getSplat(
3686 EC: ExtVTy->getElementCount(), Elt: ConstantInt::get(Context&: F.getContext(), V: NewSplatVal));
3687
3688 IRBuilder<> Builder(&I);
3689 replaceValue(Old&: I, New&: *Builder.CreateBitCast(V: NewSplat, DestTy: I.getType()));
3690 return true;
3691}
3692
3693/// This is the entry point for all transforms. Pass manager differences are
3694/// handled in the callers of this function.
3695bool VectorCombine::run() {
3696 if (DisableVectorCombine)
3697 return false;
3698
3699 // Don't attempt vectorization if the target does not support vectors.
3700 if (!TTI.getNumberOfRegisters(ClassID: TTI.getRegisterClassForType(/*Vector*/ true)))
3701 return false;
3702
3703 LLVM_DEBUG(dbgs() << "\n\nVECTORCOMBINE on " << F.getName() << "\n");
3704
3705 bool MadeChange = false;
3706 auto FoldInst = [this, &MadeChange](Instruction &I) {
3707 Builder.SetInsertPoint(&I);
3708 bool IsVectorType = isa<VectorType>(Val: I.getType());
3709 bool IsFixedVectorType = isa<FixedVectorType>(Val: I.getType());
3710 auto Opcode = I.getOpcode();
3711
3712 LLVM_DEBUG(dbgs() << "VC: Visiting: " << I << '\n');
3713
3714 // These folds should be beneficial regardless of when this pass is run
3715 // in the optimization pipeline.
3716 // The type checking is for run-time efficiency. We can avoid wasting time
3717 // dispatching to folding functions if there's no chance of matching.
3718 if (IsFixedVectorType) {
3719 switch (Opcode) {
3720 case Instruction::InsertElement:
3721 MadeChange |= vectorizeLoadInsert(I);
3722 break;
3723 case Instruction::ShuffleVector:
3724 MadeChange |= widenSubvectorLoad(I);
3725 break;
3726 default:
3727 break;
3728 }
3729 }
3730
3731 // This transform works with scalable and fixed vectors
3732 // TODO: Identify and allow other scalable transforms
3733 if (IsVectorType) {
3734 MadeChange |= scalarizeOpOrCmp(I);
3735 MadeChange |= scalarizeLoadExtract(I);
3736 MadeChange |= scalarizeExtExtract(I);
3737 MadeChange |= scalarizeVPIntrinsic(I);
3738 MadeChange |= foldInterleaveIntrinsics(I);
3739 }
3740
3741 if (Opcode == Instruction::Store)
3742 MadeChange |= foldSingleElementStore(I);
3743
3744 // If this is an early pipeline invocation of this pass, we are done.
3745 if (TryEarlyFoldsOnly)
3746 return;
3747
3748 // Otherwise, try folds that improve codegen but may interfere with
3749 // early IR canonicalizations.
3750 // The type checking is for run-time efficiency. We can avoid wasting time
3751 // dispatching to folding functions if there's no chance of matching.
3752 if (IsFixedVectorType) {
3753 switch (Opcode) {
3754 case Instruction::InsertElement:
3755 MadeChange |= foldInsExtFNeg(I);
3756 MadeChange |= foldInsExtBinop(I);
3757 MadeChange |= foldInsExtVectorToShuffle(I);
3758 break;
3759 case Instruction::ShuffleVector:
3760 MadeChange |= foldPermuteOfBinops(I);
3761 MadeChange |= foldShuffleOfBinops(I);
3762 MadeChange |= foldShuffleOfSelects(I);
3763 MadeChange |= foldShuffleOfCastops(I);
3764 MadeChange |= foldShuffleOfShuffles(I);
3765 MadeChange |= foldShuffleOfIntrinsics(I);
3766 MadeChange |= foldSelectShuffle(I);
3767 MadeChange |= foldShuffleToIdentity(I);
3768 break;
3769 case Instruction::BitCast:
3770 MadeChange |= foldBitcastShuffle(I);
3771 break;
3772 case Instruction::And:
3773 case Instruction::Or:
3774 case Instruction::Xor:
3775 MadeChange |= foldBitOpOfBitcasts(I);
3776 break;
3777 default:
3778 MadeChange |= shrinkType(I);
3779 break;
3780 }
3781 } else {
3782 switch (Opcode) {
3783 case Instruction::Call:
3784 MadeChange |= foldShuffleFromReductions(I);
3785 MadeChange |= foldCastFromReductions(I);
3786 break;
3787 case Instruction::ICmp:
3788 case Instruction::FCmp:
3789 MadeChange |= foldExtractExtract(I);
3790 break;
3791 case Instruction::Or:
3792 MadeChange |= foldConcatOfBoolMasks(I);
3793 [[fallthrough]];
3794 default:
3795 if (Instruction::isBinaryOp(Opcode)) {
3796 MadeChange |= foldExtractExtract(I);
3797 MadeChange |= foldExtractedCmps(I);
3798 MadeChange |= foldBinopOfReductions(I);
3799 }
3800 break;
3801 }
3802 }
3803 };
3804
3805 for (BasicBlock &BB : F) {
3806 // Ignore unreachable basic blocks.
3807 if (!DT.isReachableFromEntry(A: &BB))
3808 continue;
3809 // Use early increment range so that we can erase instructions in loop.
3810 for (Instruction &I : make_early_inc_range(Range&: BB)) {
3811 if (I.isDebugOrPseudoInst())
3812 continue;
3813 FoldInst(I);
3814 }
3815 }
3816
3817 while (!Worklist.isEmpty()) {
3818 Instruction *I = Worklist.removeOne();
3819 if (!I)
3820 continue;
3821
3822 if (isInstructionTriviallyDead(I)) {
3823 eraseInstruction(I&: *I);
3824 continue;
3825 }
3826
3827 FoldInst(*I);
3828 }
3829
3830 return MadeChange;
3831}
3832
3833PreservedAnalyses VectorCombinePass::run(Function &F,
3834 FunctionAnalysisManager &FAM) {
3835 auto &AC = FAM.getResult<AssumptionAnalysis>(IR&: F);
3836 TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(IR&: F);
3837 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(IR&: F);
3838 AAResults &AA = FAM.getResult<AAManager>(IR&: F);
3839 const DataLayout *DL = &F.getDataLayout();
3840 VectorCombine Combiner(F, TTI, DT, AA, AC, DL, TTI::TCK_RecipThroughput,
3841 TryEarlyFoldsOnly);
3842 if (!Combiner.run())
3843 return PreservedAnalyses::all();
3844 PreservedAnalyses PA;
3845 PA.preserveSet<CFGAnalyses>();
3846 return PA;
3847}
3848