1//===- Scalarizer.cpp - Scalarize vector operations -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass converts vector operations into scalar operations (or, optionally,
10// operations on smaller vector widths), in order to expose optimization
11// opportunities on the individual scalar operations.
12// It is mainly intended for targets that do not have vector units, but it
13// may also be useful for revectorizing code to different vector widths.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Transforms/Scalar/Scalarizer.h"
18#include "llvm/ADT/PostOrderIterator.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/Twine.h"
21#include "llvm/Analysis/TargetTransformInfo.h"
22#include "llvm/Analysis/VectorUtils.h"
23#include "llvm/IR/Argument.h"
24#include "llvm/IR/BasicBlock.h"
25#include "llvm/IR/Constants.h"
26#include "llvm/IR/DataLayout.h"
27#include "llvm/IR/DerivedTypes.h"
28#include "llvm/IR/Dominators.h"
29#include "llvm/IR/Function.h"
30#include "llvm/IR/IRBuilder.h"
31#include "llvm/IR/InstVisitor.h"
32#include "llvm/IR/InstrTypes.h"
33#include "llvm/IR/Instruction.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/IR/Intrinsics.h"
36#include "llvm/IR/LLVMContext.h"
37#include "llvm/IR/Module.h"
38#include "llvm/IR/Type.h"
39#include "llvm/IR/Value.h"
40#include "llvm/InitializePasses.h"
41#include "llvm/Support/Casting.h"
42#include "llvm/Transforms/Utils/Local.h"
43#include <cassert>
44#include <cstdint>
45#include <iterator>
46#include <map>
47#include <utility>
48
49using namespace llvm;
50
51#define DEBUG_TYPE "scalarizer"
52
53namespace {
54
55BasicBlock::iterator skipPastPhiNodesAndDbg(BasicBlock::iterator Itr) {
56 BasicBlock *BB = Itr->getParent();
57 if (isa<PHINode>(Val: Itr))
58 Itr = BB->getFirstInsertionPt();
59 if (Itr != BB->end())
60 Itr = skipDebugIntrinsics(It: Itr);
61 return Itr;
62}
63
64// Used to store the scattered form of a vector.
65using ValueVector = SmallVector<Value *, 8>;
66
67// Used to map a vector Value and associated type to its scattered form.
68// The associated type is only non-null for pointer values that are "scattered"
69// when used as pointer operands to load or store.
70//
71// We use std::map because we want iterators to persist across insertion and
72// because the values are relatively large.
73using ScatterMap = std::map<std::pair<Value *, Type *>, ValueVector>;
74
75// Lists Instructions that have been replaced with scalar implementations,
76// along with a pointer to their scattered forms.
77using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>;
78
79struct VectorSplit {
80 // The type of the vector.
81 FixedVectorType *VecTy = nullptr;
82
83 // The number of elements packed in a fragment (other than the remainder).
84 unsigned NumPacked = 0;
85
86 // The number of fragments (scalars or smaller vectors) into which the vector
87 // shall be split.
88 unsigned NumFragments = 0;
89
90 // The type of each complete fragment.
91 Type *SplitTy = nullptr;
92
93 // The type of the remainder (last) fragment; null if all fragments are
94 // complete.
95 Type *RemainderTy = nullptr;
96
97 Type *getFragmentType(unsigned I) const {
98 return RemainderTy && I == NumFragments - 1 ? RemainderTy : SplitTy;
99 }
100};
101
102// Provides a very limited vector-like interface for lazily accessing one
103// component of a scattered vector or vector pointer.
104class Scatterer {
105public:
106 Scatterer() = default;
107
108 // Scatter V into Size components. If new instructions are needed,
109 // insert them before BBI in BB. If Cache is nonnull, use it to cache
110 // the results.
111 Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
112 const VectorSplit &VS, ValueVector *cachePtr = nullptr);
113
114 // Return component I, creating a new Value for it if necessary.
115 Value *operator[](unsigned I);
116
117 // Return the number of components.
118 unsigned size() const { return VS.NumFragments; }
119
120private:
121 BasicBlock *BB;
122 BasicBlock::iterator BBI;
123 Value *V;
124 VectorSplit VS;
125 bool IsPointer;
126 ValueVector *CachePtr;
127 ValueVector Tmp;
128};
129
130// FCmpSplitter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp
131// called Name that compares X and Y in the same way as FCI.
132struct FCmpSplitter {
133 FCmpSplitter(FCmpInst &fci) : FCI(fci) {}
134
135 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
136 const Twine &Name) const {
137 return Builder.CreateFCmp(P: FCI.getPredicate(), LHS: Op0, RHS: Op1, Name);
138 }
139
140 FCmpInst &FCI;
141};
142
143// ICmpSplitter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp
144// called Name that compares X and Y in the same way as ICI.
145struct ICmpSplitter {
146 ICmpSplitter(ICmpInst &ici) : ICI(ici) {}
147
148 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
149 const Twine &Name) const {
150 return Builder.CreateICmp(P: ICI.getPredicate(), LHS: Op0, RHS: Op1, Name);
151 }
152
153 ICmpInst &ICI;
154};
155
156// UnarySplitter(UO)(Builder, X, Name) uses Builder to create
157// a unary operator like UO called Name with operand X.
158struct UnarySplitter {
159 UnarySplitter(UnaryOperator &uo) : UO(uo) {}
160
161 Value *operator()(IRBuilder<> &Builder, Value *Op, const Twine &Name) const {
162 return Builder.CreateUnOp(Opc: UO.getOpcode(), V: Op, Name);
163 }
164
165 UnaryOperator &UO;
166};
167
168// BinarySplitter(BO)(Builder, X, Y, Name) uses Builder to create
169// a binary operator like BO called Name with operands X and Y.
170struct BinarySplitter {
171 BinarySplitter(BinaryOperator &bo) : BO(bo) {}
172
173 Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1,
174 const Twine &Name) const {
175 return Builder.CreateBinOp(Opc: BO.getOpcode(), LHS: Op0, RHS: Op1, Name);
176 }
177
178 BinaryOperator &BO;
179};
180
181// Information about a load or store that we're scalarizing.
182struct VectorLayout {
183 VectorLayout() = default;
184
185 // Return the alignment of fragment Frag.
186 Align getFragmentAlign(unsigned Frag) {
187 return commonAlignment(A: VecAlign, Offset: Frag * SplitSize);
188 }
189
190 // The split of the underlying vector type.
191 VectorSplit VS;
192
193 // The alignment of the vector.
194 Align VecAlign;
195
196 // The size of each (non-remainder) fragment in bytes.
197 uint64_t SplitSize = 0;
198};
199
200static bool isStructOfMatchingFixedVectors(Type *Ty) {
201 if (!isa<StructType>(Val: Ty))
202 return false;
203 unsigned StructSize = Ty->getNumContainedTypes();
204 if (StructSize < 1)
205 return false;
206 FixedVectorType *VecTy = dyn_cast<FixedVectorType>(Val: Ty->getContainedType(i: 0));
207 if (!VecTy)
208 return false;
209 unsigned VecSize = VecTy->getNumElements();
210 for (unsigned I = 1; I < StructSize; I++) {
211 VecTy = dyn_cast<FixedVectorType>(Val: Ty->getContainedType(i: I));
212 if (!VecTy || VecSize != VecTy->getNumElements())
213 return false;
214 }
215 return true;
216}
217
218/// Concatenate the given fragments to a single vector value of the type
219/// described in @p VS.
220static Value *concatenate(IRBuilder<> &Builder, ArrayRef<Value *> Fragments,
221 const VectorSplit &VS, Twine Name) {
222 unsigned NumElements = VS.VecTy->getNumElements();
223 SmallVector<int> ExtendMask;
224 SmallVector<int> InsertMask;
225
226 if (VS.NumPacked > 1) {
227 // Prepare the shufflevector masks once and re-use them for all
228 // fragments.
229 ExtendMask.resize(N: NumElements, NV: -1);
230 for (unsigned I = 0; I < VS.NumPacked; ++I)
231 ExtendMask[I] = I;
232
233 InsertMask.resize(N: NumElements);
234 for (unsigned I = 0; I < NumElements; ++I)
235 InsertMask[I] = I;
236 }
237
238 Value *Res = PoisonValue::get(T: VS.VecTy);
239 for (unsigned I = 0; I < VS.NumFragments; ++I) {
240 Value *Fragment = Fragments[I];
241
242 unsigned NumPacked = VS.NumPacked;
243 if (I == VS.NumFragments - 1 && VS.RemainderTy) {
244 if (auto *RemVecTy = dyn_cast<FixedVectorType>(Val: VS.RemainderTy))
245 NumPacked = RemVecTy->getNumElements();
246 else
247 NumPacked = 1;
248 }
249
250 if (NumPacked == 1) {
251 Res = Builder.CreateInsertElement(Vec: Res, NewElt: Fragment, Idx: I * VS.NumPacked,
252 Name: Name + ".upto" + Twine(I));
253 } else {
254 Fragment = Builder.CreateShuffleVector(V1: Fragment, V2: Fragment, Mask: ExtendMask);
255 if (I == 0) {
256 Res = Fragment;
257 } else {
258 for (unsigned J = 0; J < NumPacked; ++J)
259 InsertMask[I * VS.NumPacked + J] = NumElements + J;
260 Res = Builder.CreateShuffleVector(V1: Res, V2: Fragment, Mask: InsertMask,
261 Name: Name + ".upto" + Twine(I));
262 for (unsigned J = 0; J < NumPacked; ++J)
263 InsertMask[I * VS.NumPacked + J] = I * VS.NumPacked + J;
264 }
265 }
266 }
267
268 return Res;
269}
270
271class ScalarizerVisitor : public InstVisitor<ScalarizerVisitor, bool> {
272public:
273 ScalarizerVisitor(DominatorTree *DT, const TargetTransformInfo *TTI,
274 ScalarizerPassOptions Options)
275 : DT(DT), TTI(TTI),
276 ScalarizeVariableInsertExtract(Options.ScalarizeVariableInsertExtract),
277 ScalarizeLoadStore(Options.ScalarizeLoadStore),
278 ScalarizeMinBits(Options.ScalarizeMinBits) {}
279
280 bool visit(Function &F);
281
282 // InstVisitor methods. They return true if the instruction was scalarized,
283 // false if nothing changed.
284 bool visitInstruction(Instruction &I) { return false; }
285 bool visitSelectInst(SelectInst &SI);
286 bool visitICmpInst(ICmpInst &ICI);
287 bool visitFCmpInst(FCmpInst &FCI);
288 bool visitUnaryOperator(UnaryOperator &UO);
289 bool visitBinaryOperator(BinaryOperator &BO);
290 bool visitGetElementPtrInst(GetElementPtrInst &GEPI);
291 bool visitCastInst(CastInst &CI);
292 bool visitBitCastInst(BitCastInst &BCI);
293 bool visitInsertElementInst(InsertElementInst &IEI);
294 bool visitExtractElementInst(ExtractElementInst &EEI);
295 bool visitExtractValueInst(ExtractValueInst &EVI);
296 bool visitShuffleVectorInst(ShuffleVectorInst &SVI);
297 bool visitPHINode(PHINode &PHI);
298 bool visitLoadInst(LoadInst &LI);
299 bool visitStoreInst(StoreInst &SI);
300 bool visitCallInst(CallInst &ICI);
301 bool visitFreezeInst(FreezeInst &FI);
302
303private:
304 Scatterer scatter(Instruction *Point, Value *V, const VectorSplit &VS);
305 void gather(Instruction *Op, const ValueVector &CV, const VectorSplit &VS);
306 void replaceUses(Instruction *Op, Value *CV);
307 bool canTransferMetadata(unsigned Kind);
308 void transferMetadataAndIRFlags(Instruction *Op, const ValueVector &CV);
309 std::optional<VectorSplit> getVectorSplit(Type *Ty);
310 std::optional<VectorLayout> getVectorLayout(Type *Ty, Align Alignment,
311 const DataLayout &DL);
312 bool finish();
313
314 template<typename T> bool splitUnary(Instruction &, const T &);
315 template<typename T> bool splitBinary(Instruction &, const T &);
316
317 bool splitCall(CallInst &CI);
318
319 ScatterMap Scattered;
320 GatherList Gathered;
321 bool Scalarized;
322
323 SmallVector<WeakTrackingVH, 32> PotentiallyDeadInstrs;
324
325 DominatorTree *DT;
326 const TargetTransformInfo *TTI;
327
328 const bool ScalarizeVariableInsertExtract;
329 const bool ScalarizeLoadStore;
330 const unsigned ScalarizeMinBits;
331};
332
333class ScalarizerLegacyPass : public FunctionPass {
334public:
335 static char ID;
336 ScalarizerPassOptions Options;
337 ScalarizerLegacyPass() : FunctionPass(ID), Options() {}
338 ScalarizerLegacyPass(const ScalarizerPassOptions &Options);
339 bool runOnFunction(Function &F) override;
340 void getAnalysisUsage(AnalysisUsage &AU) const override;
341};
342
343} // end anonymous namespace
344
345ScalarizerLegacyPass::ScalarizerLegacyPass(const ScalarizerPassOptions &Options)
346 : FunctionPass(ID), Options(Options) {}
347
348void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
349 AU.addRequired<DominatorTreeWrapperPass>();
350 AU.addRequired<TargetTransformInfoWrapperPass>();
351 AU.addPreserved<DominatorTreeWrapperPass>();
352}
353
354char ScalarizerLegacyPass::ID = 0;
355INITIALIZE_PASS_BEGIN(ScalarizerLegacyPass, "scalarizer",
356 "Scalarize vector operations", false, false)
357INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
358INITIALIZE_PASS_END(ScalarizerLegacyPass, "scalarizer",
359 "Scalarize vector operations", false, false)
360
361Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v,
362 const VectorSplit &VS, ValueVector *cachePtr)
363 : BB(bb), BBI(bbi), V(v), VS(VS), CachePtr(cachePtr) {
364 IsPointer = V->getType()->isPointerTy();
365 if (!CachePtr) {
366 Tmp.resize(N: VS.NumFragments, NV: nullptr);
367 } else {
368 assert((CachePtr->empty() || VS.NumFragments == CachePtr->size() ||
369 IsPointer) &&
370 "Inconsistent vector sizes");
371 if (VS.NumFragments > CachePtr->size())
372 CachePtr->resize(N: VS.NumFragments, NV: nullptr);
373 }
374}
375
376// Return fragment Frag, creating a new Value for it if necessary.
377Value *Scatterer::operator[](unsigned Frag) {
378 ValueVector &CV = CachePtr ? *CachePtr : Tmp;
379 // Try to reuse a previous value.
380 if (CV[Frag])
381 return CV[Frag];
382 IRBuilder<> Builder(BB, BBI);
383 if (IsPointer) {
384 if (Frag == 0)
385 CV[Frag] = V;
386 else
387 CV[Frag] = Builder.CreateConstGEP1_32(Ty: VS.SplitTy, Ptr: V, Idx0: Frag,
388 Name: V->getName() + ".i" + Twine(Frag));
389 return CV[Frag];
390 }
391
392 Type *FragmentTy = VS.getFragmentType(I: Frag);
393
394 if (auto *VecTy = dyn_cast<FixedVectorType>(Val: FragmentTy)) {
395 SmallVector<int> Mask;
396 for (unsigned J = 0; J < VecTy->getNumElements(); ++J)
397 Mask.push_back(Elt: Frag * VS.NumPacked + J);
398 CV[Frag] =
399 Builder.CreateShuffleVector(V1: V, V2: PoisonValue::get(T: V->getType()), Mask,
400 Name: V->getName() + ".i" + Twine(Frag));
401 } else {
402 // Search through a chain of InsertElementInsts looking for element Frag.
403 // Record other elements in the cache. The new V is still suitable
404 // for all uncached indices.
405 while (true) {
406 InsertElementInst *Insert = dyn_cast<InsertElementInst>(Val: V);
407 if (!Insert)
408 break;
409 ConstantInt *Idx = dyn_cast<ConstantInt>(Val: Insert->getOperand(i_nocapture: 2));
410 if (!Idx)
411 break;
412 unsigned J = Idx->getZExtValue();
413 V = Insert->getOperand(i_nocapture: 0);
414 if (Frag * VS.NumPacked == J) {
415 CV[Frag] = Insert->getOperand(i_nocapture: 1);
416 return CV[Frag];
417 }
418
419 if (VS.NumPacked == 1 && !CV[J]) {
420 // Only cache the first entry we find for each index we're not actively
421 // searching for. This prevents us from going too far up the chain and
422 // caching incorrect entries.
423 CV[J] = Insert->getOperand(i_nocapture: 1);
424 }
425 }
426 CV[Frag] = Builder.CreateExtractElement(Vec: V, Idx: Frag * VS.NumPacked,
427 Name: V->getName() + ".i" + Twine(Frag));
428 }
429
430 return CV[Frag];
431}
432
433bool ScalarizerLegacyPass::runOnFunction(Function &F) {
434 if (skipFunction(F))
435 return false;
436
437 DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
438 const TargetTransformInfo *TTI =
439 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
440 ScalarizerVisitor Impl(DT, TTI, Options);
441 return Impl.visit(F);
442}
443
444FunctionPass *llvm::createScalarizerPass(const ScalarizerPassOptions &Options) {
445 return new ScalarizerLegacyPass(Options);
446}
447
448bool ScalarizerVisitor::visit(Function &F) {
449 assert(Gathered.empty() && Scattered.empty());
450
451 Scalarized = false;
452
453 // To ensure we replace gathered components correctly we need to do an ordered
454 // traversal of the basic blocks in the function.
455 ReversePostOrderTraversal<BasicBlock *> RPOT(&F.getEntryBlock());
456 for (BasicBlock *BB : RPOT) {
457 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
458 Instruction *I = &*II;
459 bool Done = InstVisitor::visit(I);
460 ++II;
461 if (Done && I->getType()->isVoidTy())
462 I->eraseFromParent();
463 }
464 }
465 return finish();
466}
467
468// Return a scattered form of V that can be accessed by Point. V must be a
469// vector or a pointer to a vector.
470Scatterer ScalarizerVisitor::scatter(Instruction *Point, Value *V,
471 const VectorSplit &VS) {
472 if (Argument *VArg = dyn_cast<Argument>(Val: V)) {
473 // Put the scattered form of arguments in the entry block,
474 // so that it can be used everywhere.
475 Function *F = VArg->getParent();
476 BasicBlock *BB = &F->getEntryBlock();
477 return Scatterer(BB, BB->begin(), V, VS, &Scattered[{V, VS.SplitTy}]);
478 }
479 if (Instruction *VOp = dyn_cast<Instruction>(Val: V)) {
480 // When scalarizing PHI nodes we might try to examine/rewrite InsertElement
481 // nodes in predecessors. If those predecessors are unreachable from entry,
482 // then the IR in those blocks could have unexpected properties resulting in
483 // infinite loops in Scatterer::operator[]. By simply treating values
484 // originating from instructions in unreachable blocks as undef we do not
485 // need to analyse them further.
486 if (!DT->isReachableFromEntry(A: VOp->getParent()))
487 return Scatterer(Point->getParent(), Point->getIterator(),
488 PoisonValue::get(T: V->getType()), VS);
489 // Put the scattered form of an instruction directly after the
490 // instruction, skipping over PHI nodes and debug intrinsics.
491 BasicBlock *BB = VOp->getParent();
492 return Scatterer(
493 BB, skipPastPhiNodesAndDbg(Itr: std::next(x: BasicBlock::iterator(VOp))), V, VS,
494 &Scattered[{V, VS.SplitTy}]);
495 }
496 // In the fallback case, just put the scattered before Point and
497 // keep the result local to Point.
498 return Scatterer(Point->getParent(), Point->getIterator(), V, VS);
499}
500
501// Replace Op with the gathered form of the components in CV. Defer the
502// deletion of Op and creation of the gathered form to the end of the pass,
503// so that we can avoid creating the gathered form if all uses of Op are
504// replaced with uses of CV.
505void ScalarizerVisitor::gather(Instruction *Op, const ValueVector &CV,
506 const VectorSplit &VS) {
507 transferMetadataAndIRFlags(Op, CV);
508
509 // If we already have a scattered form of Op (created from ExtractElements
510 // of Op itself), replace them with the new form.
511 ValueVector &SV = Scattered[{Op, VS.SplitTy}];
512 if (!SV.empty()) {
513 for (unsigned I = 0, E = SV.size(); I != E; ++I) {
514 Value *V = SV[I];
515 if (V == nullptr || SV[I] == CV[I])
516 continue;
517
518 Instruction *Old = cast<Instruction>(Val: V);
519 if (isa<Instruction>(Val: CV[I]))
520 CV[I]->takeName(V: Old);
521 Old->replaceAllUsesWith(V: CV[I]);
522 PotentiallyDeadInstrs.emplace_back(Args&: Old);
523 }
524 }
525 SV = CV;
526 Gathered.push_back(Elt: GatherList::value_type(Op, &SV));
527}
528
529// Replace Op with CV and collect Op has a potentially dead instruction.
530void ScalarizerVisitor::replaceUses(Instruction *Op, Value *CV) {
531 if (CV != Op) {
532 Op->replaceAllUsesWith(V: CV);
533 PotentiallyDeadInstrs.emplace_back(Args&: Op);
534 Scalarized = true;
535 }
536}
537
538// Return true if it is safe to transfer the given metadata tag from
539// vector to scalar instructions.
540bool ScalarizerVisitor::canTransferMetadata(unsigned Tag) {
541 return (Tag == LLVMContext::MD_tbaa
542 || Tag == LLVMContext::MD_fpmath
543 || Tag == LLVMContext::MD_tbaa_struct
544 || Tag == LLVMContext::MD_invariant_load
545 || Tag == LLVMContext::MD_alias_scope
546 || Tag == LLVMContext::MD_noalias
547 || Tag == LLVMContext::MD_mem_parallel_loop_access
548 || Tag == LLVMContext::MD_access_group);
549}
550
551// Transfer metadata from Op to the instructions in CV if it is known
552// to be safe to do so.
553void ScalarizerVisitor::transferMetadataAndIRFlags(Instruction *Op,
554 const ValueVector &CV) {
555 SmallVector<std::pair<unsigned, MDNode *>, 4> MDs;
556 Op->getAllMetadataOtherThanDebugLoc(MDs);
557 for (Value *V : CV) {
558 if (Instruction *New = dyn_cast<Instruction>(Val: V)) {
559 for (const auto &MD : MDs)
560 if (canTransferMetadata(Tag: MD.first))
561 New->setMetadata(KindID: MD.first, Node: MD.second);
562 New->copyIRFlags(V: Op);
563 if (Op->getDebugLoc() && !New->getDebugLoc())
564 New->setDebugLoc(Op->getDebugLoc());
565 }
566 }
567}
568
569// Determine how Ty is split, if at all.
570std::optional<VectorSplit> ScalarizerVisitor::getVectorSplit(Type *Ty) {
571 VectorSplit Split;
572 Split.VecTy = dyn_cast<FixedVectorType>(Val: Ty);
573 if (!Split.VecTy)
574 return {};
575
576 unsigned NumElems = Split.VecTy->getNumElements();
577 Type *ElemTy = Split.VecTy->getElementType();
578
579 if (NumElems == 1 || ElemTy->isPointerTy() ||
580 2 * ElemTy->getScalarSizeInBits() > ScalarizeMinBits) {
581 Split.NumPacked = 1;
582 Split.NumFragments = NumElems;
583 Split.SplitTy = ElemTy;
584 } else {
585 Split.NumPacked = ScalarizeMinBits / ElemTy->getScalarSizeInBits();
586 if (Split.NumPacked >= NumElems)
587 return {};
588
589 Split.NumFragments = divideCeil(Numerator: NumElems, Denominator: Split.NumPacked);
590 Split.SplitTy = FixedVectorType::get(ElementType: ElemTy, NumElts: Split.NumPacked);
591
592 unsigned RemainderElems = NumElems % Split.NumPacked;
593 if (RemainderElems > 1)
594 Split.RemainderTy = FixedVectorType::get(ElementType: ElemTy, NumElts: RemainderElems);
595 else if (RemainderElems == 1)
596 Split.RemainderTy = ElemTy;
597 }
598
599 return Split;
600}
601
602// Try to fill in Layout from Ty, returning true on success. Alignment is
603// the alignment of the vector, or std::nullopt if the ABI default should be
604// used.
605std::optional<VectorLayout>
606ScalarizerVisitor::getVectorLayout(Type *Ty, Align Alignment,
607 const DataLayout &DL) {
608 std::optional<VectorSplit> VS = getVectorSplit(Ty);
609 if (!VS)
610 return {};
611
612 VectorLayout Layout;
613 Layout.VS = *VS;
614 // Check that we're dealing with full-byte fragments.
615 if (!DL.typeSizeEqualsStoreSize(Ty: VS->SplitTy) ||
616 (VS->RemainderTy && !DL.typeSizeEqualsStoreSize(Ty: VS->RemainderTy)))
617 return {};
618 Layout.VecAlign = Alignment;
619 Layout.SplitSize = DL.getTypeStoreSize(Ty: VS->SplitTy);
620 return Layout;
621}
622
623// Scalarize one-operand instruction I, using Split(Builder, X, Name)
624// to create an instruction like I with operand X and name Name.
625template<typename Splitter>
626bool ScalarizerVisitor::splitUnary(Instruction &I, const Splitter &Split) {
627 std::optional<VectorSplit> VS = getVectorSplit(Ty: I.getType());
628 if (!VS)
629 return false;
630
631 std::optional<VectorSplit> OpVS;
632 if (I.getOperand(i: 0)->getType() == I.getType()) {
633 OpVS = VS;
634 } else {
635 OpVS = getVectorSplit(Ty: I.getOperand(i: 0)->getType());
636 if (!OpVS || VS->NumPacked != OpVS->NumPacked)
637 return false;
638 }
639
640 IRBuilder<> Builder(&I);
641 Scatterer Op = scatter(Point: &I, V: I.getOperand(i: 0), VS: *OpVS);
642 assert(Op.size() == VS->NumFragments && "Mismatched unary operation");
643 ValueVector Res;
644 Res.resize(N: VS->NumFragments);
645 for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag)
646 Res[Frag] = Split(Builder, Op[Frag], I.getName() + ".i" + Twine(Frag));
647 gather(Op: &I, CV: Res, VS: *VS);
648 return true;
649}
650
651// Scalarize two-operand instruction I, using Split(Builder, X, Y, Name)
652// to create an instruction like I with operands X and Y and name Name.
653template<typename Splitter>
654bool ScalarizerVisitor::splitBinary(Instruction &I, const Splitter &Split) {
655 std::optional<VectorSplit> VS = getVectorSplit(Ty: I.getType());
656 if (!VS)
657 return false;
658
659 std::optional<VectorSplit> OpVS;
660 if (I.getOperand(i: 0)->getType() == I.getType()) {
661 OpVS = VS;
662 } else {
663 OpVS = getVectorSplit(Ty: I.getOperand(i: 0)->getType());
664 if (!OpVS || VS->NumPacked != OpVS->NumPacked)
665 return false;
666 }
667
668 IRBuilder<> Builder(&I);
669 Scatterer VOp0 = scatter(Point: &I, V: I.getOperand(i: 0), VS: *OpVS);
670 Scatterer VOp1 = scatter(Point: &I, V: I.getOperand(i: 1), VS: *OpVS);
671 assert(VOp0.size() == VS->NumFragments && "Mismatched binary operation");
672 assert(VOp1.size() == VS->NumFragments && "Mismatched binary operation");
673 ValueVector Res;
674 Res.resize(N: VS->NumFragments);
675 for (unsigned Frag = 0; Frag < VS->NumFragments; ++Frag) {
676 Value *Op0 = VOp0[Frag];
677 Value *Op1 = VOp1[Frag];
678 Res[Frag] = Split(Builder, Op0, Op1, I.getName() + ".i" + Twine(Frag));
679 }
680 gather(Op: &I, CV: Res, VS: *VS);
681 return true;
682}
683
684/// If a call to a vector typed intrinsic function, split into a scalar call per
685/// element if possible for the intrinsic.
686bool ScalarizerVisitor::splitCall(CallInst &CI) {
687 Type *CallType = CI.getType();
688 bool AreAllVectorsOfMatchingSize = isStructOfMatchingFixedVectors(Ty: CallType);
689 std::optional<VectorSplit> VS;
690 if (AreAllVectorsOfMatchingSize)
691 VS = getVectorSplit(Ty: CallType->getContainedType(i: 0));
692 else
693 VS = getVectorSplit(Ty: CallType);
694 if (!VS)
695 return false;
696
697 Function *F = CI.getCalledFunction();
698 if (!F)
699 return false;
700
701 Intrinsic::ID ID = F->getIntrinsicID();
702
703 if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
704 return false;
705
706 // unsigned NumElems = VT->getNumElements();
707 unsigned NumArgs = CI.arg_size();
708
709 ValueVector ScalarOperands(NumArgs);
710 SmallVector<Scatterer, 8> Scattered(NumArgs);
711 SmallVector<int> OverloadIdx(NumArgs, -1);
712
713 SmallVector<llvm::Type *, 3> Tys;
714 // Add return type if intrinsic is overloaded on it.
715 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx: -1, TTI))
716 Tys.push_back(Elt: VS->SplitTy);
717
718 if (AreAllVectorsOfMatchingSize) {
719 for (unsigned I = 1; I < CallType->getNumContainedTypes(); I++) {
720 std::optional<VectorSplit> CurrVS =
721 getVectorSplit(Ty: cast<FixedVectorType>(Val: CallType->getContainedType(i: I)));
722 // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
723 // VectorSplit is not returned and we will bailout of handling this call.
724 // The secondary bailout case is if NumPacked does not match. This can
725 // happen if ScalarizeMinBits is not set to the default. This means with
726 // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
727 // the struct elements have the same bitness.
728 if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
729 return false;
730 if (isVectorIntrinsicWithStructReturnOverloadAtField(ID, RetIdx: I, TTI))
731 Tys.push_back(Elt: CurrVS->SplitTy);
732 }
733 }
734 // Assumes that any vector type has the same number of elements as the return
735 // vector type, which is true for all current intrinsics.
736 for (unsigned I = 0; I != NumArgs; ++I) {
737 Value *OpI = CI.getOperand(i_nocapture: I);
738 if ([[maybe_unused]] auto *OpVecTy =
739 dyn_cast<FixedVectorType>(Val: OpI->getType())) {
740 assert(OpVecTy->getNumElements() == VS->VecTy->getNumElements());
741 std::optional<VectorSplit> OpVS = getVectorSplit(Ty: OpI->getType());
742 if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
743 // The natural split of the operand doesn't match the result. This could
744 // happen if the vector elements are different and the ScalarizeMinBits
745 // option is used.
746 //
747 // We could in principle handle this case as well, at the cost of
748 // complicating the scattering machinery to support multiple scattering
749 // granularities for a single value.
750 return false;
751 }
752
753 Scattered[I] = scatter(Point: &CI, V: OpI, VS: *OpVS);
754 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx: I, TTI)) {
755 OverloadIdx[I] = Tys.size();
756 Tys.push_back(Elt: OpVS->SplitTy);
757 }
758 } else {
759 ScalarOperands[I] = OpI;
760 if (isVectorIntrinsicWithOverloadTypeAtArg(ID, OpdIdx: I, TTI))
761 Tys.push_back(Elt: OpI->getType());
762 }
763 }
764
765 ValueVector Res(VS->NumFragments);
766 ValueVector ScalarCallOps(NumArgs);
767
768 Function *NewIntrin =
769 Intrinsic::getOrInsertDeclaration(M: F->getParent(), id: ID, Tys);
770 IRBuilder<> Builder(&CI);
771
772 // Perform actual scalarization, taking care to preserve any scalar operands.
773 for (unsigned I = 0; I < VS->NumFragments; ++I) {
774 bool IsRemainder = I == VS->NumFragments - 1 && VS->RemainderTy;
775 ScalarCallOps.clear();
776
777 if (IsRemainder)
778 Tys[0] = VS->RemainderTy;
779
780 for (unsigned J = 0; J != NumArgs; ++J) {
781 if (isVectorIntrinsicWithScalarOpAtArg(ID, ScalarOpdIdx: J, TTI)) {
782 ScalarCallOps.push_back(Elt: ScalarOperands[J]);
783 } else {
784 ScalarCallOps.push_back(Elt: Scattered[J][I]);
785 if (IsRemainder && OverloadIdx[J] >= 0)
786 Tys[OverloadIdx[J]] = Scattered[J][I]->getType();
787 }
788 }
789
790 if (IsRemainder)
791 NewIntrin = Intrinsic::getOrInsertDeclaration(M: F->getParent(), id: ID, Tys);
792
793 Res[I] = Builder.CreateCall(Callee: NewIntrin, Args: ScalarCallOps,
794 Name: CI.getName() + ".i" + Twine(I));
795 }
796
797 gather(Op: &CI, CV: Res, VS: *VS);
798 return true;
799}
800
801bool ScalarizerVisitor::visitSelectInst(SelectInst &SI) {
802 std::optional<VectorSplit> VS = getVectorSplit(Ty: SI.getType());
803 if (!VS)
804 return false;
805
806 std::optional<VectorSplit> CondVS;
807 if (isa<FixedVectorType>(Val: SI.getCondition()->getType())) {
808 CondVS = getVectorSplit(Ty: SI.getCondition()->getType());
809 if (!CondVS || CondVS->NumPacked != VS->NumPacked) {
810 // This happens when ScalarizeMinBits is used.
811 return false;
812 }
813 }
814
815 IRBuilder<> Builder(&SI);
816 Scatterer VOp1 = scatter(Point: &SI, V: SI.getOperand(i_nocapture: 1), VS: *VS);
817 Scatterer VOp2 = scatter(Point: &SI, V: SI.getOperand(i_nocapture: 2), VS: *VS);
818 assert(VOp1.size() == VS->NumFragments && "Mismatched select");
819 assert(VOp2.size() == VS->NumFragments && "Mismatched select");
820 ValueVector Res;
821 Res.resize(N: VS->NumFragments);
822
823 if (CondVS) {
824 Scatterer VOp0 = scatter(Point: &SI, V: SI.getOperand(i_nocapture: 0), VS: *CondVS);
825 assert(VOp0.size() == CondVS->NumFragments && "Mismatched select");
826 for (unsigned I = 0; I < VS->NumFragments; ++I) {
827 Value *Op0 = VOp0[I];
828 Value *Op1 = VOp1[I];
829 Value *Op2 = VOp2[I];
830 Res[I] = Builder.CreateSelect(C: Op0, True: Op1, False: Op2,
831 Name: SI.getName() + ".i" + Twine(I));
832 }
833 } else {
834 Value *Op0 = SI.getOperand(i_nocapture: 0);
835 for (unsigned I = 0; I < VS->NumFragments; ++I) {
836 Value *Op1 = VOp1[I];
837 Value *Op2 = VOp2[I];
838 Res[I] = Builder.CreateSelect(C: Op0, True: Op1, False: Op2,
839 Name: SI.getName() + ".i" + Twine(I));
840 }
841 }
842 gather(Op: &SI, CV: Res, VS: *VS);
843 return true;
844}
845
846bool ScalarizerVisitor::visitICmpInst(ICmpInst &ICI) {
847 return splitBinary(I&: ICI, Split: ICmpSplitter(ICI));
848}
849
850bool ScalarizerVisitor::visitFCmpInst(FCmpInst &FCI) {
851 return splitBinary(I&: FCI, Split: FCmpSplitter(FCI));
852}
853
854bool ScalarizerVisitor::visitUnaryOperator(UnaryOperator &UO) {
855 return splitUnary(I&: UO, Split: UnarySplitter(UO));
856}
857
858bool ScalarizerVisitor::visitBinaryOperator(BinaryOperator &BO) {
859 return splitBinary(I&: BO, Split: BinarySplitter(BO));
860}
861
862bool ScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
863 std::optional<VectorSplit> VS = getVectorSplit(Ty: GEPI.getType());
864 if (!VS)
865 return false;
866
867 IRBuilder<> Builder(&GEPI);
868 unsigned NumIndices = GEPI.getNumIndices();
869
870 // The base pointer and indices might be scalar even if it's a vector GEP.
871 SmallVector<Value *, 8> ScalarOps{1 + NumIndices};
872 SmallVector<Scatterer, 8> ScatterOps{1 + NumIndices};
873
874 for (unsigned I = 0; I < 1 + NumIndices; ++I) {
875 if (auto *VecTy =
876 dyn_cast<FixedVectorType>(Val: GEPI.getOperand(i_nocapture: I)->getType())) {
877 std::optional<VectorSplit> OpVS = getVectorSplit(Ty: VecTy);
878 if (!OpVS || OpVS->NumPacked != VS->NumPacked) {
879 // This can happen when ScalarizeMinBits is used.
880 return false;
881 }
882 ScatterOps[I] = scatter(Point: &GEPI, V: GEPI.getOperand(i_nocapture: I), VS: *OpVS);
883 } else {
884 ScalarOps[I] = GEPI.getOperand(i_nocapture: I);
885 }
886 }
887
888 ValueVector Res;
889 Res.resize(N: VS->NumFragments);
890 for (unsigned I = 0; I < VS->NumFragments; ++I) {
891 SmallVector<Value *, 8> SplitOps;
892 SplitOps.resize(N: 1 + NumIndices);
893 for (unsigned J = 0; J < 1 + NumIndices; ++J) {
894 if (ScalarOps[J])
895 SplitOps[J] = ScalarOps[J];
896 else
897 SplitOps[J] = ScatterOps[J][I];
898 }
899 Res[I] = Builder.CreateGEP(Ty: GEPI.getSourceElementType(), Ptr: SplitOps[0],
900 IdxList: ArrayRef(SplitOps).drop_front(),
901 Name: GEPI.getName() + ".i" + Twine(I));
902 if (GEPI.isInBounds())
903 if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Val: Res[I]))
904 NewGEPI->setIsInBounds();
905 }
906 gather(Op: &GEPI, CV: Res, VS: *VS);
907 return true;
908}
909
910bool ScalarizerVisitor::visitCastInst(CastInst &CI) {
911 std::optional<VectorSplit> DestVS = getVectorSplit(Ty: CI.getDestTy());
912 if (!DestVS)
913 return false;
914
915 std::optional<VectorSplit> SrcVS = getVectorSplit(Ty: CI.getSrcTy());
916 if (!SrcVS || SrcVS->NumPacked != DestVS->NumPacked)
917 return false;
918
919 IRBuilder<> Builder(&CI);
920 Scatterer Op0 = scatter(Point: &CI, V: CI.getOperand(i_nocapture: 0), VS: *SrcVS);
921 assert(Op0.size() == SrcVS->NumFragments && "Mismatched cast");
922 ValueVector Res;
923 Res.resize(N: DestVS->NumFragments);
924 for (unsigned I = 0; I < DestVS->NumFragments; ++I)
925 Res[I] =
926 Builder.CreateCast(Op: CI.getOpcode(), V: Op0[I], DestTy: DestVS->getFragmentType(I),
927 Name: CI.getName() + ".i" + Twine(I));
928 gather(Op: &CI, CV: Res, VS: *DestVS);
929 return true;
930}
931
932bool ScalarizerVisitor::visitBitCastInst(BitCastInst &BCI) {
933 std::optional<VectorSplit> DstVS = getVectorSplit(Ty: BCI.getDestTy());
934 std::optional<VectorSplit> SrcVS = getVectorSplit(Ty: BCI.getSrcTy());
935 if (!DstVS || !SrcVS || DstVS->RemainderTy || SrcVS->RemainderTy)
936 return false;
937
938 const bool isPointerTy = DstVS->VecTy->getElementType()->isPointerTy();
939
940 // Vectors of pointers are always fully scalarized.
941 assert(!isPointerTy || (DstVS->NumPacked == 1 && SrcVS->NumPacked == 1));
942
943 IRBuilder<> Builder(&BCI);
944 Scatterer Op0 = scatter(Point: &BCI, V: BCI.getOperand(i_nocapture: 0), VS: *SrcVS);
945 ValueVector Res;
946 Res.resize(N: DstVS->NumFragments);
947
948 unsigned DstSplitBits = DstVS->SplitTy->getPrimitiveSizeInBits();
949 unsigned SrcSplitBits = SrcVS->SplitTy->getPrimitiveSizeInBits();
950
951 if (isPointerTy || DstSplitBits == SrcSplitBits) {
952 assert(DstVS->NumFragments == SrcVS->NumFragments);
953 for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
954 Res[I] = Builder.CreateBitCast(V: Op0[I], DestTy: DstVS->getFragmentType(I),
955 Name: BCI.getName() + ".i" + Twine(I));
956 }
957 } else if (SrcSplitBits % DstSplitBits == 0) {
958 // Convert each source fragment to the same-sized destination vector and
959 // then scatter the result to the destination.
960 VectorSplit MidVS;
961 MidVS.NumPacked = DstVS->NumPacked;
962 MidVS.NumFragments = SrcSplitBits / DstSplitBits;
963 MidVS.VecTy = FixedVectorType::get(ElementType: DstVS->VecTy->getElementType(),
964 NumElts: MidVS.NumPacked * MidVS.NumFragments);
965 MidVS.SplitTy = DstVS->SplitTy;
966
967 unsigned ResI = 0;
968 for (unsigned I = 0; I < SrcVS->NumFragments; ++I) {
969 Value *V = Op0[I];
970
971 // Look through any existing bitcasts before converting to <N x t2>.
972 // In the best case, the resulting conversion might be a no-op.
973 Instruction *VI;
974 while ((VI = dyn_cast<Instruction>(Val: V)) &&
975 VI->getOpcode() == Instruction::BitCast)
976 V = VI->getOperand(i: 0);
977
978 V = Builder.CreateBitCast(V, DestTy: MidVS.VecTy, Name: V->getName() + ".cast");
979
980 Scatterer Mid = scatter(Point: &BCI, V, VS: MidVS);
981 for (unsigned J = 0; J < MidVS.NumFragments; ++J)
982 Res[ResI++] = Mid[J];
983 }
984 } else if (DstSplitBits % SrcSplitBits == 0) {
985 // Gather enough source fragments to make up a destination fragment and
986 // then convert to the destination type.
987 VectorSplit MidVS;
988 MidVS.NumFragments = DstSplitBits / SrcSplitBits;
989 MidVS.NumPacked = SrcVS->NumPacked;
990 MidVS.VecTy = FixedVectorType::get(ElementType: SrcVS->VecTy->getElementType(),
991 NumElts: MidVS.NumPacked * MidVS.NumFragments);
992 MidVS.SplitTy = SrcVS->SplitTy;
993
994 unsigned SrcI = 0;
995 SmallVector<Value *, 8> ConcatOps;
996 ConcatOps.resize(N: MidVS.NumFragments);
997 for (unsigned I = 0; I < DstVS->NumFragments; ++I) {
998 for (unsigned J = 0; J < MidVS.NumFragments; ++J)
999 ConcatOps[J] = Op0[SrcI++];
1000 Value *V = concatenate(Builder, Fragments: ConcatOps, VS: MidVS,
1001 Name: BCI.getName() + ".i" + Twine(I));
1002 Res[I] = Builder.CreateBitCast(V, DestTy: DstVS->getFragmentType(I),
1003 Name: BCI.getName() + ".i" + Twine(I));
1004 }
1005 } else {
1006 return false;
1007 }
1008
1009 gather(Op: &BCI, CV: Res, VS: *DstVS);
1010 return true;
1011}
1012
1013bool ScalarizerVisitor::visitInsertElementInst(InsertElementInst &IEI) {
1014 std::optional<VectorSplit> VS = getVectorSplit(Ty: IEI.getType());
1015 if (!VS)
1016 return false;
1017
1018 IRBuilder<> Builder(&IEI);
1019 Scatterer Op0 = scatter(Point: &IEI, V: IEI.getOperand(i_nocapture: 0), VS: *VS);
1020 Value *NewElt = IEI.getOperand(i_nocapture: 1);
1021 Value *InsIdx = IEI.getOperand(i_nocapture: 2);
1022
1023 ValueVector Res;
1024 Res.resize(N: VS->NumFragments);
1025
1026 if (auto *CI = dyn_cast<ConstantInt>(Val: InsIdx)) {
1027 unsigned Idx = CI->getZExtValue();
1028 unsigned Fragment = Idx / VS->NumPacked;
1029 for (unsigned I = 0; I < VS->NumFragments; ++I) {
1030 if (I == Fragment) {
1031 bool IsPacked = VS->NumPacked > 1;
1032 if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1033 !VS->RemainderTy->isVectorTy())
1034 IsPacked = false;
1035 if (IsPacked) {
1036 Res[I] =
1037 Builder.CreateInsertElement(Vec: Op0[I], NewElt, Idx: Idx % VS->NumPacked);
1038 } else {
1039 Res[I] = NewElt;
1040 }
1041 } else {
1042 Res[I] = Op0[I];
1043 }
1044 }
1045 } else {
1046 // Never split a variable insertelement that isn't fully scalarized.
1047 if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1048 return false;
1049
1050 for (unsigned I = 0; I < VS->NumFragments; ++I) {
1051 Value *ShouldReplace =
1052 Builder.CreateICmpEQ(LHS: InsIdx, RHS: ConstantInt::get(Ty: InsIdx->getType(), V: I),
1053 Name: InsIdx->getName() + ".is." + Twine(I));
1054 Value *OldElt = Op0[I];
1055 Res[I] = Builder.CreateSelect(C: ShouldReplace, True: NewElt, False: OldElt,
1056 Name: IEI.getName() + ".i" + Twine(I));
1057 }
1058 }
1059
1060 gather(Op: &IEI, CV: Res, VS: *VS);
1061 return true;
1062}
1063
1064bool ScalarizerVisitor::visitExtractValueInst(ExtractValueInst &EVI) {
1065 Value *Op = EVI.getOperand(i_nocapture: 0);
1066 Type *OpTy = Op->getType();
1067 ValueVector Res;
1068 if (!isStructOfMatchingFixedVectors(Ty: OpTy))
1069 return false;
1070 if (CallInst *CI = dyn_cast<CallInst>(Val: Op)) {
1071 Function *F = CI->getCalledFunction();
1072 if (!F)
1073 return false;
1074 Intrinsic::ID ID = F->getIntrinsicID();
1075 if (ID == Intrinsic::not_intrinsic || !isTriviallyScalarizable(ID, TTI))
1076 return false;
1077 // Note: Fall through means Operand is a`CallInst` and it is defined in
1078 // `isTriviallyScalarizable`.
1079 } else
1080 return false;
1081 Type *VecType = cast<FixedVectorType>(Val: OpTy->getContainedType(i: 0));
1082 std::optional<VectorSplit> VS = getVectorSplit(Ty: VecType);
1083 if (!VS)
1084 return false;
1085 for (unsigned I = 1; I < OpTy->getNumContainedTypes(); I++) {
1086 std::optional<VectorSplit> CurrVS =
1087 getVectorSplit(Ty: cast<FixedVectorType>(Val: OpTy->getContainedType(i: I)));
1088 // It is possible for VectorSplit.NumPacked >= NumElems. If that happens a
1089 // VectorSplit is not returned and we will bailout of handling this call.
1090 // The secondary bailout case is if NumPacked does not match. This can
1091 // happen if ScalarizeMinBits is not set to the default. This means with
1092 // certain ScalarizeMinBits intrinsics like frexp will only scalarize when
1093 // the struct elements have the same bitness.
1094 if (!CurrVS || CurrVS->NumPacked != VS->NumPacked)
1095 return false;
1096 }
1097 IRBuilder<> Builder(&EVI);
1098 Scatterer Op0 = scatter(Point: &EVI, V: Op, VS: *VS);
1099 assert(!EVI.getIndices().empty() && "Make sure an index exists");
1100 // Note for our use case we only care about the top level index.
1101 unsigned Index = EVI.getIndices()[0];
1102 for (unsigned OpIdx = 0; OpIdx < Op0.size(); ++OpIdx) {
1103 Value *ResElem = Builder.CreateExtractValue(
1104 Agg: Op0[OpIdx], Idxs: Index, Name: EVI.getName() + ".elem" + Twine(Index));
1105 Res.push_back(Elt: ResElem);
1106 }
1107
1108 gather(Op: &EVI, CV: Res, VS: *VS);
1109 return true;
1110}
1111
1112bool ScalarizerVisitor::visitExtractElementInst(ExtractElementInst &EEI) {
1113 std::optional<VectorSplit> VS = getVectorSplit(Ty: EEI.getOperand(i_nocapture: 0)->getType());
1114 if (!VS)
1115 return false;
1116
1117 IRBuilder<> Builder(&EEI);
1118 Scatterer Op0 = scatter(Point: &EEI, V: EEI.getOperand(i_nocapture: 0), VS: *VS);
1119 Value *ExtIdx = EEI.getOperand(i_nocapture: 1);
1120
1121 if (auto *CI = dyn_cast<ConstantInt>(Val: ExtIdx)) {
1122 unsigned Idx = CI->getZExtValue();
1123 unsigned Fragment = Idx / VS->NumPacked;
1124 Value *Res = Op0[Fragment];
1125 bool IsPacked = VS->NumPacked > 1;
1126 if (Fragment == VS->NumFragments - 1 && VS->RemainderTy &&
1127 !VS->RemainderTy->isVectorTy())
1128 IsPacked = false;
1129 if (IsPacked)
1130 Res = Builder.CreateExtractElement(Vec: Res, Idx: Idx % VS->NumPacked);
1131 replaceUses(Op: &EEI, CV: Res);
1132 return true;
1133 }
1134
1135 // Never split a variable extractelement that isn't fully scalarized.
1136 if (!ScalarizeVariableInsertExtract || VS->NumPacked > 1)
1137 return false;
1138
1139 Value *Res = PoisonValue::get(T: VS->VecTy->getElementType());
1140 for (unsigned I = 0; I < VS->NumFragments; ++I) {
1141 Value *ShouldExtract =
1142 Builder.CreateICmpEQ(LHS: ExtIdx, RHS: ConstantInt::get(Ty: ExtIdx->getType(), V: I),
1143 Name: ExtIdx->getName() + ".is." + Twine(I));
1144 Value *Elt = Op0[I];
1145 Res = Builder.CreateSelect(C: ShouldExtract, True: Elt, False: Res,
1146 Name: EEI.getName() + ".upto" + Twine(I));
1147 }
1148 replaceUses(Op: &EEI, CV: Res);
1149 return true;
1150}
1151
1152bool ScalarizerVisitor::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
1153 std::optional<VectorSplit> VS = getVectorSplit(Ty: SVI.getType());
1154 std::optional<VectorSplit> VSOp =
1155 getVectorSplit(Ty: SVI.getOperand(i_nocapture: 0)->getType());
1156 if (!VS || !VSOp || VS->NumPacked > 1 || VSOp->NumPacked > 1)
1157 return false;
1158
1159 Scatterer Op0 = scatter(Point: &SVI, V: SVI.getOperand(i_nocapture: 0), VS: *VSOp);
1160 Scatterer Op1 = scatter(Point: &SVI, V: SVI.getOperand(i_nocapture: 1), VS: *VSOp);
1161 ValueVector Res;
1162 Res.resize(N: VS->NumFragments);
1163
1164 for (unsigned I = 0; I < VS->NumFragments; ++I) {
1165 int Selector = SVI.getMaskValue(Elt: I);
1166 if (Selector < 0)
1167 Res[I] = PoisonValue::get(T: VS->VecTy->getElementType());
1168 else if (unsigned(Selector) < Op0.size())
1169 Res[I] = Op0[Selector];
1170 else
1171 Res[I] = Op1[Selector - Op0.size()];
1172 }
1173 gather(Op: &SVI, CV: Res, VS: *VS);
1174 return true;
1175}
1176
1177bool ScalarizerVisitor::visitPHINode(PHINode &PHI) {
1178 std::optional<VectorSplit> VS = getVectorSplit(Ty: PHI.getType());
1179 if (!VS)
1180 return false;
1181
1182 IRBuilder<> Builder(&PHI);
1183 ValueVector Res;
1184 Res.resize(N: VS->NumFragments);
1185
1186 unsigned NumOps = PHI.getNumOperands();
1187 for (unsigned I = 0; I < VS->NumFragments; ++I) {
1188 Res[I] = Builder.CreatePHI(Ty: VS->getFragmentType(I), NumReservedValues: NumOps,
1189 Name: PHI.getName() + ".i" + Twine(I));
1190 }
1191
1192 for (unsigned I = 0; I < NumOps; ++I) {
1193 Scatterer Op = scatter(Point: &PHI, V: PHI.getIncomingValue(i: I), VS: *VS);
1194 BasicBlock *IncomingBlock = PHI.getIncomingBlock(i: I);
1195 for (unsigned J = 0; J < VS->NumFragments; ++J)
1196 cast<PHINode>(Val: Res[J])->addIncoming(V: Op[J], BB: IncomingBlock);
1197 }
1198 gather(Op: &PHI, CV: Res, VS: *VS);
1199 return true;
1200}
1201
1202bool ScalarizerVisitor::visitLoadInst(LoadInst &LI) {
1203 if (!ScalarizeLoadStore)
1204 return false;
1205 if (!LI.isSimple())
1206 return false;
1207
1208 std::optional<VectorLayout> Layout = getVectorLayout(
1209 Ty: LI.getType(), Alignment: LI.getAlign(), DL: LI.getDataLayout());
1210 if (!Layout)
1211 return false;
1212
1213 IRBuilder<> Builder(&LI);
1214 Scatterer Ptr = scatter(Point: &LI, V: LI.getPointerOperand(), VS: Layout->VS);
1215 ValueVector Res;
1216 Res.resize(N: Layout->VS.NumFragments);
1217
1218 for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1219 Res[I] = Builder.CreateAlignedLoad(Ty: Layout->VS.getFragmentType(I), Ptr: Ptr[I],
1220 Align: Align(Layout->getFragmentAlign(Frag: I)),
1221 Name: LI.getName() + ".i" + Twine(I));
1222 }
1223 gather(Op: &LI, CV: Res, VS: Layout->VS);
1224 return true;
1225}
1226
1227bool ScalarizerVisitor::visitStoreInst(StoreInst &SI) {
1228 if (!ScalarizeLoadStore)
1229 return false;
1230 if (!SI.isSimple())
1231 return false;
1232
1233 Value *FullValue = SI.getValueOperand();
1234 std::optional<VectorLayout> Layout = getVectorLayout(
1235 Ty: FullValue->getType(), Alignment: SI.getAlign(), DL: SI.getDataLayout());
1236 if (!Layout)
1237 return false;
1238
1239 IRBuilder<> Builder(&SI);
1240 Scatterer VPtr = scatter(Point: &SI, V: SI.getPointerOperand(), VS: Layout->VS);
1241 Scatterer VVal = scatter(Point: &SI, V: FullValue, VS: Layout->VS);
1242
1243 ValueVector Stores;
1244 Stores.resize(N: Layout->VS.NumFragments);
1245 for (unsigned I = 0; I < Layout->VS.NumFragments; ++I) {
1246 Value *Val = VVal[I];
1247 Value *Ptr = VPtr[I];
1248 Stores[I] =
1249 Builder.CreateAlignedStore(Val, Ptr, Align: Layout->getFragmentAlign(Frag: I));
1250 }
1251 transferMetadataAndIRFlags(Op: &SI, CV: Stores);
1252 return true;
1253}
1254
1255bool ScalarizerVisitor::visitCallInst(CallInst &CI) {
1256 return splitCall(CI);
1257}
1258
1259bool ScalarizerVisitor::visitFreezeInst(FreezeInst &FI) {
1260 return splitUnary(I&: FI, Split: [](IRBuilder<> &Builder, Value *Op, const Twine &Name) {
1261 return Builder.CreateFreeze(V: Op, Name);
1262 });
1263}
1264
1265// Delete the instructions that we scalarized. If a full vector result
1266// is still needed, recreate it using InsertElements.
1267bool ScalarizerVisitor::finish() {
1268 // The presence of data in Gathered or Scattered indicates changes
1269 // made to the Function.
1270 if (Gathered.empty() && Scattered.empty() && !Scalarized)
1271 return false;
1272 for (const auto &GMI : Gathered) {
1273 Instruction *Op = GMI.first;
1274 ValueVector &CV = *GMI.second;
1275 if (!Op->use_empty()) {
1276 // The value is still needed, so recreate it using a series of
1277 // insertelements and/or shufflevectors.
1278 Value *Res;
1279 if (auto *Ty = dyn_cast<FixedVectorType>(Val: Op->getType())) {
1280 BasicBlock *BB = Op->getParent();
1281 IRBuilder<> Builder(Op);
1282 if (isa<PHINode>(Val: Op))
1283 Builder.SetInsertPoint(TheBB: BB, IP: BB->getFirstInsertionPt());
1284
1285 VectorSplit VS = *getVectorSplit(Ty);
1286 assert(VS.NumFragments == CV.size());
1287
1288 Res = concatenate(Builder, Fragments: CV, VS, Name: Op->getName());
1289
1290 Res->takeName(V: Op);
1291 } else if (auto *Ty = dyn_cast<StructType>(Val: Op->getType())) {
1292 BasicBlock *BB = Op->getParent();
1293 IRBuilder<> Builder(Op);
1294 if (isa<PHINode>(Val: Op))
1295 Builder.SetInsertPoint(TheBB: BB, IP: BB->getFirstInsertionPt());
1296
1297 // Iterate over each element in the struct
1298 unsigned NumOfStructElements = Ty->getNumElements();
1299 SmallVector<ValueVector, 4> ElemCV(NumOfStructElements);
1300 for (unsigned I = 0; I < NumOfStructElements; ++I) {
1301 for (auto *CVelem : CV) {
1302 Value *Elem = Builder.CreateExtractValue(
1303 Agg: CVelem, Idxs: I, Name: Op->getName() + ".elem" + Twine(I));
1304 ElemCV[I].push_back(Elt: Elem);
1305 }
1306 }
1307 Res = PoisonValue::get(T: Ty);
1308 for (unsigned I = 0; I < NumOfStructElements; ++I) {
1309 Type *ElemTy = Ty->getElementType(N: I);
1310 assert(isa<FixedVectorType>(ElemTy) &&
1311 "Only Structs of all FixedVectorType supported");
1312 VectorSplit VS = *getVectorSplit(Ty: ElemTy);
1313 assert(VS.NumFragments == CV.size());
1314
1315 Value *ConcatenatedVector =
1316 concatenate(Builder, Fragments: ElemCV[I], VS, Name: Op->getName());
1317 Res = Builder.CreateInsertValue(Agg: Res, Val: ConcatenatedVector, Idxs: I,
1318 Name: Op->getName() + ".insert");
1319 }
1320 } else {
1321 assert(CV.size() == 1 && Op->getType() == CV[0]->getType());
1322 Res = CV[0];
1323 if (Op == Res)
1324 continue;
1325 }
1326 Op->replaceAllUsesWith(V: Res);
1327 }
1328 PotentiallyDeadInstrs.emplace_back(Args&: Op);
1329 }
1330 Gathered.clear();
1331 Scattered.clear();
1332 Scalarized = false;
1333
1334 RecursivelyDeleteTriviallyDeadInstructionsPermissive(DeadInsts&: PotentiallyDeadInstrs);
1335
1336 return true;
1337}
1338
1339PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM) {
1340 DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(IR&: F);
1341 const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(IR&: F);
1342 ScalarizerVisitor Impl(DT, TTI, Options);
1343 bool Changed = Impl.visit(F);
1344 PreservedAnalyses PA;
1345 PA.preserve<DominatorTreeAnalysis>();
1346 return Changed ? PA : PreservedAnalyses::all();
1347}
1348