1//===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass merges loads/stores to/from sequential memory addresses into vector
10// loads/stores. Although there's nothing GPU-specific in here, this pass is
11// motivated by the microarchitectural quirks of nVidia and AMD GPUs.
12//
13// (For simplicity below we talk about loads only, but everything also applies
14// to stores.)
15//
16// This pass is intended to be run late in the pipeline, after other
17// vectorization opportunities have been exploited. So the assumption here is
18// that immediately following our new vector load we'll need to extract out the
19// individual elements of the load, so we can operate on them individually.
20//
21// On CPUs this transformation is usually not beneficial, because extracting the
22// elements of a vector register is expensive on most architectures. It's
23// usually better just to load each element individually into its own scalar
24// register.
25//
26// However, nVidia and AMD GPUs don't have proper vector registers. Instead, a
27// "vector load" loads directly into a series of scalar registers. In effect,
28// extracting the elements of the vector is free. It's therefore always
29// beneficial to vectorize a sequence of loads on these architectures.
30//
31// Vectorizing (perhaps a better name might be "coalescing") loads can have
32// large performance impacts on GPU kernels, and opportunities for vectorizing
33// are common in GPU code. This pass tries very hard to find such
34// opportunities; its runtime is quadratic in the number of loads in a BB.
35//
36// Some CPU architectures, such as ARM, have instructions that load into
37// multiple scalar registers, similar to a GPU vectorized load. In theory ARM
38// could use this pass (with some modifications), but currently it implements
39// its own pass to do something similar to what we do here.
40//
41// Overview of the algorithm and terminology in this pass:
42//
43// - Break up each basic block into pseudo-BBs, composed of instructions which
44// are guaranteed to transfer control to their successors.
45// - Within a single pseudo-BB, find all loads, and group them into
46// "equivalence classes" according to getUnderlyingObject() and loaded
47// element size. Do the same for stores.
48// - For each equivalence class, greedily build "chains". Each chain has a
49// leader instruction, and every other member of the chain has a known
50// constant offset from the first instr in the chain.
51// - Break up chains so that they contain only contiguous accesses of legal
52// size with no intervening may-alias instrs.
53// - Convert each chain to vector instructions.
54//
55// The O(n^2) behavior of this pass comes from initially building the chains.
56// In the worst case we have to compare each new instruction to all of those
57// that came before. To limit this, we only calculate the offset to the leaders
58// of the N most recently-used chains.
59
60#include "llvm/Transforms/Vectorize/LoadStoreVectorizer.h"
61#include "llvm/ADT/APInt.h"
62#include "llvm/ADT/ArrayRef.h"
63#include "llvm/ADT/DenseMap.h"
64#include "llvm/ADT/MapVector.h"
65#include "llvm/ADT/PostOrderIterator.h"
66#include "llvm/ADT/STLExtras.h"
67#include "llvm/ADT/Sequence.h"
68#include "llvm/ADT/SmallPtrSet.h"
69#include "llvm/ADT/SmallVector.h"
70#include "llvm/ADT/Statistic.h"
71#include "llvm/ADT/iterator_range.h"
72#include "llvm/Analysis/AliasAnalysis.h"
73#include "llvm/Analysis/AssumptionCache.h"
74#include "llvm/Analysis/MemoryLocation.h"
75#include "llvm/Analysis/ScalarEvolution.h"
76#include "llvm/Analysis/TargetTransformInfo.h"
77#include "llvm/Analysis/ValueTracking.h"
78#include "llvm/Analysis/VectorUtils.h"
79#include "llvm/IR/Attributes.h"
80#include "llvm/IR/BasicBlock.h"
81#include "llvm/IR/ConstantRange.h"
82#include "llvm/IR/Constants.h"
83#include "llvm/IR/DataLayout.h"
84#include "llvm/IR/DerivedTypes.h"
85#include "llvm/IR/Dominators.h"
86#include "llvm/IR/Function.h"
87#include "llvm/IR/GetElementPtrTypeIterator.h"
88#include "llvm/IR/IRBuilder.h"
89#include "llvm/IR/InstrTypes.h"
90#include "llvm/IR/Instruction.h"
91#include "llvm/IR/Instructions.h"
92#include "llvm/IR/LLVMContext.h"
93#include "llvm/IR/Module.h"
94#include "llvm/IR/Type.h"
95#include "llvm/IR/Value.h"
96#include "llvm/InitializePasses.h"
97#include "llvm/Pass.h"
98#include "llvm/Support/Alignment.h"
99#include "llvm/Support/Casting.h"
100#include "llvm/Support/Debug.h"
101#include "llvm/Support/KnownBits.h"
102#include "llvm/Support/MathExtras.h"
103#include "llvm/Support/ModRef.h"
104#include "llvm/Support/raw_ostream.h"
105#include "llvm/Transforms/Utils/Local.h"
106#include <algorithm>
107#include <cassert>
108#include <cstdint>
109#include <cstdlib>
110#include <iterator>
111#include <numeric>
112#include <optional>
113#include <tuple>
114#include <type_traits>
115#include <utility>
116#include <vector>
117
118using namespace llvm;
119
120#define DEBUG_TYPE "load-store-vectorizer"
121
122STATISTIC(NumVectorInstructions, "Number of vector accesses generated");
123STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized");
124
125namespace {
126
127// Equivalence class key, the initial tuple by which we group loads/stores.
128// Loads/stores with different EqClassKeys are never merged.
129//
130// (We could in theory remove element-size from the this tuple. We'd just need
131// to fix up the vector packing/unpacking code.)
132using EqClassKey =
133 std::tuple<const Value * /* result of getUnderlyingObject() */,
134 unsigned /* AddrSpace */,
135 unsigned /* Load/Store element size bits */,
136 char /* IsLoad; char b/c bool can't be a DenseMap key */
137 >;
138[[maybe_unused]] llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
139 const EqClassKey &K) {
140 const auto &[UnderlyingObject, AddrSpace, ElementSize, IsLoad] = K;
141 return OS << (IsLoad ? "load" : "store") << " of " << *UnderlyingObject
142 << " of element size " << ElementSize << " bits in addrspace "
143 << AddrSpace;
144}
145
146// A Chain is a set of instructions such that:
147// - All instructions have the same equivalence class, so in particular all are
148// loads, or all are stores.
149// - We know the address accessed by the i'th chain elem relative to the
150// chain's leader instruction, which is the first instr of the chain in BB
151// order.
152//
153// Chains have two canonical orderings:
154// - BB order, sorted by Instr->comesBefore.
155// - Offset order, sorted by OffsetFromLeader.
156// This pass switches back and forth between these orders.
157struct ChainElem {
158 Instruction *Inst;
159 APInt OffsetFromLeader;
160 ChainElem(Instruction *Inst, APInt OffsetFromLeader)
161 : Inst(std::move(Inst)), OffsetFromLeader(std::move(OffsetFromLeader)) {}
162};
163using Chain = SmallVector<ChainElem, 1>;
164
165void sortChainInBBOrder(Chain &C) {
166 sort(C, Comp: [](auto &A, auto &B) { return A.Inst->comesBefore(B.Inst); });
167}
168
169void sortChainInOffsetOrder(Chain &C) {
170 sort(C, Comp: [](const auto &A, const auto &B) {
171 if (A.OffsetFromLeader != B.OffsetFromLeader)
172 return A.OffsetFromLeader.slt(B.OffsetFromLeader);
173 return A.Inst->comesBefore(B.Inst); // stable tiebreaker
174 });
175}
176
177[[maybe_unused]] void dumpChain(ArrayRef<ChainElem> C) {
178 for (const auto &E : C) {
179 dbgs() << " " << *E.Inst << " (offset " << E.OffsetFromLeader << ")\n";
180 }
181}
182
183using EquivalenceClassMap =
184 MapVector<EqClassKey, SmallVector<Instruction *, 8>>;
185
186// FIXME: Assuming stack alignment of 4 is always good enough
187constexpr unsigned StackAdjustedAlignment = 4;
188
189Instruction *propagateMetadata(Instruction *I, const Chain &C) {
190 SmallVector<Value *, 8> Values;
191 for (const ChainElem &E : C)
192 Values.emplace_back(Args: E.Inst);
193 return propagateMetadata(I, VL: Values);
194}
195
196bool isInvariantLoad(const Instruction *I) {
197 const LoadInst *LI = dyn_cast<LoadInst>(Val: I);
198 return LI != nullptr && LI->hasMetadata(KindID: LLVMContext::MD_invariant_load);
199}
200
201/// Reorders the instructions that I depends on (the instructions defining its
202/// operands), to ensure they dominate I.
203void reorder(Instruction *I) {
204 SmallPtrSet<Instruction *, 16> InstructionsToMove;
205 SmallVector<Instruction *, 16> Worklist;
206
207 Worklist.emplace_back(Args&: I);
208 while (!Worklist.empty()) {
209 Instruction *IW = Worklist.pop_back_val();
210 int NumOperands = IW->getNumOperands();
211 for (int Idx = 0; Idx < NumOperands; Idx++) {
212 Instruction *IM = dyn_cast<Instruction>(Val: IW->getOperand(i: Idx));
213 if (!IM || IM->getOpcode() == Instruction::PHI)
214 continue;
215
216 // If IM is in another BB, no need to move it, because this pass only
217 // vectorizes instructions within one BB.
218 if (IM->getParent() != I->getParent())
219 continue;
220
221 assert(IM != I && "Unexpected cycle while re-ordering instructions");
222
223 if (!IM->comesBefore(Other: I)) {
224 InstructionsToMove.insert(Ptr: IM);
225 Worklist.emplace_back(Args&: IM);
226 }
227 }
228 }
229
230 // All instructions to move should follow I. Start from I, not from begin().
231 for (auto BBI = I->getIterator(), E = I->getParent()->end(); BBI != E;) {
232 Instruction *IM = &*(BBI++);
233 if (!InstructionsToMove.contains(Ptr: IM))
234 continue;
235 IM->moveBefore(InsertPos: I->getIterator());
236 }
237}
238
239class Vectorizer {
240 Function &F;
241 AliasAnalysis &AA;
242 AssumptionCache &AC;
243 DominatorTree &DT;
244 ScalarEvolution &SE;
245 TargetTransformInfo &TTI;
246 const DataLayout &DL;
247 IRBuilder<> Builder;
248
249 /// We could erase instrs right after vectorizing them, but that can mess up
250 /// our BB iterators, and also can make the equivalence class keys point to
251 /// freed memory. This is fixable, but it's simpler just to wait until we're
252 /// done with the BB and erase all at once.
253 SmallVector<Instruction *, 128> ToErase;
254
255 /// We insert load/store instructions and GEPs to fill gaps and extend chains
256 /// to enable vectorization. Keep track and delete them later.
257 DenseSet<Instruction *> ExtraElements;
258
259public:
260 Vectorizer(Function &F, AliasAnalysis &AA, AssumptionCache &AC,
261 DominatorTree &DT, ScalarEvolution &SE, TargetTransformInfo &TTI)
262 : F(F), AA(AA), AC(AC), DT(DT), SE(SE), TTI(TTI),
263 DL(F.getDataLayout()), Builder(SE.getContext()) {}
264
265 bool run();
266
267private:
268 static const unsigned MaxDepth = 3;
269
270 /// Runs the vectorizer on a "pseudo basic block", which is a range of
271 /// instructions [Begin, End) within one BB all of which have
272 /// isGuaranteedToTransferExecutionToSuccessor(I) == true.
273 bool runOnPseudoBB(BasicBlock::iterator Begin, BasicBlock::iterator End);
274
275 /// Runs the vectorizer on one equivalence class, i.e. one set of loads/stores
276 /// in the same BB with the same value for getUnderlyingObject() etc.
277 bool runOnEquivalenceClass(const EqClassKey &EqClassKey,
278 ArrayRef<Instruction *> EqClass);
279
280 /// Runs the vectorizer on one chain, i.e. a subset of an equivalence class
281 /// where all instructions access a known, constant offset from the first
282 /// instruction.
283 bool runOnChain(Chain &C);
284
285 /// Splits the chain into subchains of instructions which read/write a
286 /// contiguous block of memory. Discards any length-1 subchains (because
287 /// there's nothing to vectorize in there). Also attempts to fill gaps with
288 /// "extra" elements to artificially make chains contiguous in some cases.
289 std::vector<Chain> splitChainByContiguity(Chain &C);
290
291 /// Splits the chain into subchains where it's safe to hoist loads up to the
292 /// beginning of the sub-chain and it's safe to sink loads up to the end of
293 /// the sub-chain. Discards any length-1 subchains. Also attempts to extend
294 /// non-power-of-two chains by adding "extra" elements in some cases.
295 std::vector<Chain> splitChainByMayAliasInstrs(Chain &C);
296
297 /// Splits the chain into subchains that make legal, aligned accesses.
298 /// Discards any length-1 subchains.
299 std::vector<Chain> splitChainByAlignment(Chain &C);
300
301 /// Converts the instrs in the chain into a single vectorized load or store.
302 /// Adds the old scalar loads/stores to ToErase.
303 bool vectorizeChain(Chain &C);
304
305 /// Tries to compute the offset in bytes PtrB - PtrA.
306 std::optional<APInt> getConstantOffset(Value *PtrA, Value *PtrB,
307 Instruction *ContextInst,
308 unsigned Depth = 0);
309 std::optional<APInt> getConstantOffsetComplexAddrs(Value *PtrA, Value *PtrB,
310 Instruction *ContextInst,
311 unsigned Depth);
312 std::optional<APInt> getConstantOffsetSelects(Value *PtrA, Value *PtrB,
313 Instruction *ContextInst,
314 unsigned Depth);
315
316 /// Gets the element type of the vector that the chain will load or store.
317 /// This is nontrivial because the chain may contain elements of different
318 /// types; e.g. it's legal to have a chain that contains both i32 and float.
319 Type *getChainElemTy(const Chain &C);
320
321 /// Determines whether ChainElem can be moved up (if IsLoad) or down (if
322 /// !IsLoad) to ChainBegin -- i.e. there are no intervening may-alias
323 /// instructions.
324 ///
325 /// The map ChainElemOffsets must contain all of the elements in
326 /// [ChainBegin, ChainElem] and their offsets from some arbitrary base
327 /// address. It's ok if it contains additional entries.
328 template <bool IsLoadChain>
329 bool isSafeToMove(
330 Instruction *ChainElem, Instruction *ChainBegin,
331 const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets,
332 BatchAAResults &BatchAA);
333
334 /// Merges the equivalence classes if they have underlying objects that differ
335 /// by one level of indirection (i.e., one is a getelementptr and the other is
336 /// the base pointer in that getelementptr).
337 void mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const;
338
339 /// Collects loads and stores grouped by "equivalence class", where:
340 /// - all elements in an eq class are a load or all are a store,
341 /// - they all load/store the same element size (it's OK to have e.g. i8 and
342 /// <4 x i8> in the same class, but not i32 and <4 x i8>), and
343 /// - they all have the same value for getUnderlyingObject().
344 EquivalenceClassMap collectEquivalenceClasses(BasicBlock::iterator Begin,
345 BasicBlock::iterator End);
346
347 /// Partitions Instrs into "chains" where every instruction has a known
348 /// constant offset from the first instr in the chain.
349 ///
350 /// Postcondition: For all i, ret[i][0].second == 0, because the first instr
351 /// in the chain is the leader, and an instr touches distance 0 from itself.
352 std::vector<Chain> gatherChains(ArrayRef<Instruction *> Instrs);
353
354 /// Checks if a potential vector load/store with a given alignment is allowed
355 /// and fast. Aligned accesses are always allowed and fast, while misaligned
356 /// accesses depend on TTI checks to determine whether they can and should be
357 /// vectorized or kept as element-wise accesses.
358 bool accessIsAllowedAndFast(unsigned SizeBytes, unsigned AS, Align Alignment,
359 unsigned VecElemBits) const;
360
361 /// Create a new GEP and a new Load/Store instruction such that the GEP
362 /// is pointing at PrevElem + Offset. In the case of stores, store poison.
363 /// Extra elements will either be combined into a masked load/store or
364 /// deleted before the end of the pass.
365 ChainElem createExtraElementAfter(const ChainElem &PrevElem, Type *Ty,
366 APInt Offset, StringRef Prefix,
367 Align Alignment = Align());
368
369 /// Create a mask that masks off the extra elements in the chain, to be used
370 /// for the creation of a masked load/store vector.
371 Value *createMaskForExtraElements(const ArrayRef<ChainElem> C,
372 FixedVectorType *VecTy);
373
374 /// Delete dead GEPs and extra Load/Store instructions created by
375 /// createExtraElementAfter
376 void deleteExtraElements();
377};
378
379class LoadStoreVectorizerLegacyPass : public FunctionPass {
380public:
381 static char ID;
382
383 LoadStoreVectorizerLegacyPass() : FunctionPass(ID) {}
384
385 bool runOnFunction(Function &F) override;
386
387 StringRef getPassName() const override {
388 return "GPU Load and Store Vectorizer";
389 }
390
391 void getAnalysisUsage(AnalysisUsage &AU) const override {
392 AU.addRequired<AAResultsWrapperPass>();
393 AU.addRequired<AssumptionCacheTracker>();
394 AU.addRequired<ScalarEvolutionWrapperPass>();
395 AU.addRequired<DominatorTreeWrapperPass>();
396 AU.addRequired<TargetTransformInfoWrapperPass>();
397 AU.setPreservesCFG();
398 }
399};
400
401} // end anonymous namespace
402
403char LoadStoreVectorizerLegacyPass::ID = 0;
404
405INITIALIZE_PASS_BEGIN(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
406 "Vectorize load and Store instructions", false, false)
407INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
408INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker);
409INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
410INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
411INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
412INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
413INITIALIZE_PASS_END(LoadStoreVectorizerLegacyPass, DEBUG_TYPE,
414 "Vectorize load and store instructions", false, false)
415
416Pass *llvm::createLoadStoreVectorizerPass() {
417 return new LoadStoreVectorizerLegacyPass();
418}
419
420bool LoadStoreVectorizerLegacyPass::runOnFunction(Function &F) {
421 // Don't vectorize when the attribute NoImplicitFloat is used.
422 if (skipFunction(F) || F.hasFnAttribute(Kind: Attribute::NoImplicitFloat))
423 return false;
424
425 AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
426 DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
427 ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
428 TargetTransformInfo &TTI =
429 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
430
431 AssumptionCache &AC =
432 getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
433
434 return Vectorizer(F, AA, AC, DT, SE, TTI).run();
435}
436
437PreservedAnalyses LoadStoreVectorizerPass::run(Function &F,
438 FunctionAnalysisManager &AM) {
439 // Don't vectorize when the attribute NoImplicitFloat is used.
440 if (F.hasFnAttribute(Kind: Attribute::NoImplicitFloat))
441 return PreservedAnalyses::all();
442
443 AliasAnalysis &AA = AM.getResult<AAManager>(IR&: F);
444 DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(IR&: F);
445 ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(IR&: F);
446 TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(IR&: F);
447 AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(IR&: F);
448
449 bool Changed = Vectorizer(F, AA, AC, DT, SE, TTI).run();
450 PreservedAnalyses PA;
451 PA.preserveSet<CFGAnalyses>();
452 return Changed ? PA : PreservedAnalyses::all();
453}
454
455bool Vectorizer::run() {
456 bool Changed = false;
457 // Break up the BB if there are any instrs which aren't guaranteed to transfer
458 // execution to their successor.
459 //
460 // Consider, for example:
461 //
462 // def assert_arr_len(int n) { if (n < 2) exit(); }
463 //
464 // load arr[0]
465 // call assert_array_len(arr.length)
466 // load arr[1]
467 //
468 // Even though assert_arr_len does not read or write any memory, we can't
469 // speculate the second load before the call. More info at
470 // https://github.com/llvm/llvm-project/issues/52950.
471 for (BasicBlock *BB : post_order(G: &F)) {
472 // BB must at least have a terminator.
473 assert(!BB->empty());
474
475 SmallVector<BasicBlock::iterator, 8> Barriers;
476 Barriers.emplace_back(Args: BB->begin());
477 for (Instruction &I : *BB)
478 if (!isGuaranteedToTransferExecutionToSuccessor(I: &I))
479 Barriers.emplace_back(Args: I.getIterator());
480 Barriers.emplace_back(Args: BB->end());
481
482 for (auto It = Barriers.begin(), End = std::prev(x: Barriers.end()); It != End;
483 ++It)
484 Changed |= runOnPseudoBB(Begin: *It, End: *std::next(x: It));
485
486 for (Instruction *I : ToErase) {
487 // These will get deleted in deleteExtraElements.
488 // This is because ExtraElements will include both extra elements
489 // that *were* vectorized and extra elements that *were not*
490 // vectorized. ToErase will only include extra elements that *were*
491 // vectorized, so in order to avoid double deletion we skip them here and
492 // handle them in deleteExtraElements.
493 if (ExtraElements.contains(V: I))
494 continue;
495 auto *PtrOperand = getLoadStorePointerOperand(V: I);
496 if (I->use_empty())
497 I->eraseFromParent();
498 RecursivelyDeleteTriviallyDeadInstructions(V: PtrOperand);
499 }
500 ToErase.clear();
501 deleteExtraElements();
502 }
503
504 return Changed;
505}
506
507bool Vectorizer::runOnPseudoBB(BasicBlock::iterator Begin,
508 BasicBlock::iterator End) {
509 LLVM_DEBUG({
510 dbgs() << "LSV: Running on pseudo-BB [" << *Begin << " ... ";
511 if (End != Begin->getParent()->end())
512 dbgs() << *End;
513 else
514 dbgs() << "<BB end>";
515 dbgs() << ")\n";
516 });
517
518 bool Changed = false;
519 for (const auto &[EqClassKey, EqClass] :
520 collectEquivalenceClasses(Begin, End))
521 Changed |= runOnEquivalenceClass(EqClassKey, EqClass);
522
523 return Changed;
524}
525
526bool Vectorizer::runOnEquivalenceClass(const EqClassKey &EqClassKey,
527 ArrayRef<Instruction *> EqClass) {
528 bool Changed = false;
529
530 LLVM_DEBUG({
531 dbgs() << "LSV: Running on equivalence class of size " << EqClass.size()
532 << " keyed on " << EqClassKey << ":\n";
533 for (Instruction *I : EqClass)
534 dbgs() << " " << *I << "\n";
535 });
536
537 std::vector<Chain> Chains = gatherChains(Instrs: EqClass);
538 LLVM_DEBUG(dbgs() << "LSV: Got " << Chains.size()
539 << " nontrivial chains.\n";);
540 for (Chain &C : Chains)
541 Changed |= runOnChain(C);
542 return Changed;
543}
544
545bool Vectorizer::runOnChain(Chain &C) {
546 LLVM_DEBUG({
547 dbgs() << "LSV: Running on chain with " << C.size() << " instructions:\n";
548 dumpChain(C);
549 });
550
551 // Split up the chain into increasingly smaller chains, until we can finally
552 // vectorize the chains.
553 //
554 // (Don't be scared by the depth of the loop nest here. These operations are
555 // all at worst O(n lg n) in the number of instructions, and splitting chains
556 // doesn't change the number of instrs. So the whole loop nest is O(n lg n).)
557 bool Changed = false;
558 for (auto &C : splitChainByMayAliasInstrs(C))
559 for (auto &C : splitChainByContiguity(C))
560 for (auto &C : splitChainByAlignment(C))
561 Changed |= vectorizeChain(C);
562 return Changed;
563}
564
565std::vector<Chain> Vectorizer::splitChainByMayAliasInstrs(Chain &C) {
566 if (C.empty())
567 return {};
568
569 sortChainInBBOrder(C);
570
571 LLVM_DEBUG({
572 dbgs() << "LSV: splitChainByMayAliasInstrs considering chain:\n";
573 dumpChain(C);
574 });
575
576 // We know that elements in the chain with nonverlapping offsets can't
577 // alias, but AA may not be smart enough to figure this out. Use a
578 // hashtable.
579 DenseMap<Instruction *, APInt /*OffsetFromLeader*/> ChainOffsets;
580 for (const auto &E : C)
581 ChainOffsets.insert(KV: {&*E.Inst, E.OffsetFromLeader});
582
583 // Across a single invocation of this function the IR is not changing, so
584 // using a batched Alias Analysis is safe and can reduce compile time.
585 BatchAAResults BatchAA(AA);
586
587 // Loads get hoisted up to the first load in the chain. Stores get sunk
588 // down to the last store in the chain. Our algorithm for loads is:
589 //
590 // - Take the first element of the chain. This is the start of a new chain.
591 // - Take the next element of `Chain` and check for may-alias instructions
592 // up to the start of NewChain. If no may-alias instrs, add it to
593 // NewChain. Otherwise, start a new NewChain.
594 //
595 // For stores it's the same except in the reverse direction.
596 //
597 // We expect IsLoad to be an std::bool_constant.
598 auto Impl = [&](auto IsLoad) {
599 // MSVC is unhappy if IsLoad is a capture, so pass it as an arg.
600 auto [ChainBegin, ChainEnd] = [&](auto IsLoad) {
601 if constexpr (IsLoad())
602 return std::make_pair(x: C.begin(), y: C.end());
603 else
604 return std::make_pair(x: C.rbegin(), y: C.rend());
605 }(IsLoad);
606 assert(ChainBegin != ChainEnd);
607
608 std::vector<Chain> Chains;
609 SmallVector<ChainElem, 1> NewChain;
610 NewChain.emplace_back(*ChainBegin);
611 for (auto ChainIt = std::next(ChainBegin); ChainIt != ChainEnd; ++ChainIt) {
612 if (isSafeToMove<IsLoad>(ChainIt->Inst, NewChain.front().Inst,
613 ChainOffsets, BatchAA)) {
614 LLVM_DEBUG(dbgs() << "LSV: No intervening may-alias instrs; can merge "
615 << *ChainIt->Inst << " into " << *ChainBegin->Inst
616 << "\n");
617 NewChain.emplace_back(*ChainIt);
618 } else {
619 LLVM_DEBUG(
620 dbgs() << "LSV: Found intervening may-alias instrs; cannot merge "
621 << *ChainIt->Inst << " into " << *ChainBegin->Inst << "\n");
622 if (NewChain.size() > 1) {
623 LLVM_DEBUG({
624 dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
625 dumpChain(NewChain);
626 });
627 Chains.emplace_back(args: std::move(NewChain));
628 }
629
630 // Start a new chain.
631 NewChain = SmallVector<ChainElem, 1>({*ChainIt});
632 }
633 }
634 if (NewChain.size() > 1) {
635 LLVM_DEBUG({
636 dbgs() << "LSV: got nontrivial chain without aliasing instrs:\n";
637 dumpChain(NewChain);
638 });
639 Chains.emplace_back(args: std::move(NewChain));
640 }
641 return Chains;
642 };
643
644 if (isa<LoadInst>(Val: C[0].Inst))
645 return Impl(/*IsLoad=*/std::bool_constant<true>());
646
647 assert(isa<StoreInst>(C[0].Inst));
648 return Impl(/*IsLoad=*/std::bool_constant<false>());
649}
650
651std::vector<Chain> Vectorizer::splitChainByContiguity(Chain &C) {
652 if (C.empty())
653 return {};
654
655 sortChainInOffsetOrder(C);
656
657 LLVM_DEBUG({
658 dbgs() << "LSV: splitChainByContiguity considering chain:\n";
659 dumpChain(C);
660 });
661
662 // If the chain is not contiguous, we try to fill the gap with "extra"
663 // elements to artificially make it contiguous, to try to enable
664 // vectorization. We only fill gaps if there is potential to end up with a
665 // legal masked load/store given the target, address space, and element type.
666 // At this point, when querying the TTI, optimistically assume max alignment
667 // and max vector size, as splitChainByAlignment will ensure the final vector
668 // shape passes the legalization check.
669 unsigned AS = getLoadStoreAddressSpace(I: C[0].Inst);
670 Type *ElementType = getLoadStoreType(I: C[0].Inst)->getScalarType();
671 unsigned MaxVecRegBits = TTI.getLoadStoreVecRegBitWidth(AddrSpace: AS);
672 Align OptimisticAlign = Align(MaxVecRegBits / 8);
673 unsigned int MaxVectorNumElems =
674 MaxVecRegBits / DL.getTypeSizeInBits(Ty: ElementType);
675 // Note: This check decides whether to try to fill gaps based on the masked
676 // legality of the target's maximum vector size (getLoadStoreVecRegBitWidth).
677 // If a target *does not* support a masked load/store with this max vector
678 // size, but *does* support a masked load/store with a *smaller* vector size,
679 // that optimization will be missed. This does not occur in any of the targets
680 // that currently support this API.
681 FixedVectorType *OptimisticVectorType =
682 FixedVectorType::get(ElementType, NumElts: MaxVectorNumElems);
683 bool TryFillGaps =
684 isa<LoadInst>(Val: C[0].Inst)
685 ? TTI.isLegalMaskedLoad(DataType: OptimisticVectorType, Alignment: OptimisticAlign, AddressSpace: AS,
686 MaskKind: TTI::MaskKind::ConstantMask)
687 : TTI.isLegalMaskedStore(DataType: OptimisticVectorType, Alignment: OptimisticAlign, AddressSpace: AS,
688 MaskKind: TTI::MaskKind::ConstantMask);
689
690 // Cache the best aligned element in the chain for use when creating extra
691 // elements.
692 Align BestAlignedElemAlign = getLoadStoreAlignment(I: C[0].Inst);
693 APInt OffsetOfBestAlignedElemFromLeader = C[0].OffsetFromLeader;
694 for (const auto &E : C) {
695 Align ElementAlignment = getLoadStoreAlignment(I: E.Inst);
696 if (ElementAlignment > BestAlignedElemAlign) {
697 BestAlignedElemAlign = ElementAlignment;
698 OffsetOfBestAlignedElemFromLeader = E.OffsetFromLeader;
699 }
700 }
701
702 auto DeriveAlignFromBestAlignedElem = [&](APInt NewElemOffsetFromLeader) {
703 return commonAlignment(
704 A: BestAlignedElemAlign,
705 Offset: (NewElemOffsetFromLeader - OffsetOfBestAlignedElemFromLeader)
706 .abs()
707 .getLimitedValue());
708 };
709
710 unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
711
712 std::vector<Chain> Ret;
713 Ret.push_back(x: {C.front()});
714
715 unsigned ChainElemTyBits = DL.getTypeSizeInBits(Ty: getChainElemTy(C));
716 ChainElem &Prev = C[0];
717 for (auto It = std::next(x: C.begin()), End = C.end(); It != End; ++It) {
718 auto &CurChain = Ret.back();
719
720 APInt PrevSzBytes =
721 APInt(ASPtrBits, DL.getTypeStoreSize(Ty: getLoadStoreType(I: Prev.Inst)));
722 APInt PrevReadEnd = Prev.OffsetFromLeader + PrevSzBytes;
723 unsigned SzBytes = DL.getTypeStoreSize(Ty: getLoadStoreType(I: It->Inst));
724
725 // Add this instruction to the end of the current chain, or start a new one.
726 assert(
727 8 * SzBytes % ChainElemTyBits == 0 &&
728 "Every chain-element size must be a multiple of the element size after "
729 "vectorization.");
730 APInt ReadEnd = It->OffsetFromLeader + SzBytes;
731 // Allow redundancy: partial or full overlap counts as contiguous.
732 bool AreContiguous = false;
733 if (It->OffsetFromLeader.sle(RHS: PrevReadEnd)) {
734 // Check overlap is a multiple of the element size after vectorization.
735 uint64_t Overlap = (PrevReadEnd - It->OffsetFromLeader).getZExtValue();
736 if (8 * Overlap % ChainElemTyBits == 0)
737 AreContiguous = true;
738 }
739
740 LLVM_DEBUG(dbgs() << "LSV: Instruction is "
741 << (AreContiguous ? "contiguous" : "chain-breaker")
742 << *It->Inst << " (starts at offset "
743 << It->OffsetFromLeader << ")\n");
744
745 // If the chain is not contiguous, try to fill in gaps between Prev and
746 // Curr. For now, we aren't filling gaps between load/stores of different
747 // sizes. Additionally, as a conservative heuristic, we only fill gaps of
748 // 1-2 elements. Generating loads/stores with too many unused bytes has a
749 // side effect of increasing register pressure (on NVIDIA targets at least),
750 // which could cancel out the benefits of reducing number of load/stores.
751 bool GapFilled = false;
752 if (!AreContiguous && TryFillGaps && PrevSzBytes == SzBytes) {
753 APInt GapSzBytes = It->OffsetFromLeader - PrevReadEnd;
754 if (GapSzBytes == PrevSzBytes) {
755 // There is a single gap between Prev and Curr, create one extra element
756 ChainElem NewElem = createExtraElementAfter(
757 PrevElem: Prev, Ty: getLoadStoreType(I: Prev.Inst), Offset: PrevSzBytes, Prefix: "GapFill",
758 Alignment: DeriveAlignFromBestAlignedElem(PrevReadEnd));
759 CurChain.push_back(Elt: NewElem);
760 GapFilled = true;
761 }
762 // There are two gaps between Prev and Curr, only create two extra
763 // elements if Prev is the first element in a sequence of four.
764 // This has the highest chance of resulting in a beneficial vectorization.
765 if ((GapSzBytes == 2 * PrevSzBytes) && (CurChain.size() % 4 == 1)) {
766 ChainElem NewElem1 = createExtraElementAfter(
767 PrevElem: Prev, Ty: getLoadStoreType(I: Prev.Inst), Offset: PrevSzBytes, Prefix: "GapFill",
768 Alignment: DeriveAlignFromBestAlignedElem(PrevReadEnd));
769 ChainElem NewElem2 = createExtraElementAfter(
770 PrevElem: NewElem1, Ty: getLoadStoreType(I: Prev.Inst), Offset: PrevSzBytes, Prefix: "GapFill",
771 Alignment: DeriveAlignFromBestAlignedElem(PrevReadEnd + PrevSzBytes));
772 CurChain.push_back(Elt: NewElem1);
773 CurChain.push_back(Elt: NewElem2);
774 GapFilled = true;
775 }
776 }
777
778 if (AreContiguous || GapFilled)
779 CurChain.push_back(Elt: *It);
780 else
781 Ret.push_back(x: {*It});
782 // In certain cases when handling redundant elements with partial overlaps,
783 // the previous element may still extend beyond the current element. Only
784 // update Prev if the current element is the new end of the chain.
785 if (ReadEnd.sge(RHS: PrevReadEnd))
786 Prev = *It;
787 }
788
789 // Filter out length-1 chains, these are uninteresting.
790 llvm::erase_if(C&: Ret, P: [](const auto &Chain) { return Chain.size() <= 1; });
791 return Ret;
792}
793
794Type *Vectorizer::getChainElemTy(const Chain &C) {
795 assert(!C.empty());
796 // The rules are:
797 // - If there are any pointer types in the chain, use an integer type.
798 // - Prefer an integer type if it appears in the chain.
799 // - Otherwise, use the first type in the chain.
800 //
801 // The rule about pointer types is a simplification when we merge e.g. a load
802 // of a ptr and a double. There's no direct conversion from a ptr to a
803 // double; it requires a ptrtoint followed by a bitcast.
804 //
805 // It's unclear to me if the other rules have any practical effect, but we do
806 // it to match this pass's previous behavior.
807 if (any_of(Range: C, P: [](const ChainElem &E) {
808 return getLoadStoreType(I: E.Inst)->getScalarType()->isPointerTy();
809 })) {
810 return Type::getIntNTy(
811 C&: F.getContext(),
812 N: DL.getTypeSizeInBits(Ty: getLoadStoreType(I: C[0].Inst)->getScalarType()));
813 }
814
815 for (const ChainElem &E : C)
816 if (Type *T = getLoadStoreType(I: E.Inst)->getScalarType(); T->isIntegerTy())
817 return T;
818 return getLoadStoreType(I: C[0].Inst)->getScalarType();
819}
820
821std::vector<Chain> Vectorizer::splitChainByAlignment(Chain &C) {
822 // We use a simple greedy algorithm.
823 // - Given a chain of length N, find all prefixes that
824 // (a) are not longer than the max register length, and
825 // (b) are a power of 2.
826 // - Starting from the longest prefix, try to create a vector of that length.
827 // - If one of them works, great. Repeat the algorithm on any remaining
828 // elements in the chain.
829 // - If none of them work, discard the first element and repeat on a chain
830 // of length N-1.
831 if (C.empty())
832 return {};
833
834 sortChainInOffsetOrder(C);
835
836 LLVM_DEBUG({
837 dbgs() << "LSV: splitChainByAlignment considering chain:\n";
838 dumpChain(C);
839 });
840
841 bool IsLoadChain = isa<LoadInst>(Val: C[0].Inst);
842 auto GetVectorFactor = [&](unsigned VF, unsigned LoadStoreSize,
843 unsigned ChainSizeBytes, VectorType *VecTy) {
844 return IsLoadChain ? TTI.getLoadVectorFactor(VF, LoadSize: LoadStoreSize,
845 ChainSizeInBytes: ChainSizeBytes, VecTy)
846 : TTI.getStoreVectorFactor(VF, StoreSize: LoadStoreSize,
847 ChainSizeInBytes: ChainSizeBytes, VecTy);
848 };
849
850#ifndef NDEBUG
851 for (const auto &E : C) {
852 Type *Ty = getLoadStoreType(E.Inst)->getScalarType();
853 assert(isPowerOf2_32(DL.getTypeSizeInBits(Ty)) &&
854 "Should have filtered out non-power-of-two elements in "
855 "collectEquivalenceClasses.");
856 }
857#endif
858
859 unsigned AS = getLoadStoreAddressSpace(I: C[0].Inst);
860 unsigned VecRegBytes = TTI.getLoadStoreVecRegBitWidth(AddrSpace: AS) / 8;
861
862 // For compile time reasons, we cache whether or not the superset
863 // of all candidate chains contains any extra loads/stores from earlier gap
864 // filling.
865 bool CandidateChainsMayContainExtraLoadsStores = any_of(
866 Range&: C, P: [this](const ChainElem &E) { return ExtraElements.contains(V: E.Inst); });
867
868 std::vector<Chain> Ret;
869 for (unsigned CBegin = 0; CBegin < C.size(); ++CBegin) {
870 // Find candidate chains of size not greater than the largest vector reg.
871 // These chains are over the closed interval [CBegin, CEnd].
872 SmallVector<std::pair<unsigned /*CEnd*/, unsigned /*SizeBytes*/>, 8>
873 CandidateChains;
874 // Need to compute the size of every candidate chain from its beginning
875 // because of possible overlapping among chain elements.
876 unsigned Sz = DL.getTypeStoreSize(Ty: getLoadStoreType(I: C[CBegin].Inst));
877 APInt PrevReadEnd = C[CBegin].OffsetFromLeader + Sz;
878 for (unsigned CEnd = CBegin + 1, Size = C.size(); CEnd < Size; ++CEnd) {
879 APInt ReadEnd = C[CEnd].OffsetFromLeader +
880 DL.getTypeStoreSize(Ty: getLoadStoreType(I: C[CEnd].Inst));
881 unsigned BytesAdded =
882 PrevReadEnd.sle(RHS: ReadEnd) ? (ReadEnd - PrevReadEnd).getSExtValue() : 0;
883 Sz += BytesAdded;
884 if (Sz > VecRegBytes)
885 break;
886 CandidateChains.emplace_back(Args&: CEnd, Args&: Sz);
887 PrevReadEnd = APIntOps::smax(A: PrevReadEnd, B: ReadEnd);
888 }
889
890 // Consider the longest chain first.
891 for (auto It = CandidateChains.rbegin(), End = CandidateChains.rend();
892 It != End; ++It) {
893 auto [CEnd, SizeBytes] = *It;
894 LLVM_DEBUG(
895 dbgs() << "LSV: splitChainByAlignment considering candidate chain ["
896 << *C[CBegin].Inst << " ... " << *C[CEnd].Inst << "]\n");
897
898 Type *VecElemTy = getChainElemTy(C);
899 // Note, VecElemTy is a power of 2, but might be less than one byte. For
900 // example, we can vectorize 2 x <2 x i4> to <4 x i4>, and in this case
901 // VecElemTy would be i4.
902 unsigned VecElemBits = DL.getTypeSizeInBits(Ty: VecElemTy);
903
904 // SizeBytes and VecElemBits are powers of 2, so they divide evenly.
905 assert((8 * SizeBytes) % VecElemBits == 0);
906 unsigned NumVecElems = 8 * SizeBytes / VecElemBits;
907 FixedVectorType *VecTy = FixedVectorType::get(ElementType: VecElemTy, NumElts: NumVecElems);
908 unsigned VF = 8 * VecRegBytes / VecElemBits;
909
910 // Check that TTI is happy with this vectorization factor.
911 unsigned TargetVF = GetVectorFactor(VF, VecElemBits,
912 VecElemBits * NumVecElems / 8, VecTy);
913 if (TargetVF != VF && TargetVF < NumVecElems) {
914 LLVM_DEBUG(
915 dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
916 "because TargetVF="
917 << TargetVF << " != VF=" << VF
918 << " and TargetVF < NumVecElems=" << NumVecElems << "\n");
919 continue;
920 }
921
922 // If we're loading/storing from an alloca, align it if possible.
923 //
924 // FIXME: We eagerly upgrade the alignment, regardless of whether TTI
925 // tells us this is beneficial. This feels a bit odd, but it matches
926 // existing tests. This isn't *so* bad, because at most we align to 4
927 // bytes (current value of StackAdjustedAlignment).
928 //
929 // FIXME: We will upgrade the alignment of the alloca even if it turns out
930 // we can't vectorize for some other reason.
931 Value *PtrOperand = getLoadStorePointerOperand(V: C[CBegin].Inst);
932 bool IsAllocaAccess = AS == DL.getAllocaAddrSpace() &&
933 isa<AllocaInst>(Val: PtrOperand->stripPointerCasts());
934 Align Alignment = getLoadStoreAlignment(I: C[CBegin].Inst);
935 Align PrefAlign = Align(StackAdjustedAlignment);
936 if (IsAllocaAccess && Alignment.value() % SizeBytes != 0 &&
937 accessIsAllowedAndFast(SizeBytes, AS, Alignment: PrefAlign, VecElemBits)) {
938 Align NewAlign = getOrEnforceKnownAlignment(
939 V: PtrOperand, PrefAlign, DL, CxtI: C[CBegin].Inst, AC: nullptr, DT: &DT);
940 if (NewAlign >= Alignment) {
941 LLVM_DEBUG(dbgs()
942 << "LSV: splitByChain upgrading alloca alignment from "
943 << Alignment.value() << " to " << NewAlign.value()
944 << "\n");
945 Alignment = NewAlign;
946 }
947 }
948
949 Chain ExtendingLoadsStores;
950 if (!accessIsAllowedAndFast(SizeBytes, AS, Alignment, VecElemBits)) {
951 // If we have a non-power-of-2 element count, attempt to extend the
952 // chain to the next power-of-2 if it makes the access allowed and
953 // fast.
954 bool AllowedAndFast = false;
955 if (NumVecElems < TargetVF && !isPowerOf2_32(Value: NumVecElems) &&
956 VecElemBits >= 8) {
957 // TargetVF may be a lot higher than NumVecElems,
958 // so only extend to the next power of 2.
959 assert(VecElemBits % 8 == 0);
960 unsigned VecElemBytes = VecElemBits / 8;
961 unsigned NewNumVecElems = PowerOf2Ceil(A: NumVecElems);
962 unsigned NewSizeBytes = VecElemBytes * NewNumVecElems;
963
964 assert(isPowerOf2_32(TargetVF) &&
965 "TargetVF expected to be a power of 2");
966 assert(NewNumVecElems <= TargetVF &&
967 "Should not extend past TargetVF");
968
969 LLVM_DEBUG(dbgs()
970 << "LSV: attempting to extend chain of " << NumVecElems
971 << " " << (IsLoadChain ? "loads" : "stores") << " to "
972 << NewNumVecElems << " elements\n");
973 bool IsLegalToExtend =
974 IsLoadChain ? TTI.isLegalMaskedLoad(
975 DataType: FixedVectorType::get(ElementType: VecElemTy, NumElts: NewNumVecElems),
976 Alignment, AddressSpace: AS, MaskKind: TTI::MaskKind::ConstantMask)
977 : TTI.isLegalMaskedStore(
978 DataType: FixedVectorType::get(ElementType: VecElemTy, NumElts: NewNumVecElems),
979 Alignment, AddressSpace: AS, MaskKind: TTI::MaskKind::ConstantMask);
980 // Only artificially increase the chain if it would be AllowedAndFast
981 // and if the resulting masked load/store will be legal for the
982 // target.
983 if (IsLegalToExtend &&
984 accessIsAllowedAndFast(SizeBytes: NewSizeBytes, AS, Alignment,
985 VecElemBits)) {
986 LLVM_DEBUG(dbgs()
987 << "LSV: extending " << (IsLoadChain ? "load" : "store")
988 << " chain of " << NumVecElems << " "
989 << (IsLoadChain ? "loads" : "stores")
990 << " with total byte size of " << SizeBytes << " to "
991 << NewNumVecElems << " "
992 << (IsLoadChain ? "loads" : "stores")
993 << " with total byte size of " << NewSizeBytes
994 << ", TargetVF=" << TargetVF << " \n");
995
996 // Create (NewNumVecElems - NumVecElems) extra elements.
997 // We are basing each extra element on CBegin, which means the
998 // offsets should be based on SizeBytes, which represents the offset
999 // from CBegin to the current end of the chain.
1000 unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
1001 for (unsigned I = 0; I < (NewNumVecElems - NumVecElems); I++) {
1002 ChainElem NewElem = createExtraElementAfter(
1003 PrevElem: C[CBegin], Ty: VecElemTy,
1004 Offset: APInt(ASPtrBits, SizeBytes + I * VecElemBytes), Prefix: "Extend");
1005 ExtendingLoadsStores.push_back(Elt: NewElem);
1006 }
1007
1008 // Update the size and number of elements for upcoming checks.
1009 SizeBytes = NewSizeBytes;
1010 NumVecElems = NewNumVecElems;
1011 AllowedAndFast = true;
1012 }
1013 }
1014 if (!AllowedAndFast) {
1015 // We were not able to achieve legality by extending the chain.
1016 LLVM_DEBUG(dbgs()
1017 << "LSV: splitChainByAlignment discarding candidate chain "
1018 "because its alignment is not AllowedAndFast: "
1019 << Alignment.value() << "\n");
1020 continue;
1021 }
1022 }
1023
1024 if ((IsLoadChain &&
1025 !TTI.isLegalToVectorizeLoadChain(ChainSizeInBytes: SizeBytes, Alignment, AddrSpace: AS)) ||
1026 (!IsLoadChain &&
1027 !TTI.isLegalToVectorizeStoreChain(ChainSizeInBytes: SizeBytes, Alignment, AddrSpace: AS))) {
1028 LLVM_DEBUG(
1029 dbgs() << "LSV: splitChainByAlignment discarding candidate chain "
1030 "because !isLegalToVectorizeLoad/StoreChain.");
1031 continue;
1032 }
1033
1034 if (CandidateChainsMayContainExtraLoadsStores) {
1035 // If the candidate chain contains extra loads/stores from an earlier
1036 // optimization, confirm legality now. This filter is essential because
1037 // when filling gaps in splitChainByContiguity, we queried the API to
1038 // check that (for a given element type and address space) there *may*
1039 // have been a legal masked load/store we could possibly create. Now, we
1040 // need to check if the actual chain we ended up with is legal to turn
1041 // into a masked load/store. This is relevant for NVPTX, for example,
1042 // where a masked store is only legal if we have ended up with a 256-bit
1043 // vector.
1044 bool CurrCandContainsExtraLoadsStores = llvm::any_of(
1045 Range: ArrayRef<ChainElem>(C).slice(N: CBegin, M: CEnd - CBegin + 1),
1046 P: [this](const ChainElem &E) {
1047 return ExtraElements.contains(V: E.Inst);
1048 });
1049
1050 if (CurrCandContainsExtraLoadsStores &&
1051 (IsLoadChain ? !TTI.isLegalMaskedLoad(
1052 DataType: FixedVectorType::get(ElementType: VecElemTy, NumElts: NumVecElems),
1053 Alignment, AddressSpace: AS, MaskKind: TTI::MaskKind::ConstantMask)
1054 : !TTI.isLegalMaskedStore(
1055 DataType: FixedVectorType::get(ElementType: VecElemTy, NumElts: NumVecElems),
1056 Alignment, AddressSpace: AS, MaskKind: TTI::MaskKind::ConstantMask))) {
1057 LLVM_DEBUG(dbgs()
1058 << "LSV: splitChainByAlignment discarding candidate chain "
1059 "because it contains extra loads/stores that we cannot "
1060 "legally vectorize into a masked load/store \n");
1061 continue;
1062 }
1063 }
1064
1065 // Hooray, we can vectorize this chain!
1066 Chain &NewChain = Ret.emplace_back();
1067 for (unsigned I = CBegin; I <= CEnd; ++I)
1068 NewChain.emplace_back(Args&: C[I]);
1069 for (ChainElem E : ExtendingLoadsStores)
1070 NewChain.emplace_back(Args&: E);
1071 CBegin = CEnd; // Skip over the instructions we've added to the chain.
1072 break;
1073 }
1074 }
1075 return Ret;
1076}
1077
1078bool Vectorizer::vectorizeChain(Chain &C) {
1079 if (C.size() < 2)
1080 return false;
1081
1082 bool ChainContainsExtraLoadsStores = llvm::any_of(
1083 Range&: C, P: [this](const ChainElem &E) { return ExtraElements.contains(V: E.Inst); });
1084
1085 // If we are left with a two-element chain, and one of the elements is an
1086 // extra element, we don't want to vectorize
1087 if (C.size() == 2 && ChainContainsExtraLoadsStores)
1088 return false;
1089
1090 sortChainInOffsetOrder(C);
1091
1092 LLVM_DEBUG({
1093 dbgs() << "LSV: Vectorizing chain of " << C.size() << " instructions:\n";
1094 dumpChain(C);
1095 });
1096
1097 Type *VecElemTy = getChainElemTy(C);
1098 bool IsLoadChain = isa<LoadInst>(Val: C[0].Inst);
1099 unsigned AS = getLoadStoreAddressSpace(I: C[0].Inst);
1100 unsigned BytesAdded = DL.getTypeStoreSize(Ty: getLoadStoreType(I: &*C[0].Inst));
1101 APInt PrevReadEnd = C[0].OffsetFromLeader + BytesAdded;
1102 unsigned ChainBytes = BytesAdded;
1103 for (auto It = std::next(x: C.begin()), End = C.end(); It != End; ++It) {
1104 unsigned SzBytes = DL.getTypeStoreSize(Ty: getLoadStoreType(I: &*It->Inst));
1105 APInt ReadEnd = It->OffsetFromLeader + SzBytes;
1106 // Update ChainBytes considering possible overlap.
1107 BytesAdded =
1108 PrevReadEnd.sle(RHS: ReadEnd) ? (ReadEnd - PrevReadEnd).getSExtValue() : 0;
1109 ChainBytes += BytesAdded;
1110 PrevReadEnd = APIntOps::smax(A: PrevReadEnd, B: ReadEnd);
1111 }
1112
1113 assert(8 * ChainBytes % DL.getTypeSizeInBits(VecElemTy) == 0);
1114 // VecTy is a power of 2 and 1 byte at smallest, but VecElemTy may be smaller
1115 // than 1 byte (e.g. VecTy == <32 x i1>).
1116 unsigned NumElem = 8 * ChainBytes / DL.getTypeSizeInBits(Ty: VecElemTy);
1117 Type *VecTy = FixedVectorType::get(ElementType: VecElemTy, NumElts: NumElem);
1118
1119 Align Alignment = getLoadStoreAlignment(I: C[0].Inst);
1120 // If this is a load/store of an alloca, we might have upgraded the alloca's
1121 // alignment earlier. Get the new alignment.
1122 if (AS == DL.getAllocaAddrSpace()) {
1123 Alignment = std::max(
1124 a: Alignment,
1125 b: getOrEnforceKnownAlignment(V: getLoadStorePointerOperand(V: C[0].Inst),
1126 PrefAlign: MaybeAlign(), DL, CxtI: C[0].Inst, AC: nullptr, DT: &DT));
1127 }
1128
1129 // All elements of the chain must have the same scalar-type size.
1130#ifndef NDEBUG
1131 for (const ChainElem &E : C)
1132 assert(DL.getTypeStoreSize(getLoadStoreType(E.Inst)->getScalarType()) ==
1133 DL.getTypeStoreSize(VecElemTy));
1134#endif
1135
1136 Instruction *VecInst;
1137 if (IsLoadChain) {
1138 // Loads get hoisted to the location of the first load in the chain. We may
1139 // also need to hoist the (transitive) operands of the loads.
1140 Builder.SetInsertPoint(
1141 llvm::min_element(Range&: C, C: [](const auto &A, const auto &B) {
1142 return A.Inst->comesBefore(B.Inst);
1143 })->Inst);
1144
1145 // If the chain contains extra loads, we need to vectorize into a
1146 // masked load.
1147 if (ChainContainsExtraLoadsStores) {
1148 assert(TTI.isLegalMaskedLoad(VecTy, Alignment, AS,
1149 TTI::MaskKind::ConstantMask));
1150 Value *Mask = createMaskForExtraElements(C, VecTy: cast<FixedVectorType>(Val: VecTy));
1151 VecInst = Builder.CreateMaskedLoad(
1152 Ty: VecTy, Ptr: getLoadStorePointerOperand(V: C[0].Inst), Alignment, Mask);
1153 } else {
1154 // This can happen due to a chain of redundant loads.
1155 // In this case, just use the element-type, and avoid ExtractElement.
1156 if (NumElem == 1)
1157 VecTy = VecElemTy;
1158 // Chain is in offset order, so C[0] is the instr with the lowest offset,
1159 // i.e. the root of the vector.
1160 VecInst = Builder.CreateAlignedLoad(
1161 Ty: VecTy, Ptr: getLoadStorePointerOperand(V: C[0].Inst), Align: Alignment);
1162 }
1163
1164 for (const ChainElem &E : C) {
1165 Instruction *I = E.Inst;
1166 Value *V;
1167 Type *T = getLoadStoreType(I);
1168 unsigned EOffset =
1169 (E.OffsetFromLeader - C[0].OffsetFromLeader).getZExtValue();
1170 unsigned VecIdx = 8 * EOffset / DL.getTypeSizeInBits(Ty: VecElemTy);
1171 if (!VecTy->isVectorTy()) {
1172 V = VecInst;
1173 } else if (auto *VT = dyn_cast<FixedVectorType>(Val: T)) {
1174 auto Mask = llvm::to_vector<8>(
1175 Range: llvm::seq<int>(Begin: VecIdx, End: VecIdx + VT->getNumElements()));
1176 V = Builder.CreateShuffleVector(V: VecInst, Mask, Name: I->getName());
1177 } else {
1178 V = Builder.CreateExtractElement(Vec: VecInst, Idx: Builder.getInt32(C: VecIdx),
1179 Name: I->getName());
1180 }
1181 if (V->getType() != I->getType())
1182 V = Builder.CreateBitOrPointerCast(V, DestTy: I->getType());
1183 I->replaceAllUsesWith(V);
1184 }
1185
1186 // Finally, we need to reorder the instrs in the BB so that the (transitive)
1187 // operands of VecInst appear before it. To see why, suppose we have
1188 // vectorized the following code:
1189 //
1190 // ptr1 = gep a, 1
1191 // load1 = load i32 ptr1
1192 // ptr0 = gep a, 0
1193 // load0 = load i32 ptr0
1194 //
1195 // We will put the vectorized load at the location of the earliest load in
1196 // the BB, i.e. load1. We get:
1197 //
1198 // ptr1 = gep a, 1
1199 // loadv = load <2 x i32> ptr0
1200 // load0 = extractelement loadv, 0
1201 // load1 = extractelement loadv, 1
1202 // ptr0 = gep a, 0
1203 //
1204 // Notice that loadv uses ptr0, which is defined *after* it!
1205 reorder(I: VecInst);
1206 } else {
1207 // Stores get sunk to the location of the last store in the chain.
1208 Builder.SetInsertPoint(llvm::max_element(Range&: C, C: [](auto &A, auto &B) {
1209 return A.Inst->comesBefore(B.Inst);
1210 })->Inst);
1211
1212 // Build the vector to store.
1213 Value *Vec = PoisonValue::get(T: VecTy);
1214 auto InsertElem = [&](Value *V, unsigned VecIdx) {
1215 if (V->getType() != VecElemTy)
1216 V = Builder.CreateBitOrPointerCast(V, DestTy: VecElemTy);
1217 Vec = Builder.CreateInsertElement(Vec, NewElt: V, Idx: Builder.getInt32(C: VecIdx));
1218 };
1219 for (const ChainElem &E : C) {
1220 auto *I = cast<StoreInst>(Val: E.Inst);
1221 unsigned EOffset =
1222 (E.OffsetFromLeader - C[0].OffsetFromLeader).getZExtValue();
1223 unsigned VecIdx = 8 * EOffset / DL.getTypeSizeInBits(Ty: VecElemTy);
1224 if (FixedVectorType *VT =
1225 dyn_cast<FixedVectorType>(Val: getLoadStoreType(I))) {
1226 for (int J = 0, JE = VT->getNumElements(); J < JE; ++J) {
1227 InsertElem(Builder.CreateExtractElement(Vec: I->getValueOperand(),
1228 Idx: Builder.getInt32(C: J)),
1229 VecIdx++);
1230 }
1231 } else {
1232 InsertElem(I->getValueOperand(), VecIdx);
1233 }
1234 }
1235
1236 // If the chain originates from extra stores, we need to vectorize into a
1237 // masked store.
1238 if (ChainContainsExtraLoadsStores) {
1239 assert(TTI.isLegalMaskedStore(Vec->getType(), Alignment, AS,
1240 TTI::MaskKind::ConstantMask));
1241 Value *Mask =
1242 createMaskForExtraElements(C, VecTy: cast<FixedVectorType>(Val: Vec->getType()));
1243 VecInst = Builder.CreateMaskedStore(
1244 Val: Vec, Ptr: getLoadStorePointerOperand(V: C[0].Inst), Alignment, Mask);
1245 } else {
1246 // Chain is in offset order, so C[0] is the instr with the lowest offset,
1247 // i.e. the root of the vector.
1248 VecInst = Builder.CreateAlignedStore(
1249 Val: Vec, Ptr: getLoadStorePointerOperand(V: C[0].Inst), Align: Alignment);
1250 }
1251 }
1252
1253 propagateMetadata(I: VecInst, C);
1254
1255 for (const ChainElem &E : C)
1256 ToErase.emplace_back(Args: E.Inst);
1257
1258 ++NumVectorInstructions;
1259 NumScalarsVectorized += C.size();
1260 return true;
1261}
1262
1263template <bool IsLoadChain>
1264bool Vectorizer::isSafeToMove(
1265 Instruction *ChainElem, Instruction *ChainBegin,
1266 const DenseMap<Instruction *, APInt /*OffsetFromLeader*/> &ChainOffsets,
1267 BatchAAResults &BatchAA) {
1268 LLVM_DEBUG(dbgs() << "LSV: isSafeToMove(" << *ChainElem << " -> "
1269 << *ChainBegin << ")\n");
1270
1271 assert(isa<LoadInst>(ChainElem) == IsLoadChain);
1272 if (ChainElem == ChainBegin)
1273 return true;
1274
1275 // Invariant loads can always be reordered; by definition they are not
1276 // clobbered by stores.
1277 if (isInvariantLoad(I: ChainElem))
1278 return true;
1279
1280 auto BBIt = std::next([&] {
1281 if constexpr (IsLoadChain)
1282 return BasicBlock::reverse_iterator(ChainElem);
1283 else
1284 return BasicBlock::iterator(ChainElem);
1285 }());
1286 auto BBItEnd = std::next([&] {
1287 if constexpr (IsLoadChain)
1288 return BasicBlock::reverse_iterator(ChainBegin);
1289 else
1290 return BasicBlock::iterator(ChainBegin);
1291 }());
1292
1293 const APInt &ChainElemOffset = ChainOffsets.at(Val: ChainElem);
1294 const unsigned ChainElemSize =
1295 DL.getTypeStoreSize(Ty: getLoadStoreType(I: ChainElem));
1296
1297 for (; BBIt != BBItEnd; ++BBIt) {
1298 Instruction *I = &*BBIt;
1299
1300 if (!I->mayReadOrWriteMemory())
1301 continue;
1302
1303 // Loads can be reordered with other loads.
1304 if (IsLoadChain && isa<LoadInst>(Val: I))
1305 continue;
1306
1307 // Stores can be sunk below invariant loads.
1308 if (!IsLoadChain && isInvariantLoad(I))
1309 continue;
1310
1311 // If I is in the chain, we can tell whether it aliases ChainIt by checking
1312 // what offset ChainIt accesses. This may be better than AA is able to do.
1313 //
1314 // We should really only have duplicate offsets for stores (the duplicate
1315 // loads should be CSE'ed), but in case we have a duplicate load, we'll
1316 // split the chain so we don't have to handle this case specially.
1317 if (auto OffsetIt = ChainOffsets.find(Val: I); OffsetIt != ChainOffsets.end()) {
1318 // I and ChainElem overlap if:
1319 // - I and ChainElem have the same offset, OR
1320 // - I's offset is less than ChainElem's, but I touches past the
1321 // beginning of ChainElem, OR
1322 // - ChainElem's offset is less than I's, but ChainElem touches past the
1323 // beginning of I.
1324 const APInt &IOffset = OffsetIt->second;
1325 unsigned IElemSize = DL.getTypeStoreSize(Ty: getLoadStoreType(I));
1326 if (IOffset == ChainElemOffset ||
1327 (IOffset.sle(RHS: ChainElemOffset) &&
1328 (IOffset + IElemSize).sgt(RHS: ChainElemOffset)) ||
1329 (ChainElemOffset.sle(RHS: IOffset) &&
1330 (ChainElemOffset + ChainElemSize).sgt(RHS: OffsetIt->second))) {
1331 LLVM_DEBUG({
1332 // Double check that AA also sees this alias. If not, we probably
1333 // have a bug.
1334 ModRefInfo MR =
1335 BatchAA.getModRefInfo(I, MemoryLocation::get(ChainElem));
1336 assert(IsLoadChain ? isModSet(MR) : isModOrRefSet(MR));
1337 dbgs() << "LSV: Found alias in chain: " << *I << "\n";
1338 });
1339 return false; // We found an aliasing instruction; bail.
1340 }
1341
1342 continue; // We're confident there's no alias.
1343 }
1344
1345 LLVM_DEBUG(dbgs() << "LSV: Querying AA for " << *I << "\n");
1346 ModRefInfo MR = BatchAA.getModRefInfo(I, OptLoc: MemoryLocation::get(Inst: ChainElem));
1347 if (IsLoadChain ? isModSet(MRI: MR) : isModOrRefSet(MRI: MR)) {
1348 LLVM_DEBUG(dbgs() << "LSV: Found alias in chain:\n"
1349 << " Aliasing instruction:\n"
1350 << " " << *I << '\n'
1351 << " Aliased instruction and pointer:\n"
1352 << " " << *ChainElem << '\n'
1353 << " " << *getLoadStorePointerOperand(ChainElem)
1354 << '\n');
1355
1356 return false;
1357 }
1358 }
1359 return true;
1360}
1361
1362static bool checkNoWrapFlags(Instruction *I, bool Signed) {
1363 BinaryOperator *BinOpI = cast<BinaryOperator>(Val: I);
1364 return (Signed && BinOpI->hasNoSignedWrap()) ||
1365 (!Signed && BinOpI->hasNoUnsignedWrap());
1366}
1367
1368static bool checkIfSafeAddSequence(const APInt &IdxDiff, Instruction *AddOpA,
1369 unsigned MatchingOpIdxA, Instruction *AddOpB,
1370 unsigned MatchingOpIdxB, bool Signed) {
1371 LLVM_DEBUG(dbgs() << "LSV: checkIfSafeAddSequence IdxDiff=" << IdxDiff
1372 << ", AddOpA=" << *AddOpA << ", MatchingOpIdxA="
1373 << MatchingOpIdxA << ", AddOpB=" << *AddOpB
1374 << ", MatchingOpIdxB=" << MatchingOpIdxB
1375 << ", Signed=" << Signed << "\n");
1376 // If both OpA and OpB are adds with NSW/NUW and with one of the operands
1377 // being the same, we can guarantee that the transformation is safe if we can
1378 // prove that OpA won't overflow when Ret added to the other operand of OpA.
1379 // For example:
1380 // %tmp7 = add nsw i32 %tmp2, %v0
1381 // %tmp8 = sext i32 %tmp7 to i64
1382 // ...
1383 // %tmp11 = add nsw i32 %v0, 1
1384 // %tmp12 = add nsw i32 %tmp2, %tmp11
1385 // %tmp13 = sext i32 %tmp12 to i64
1386 //
1387 // Both %tmp7 and %tmp12 have the nsw flag and the first operand is %tmp2.
1388 // It's guaranteed that adding 1 to %tmp7 won't overflow because %tmp11 adds
1389 // 1 to %v0 and both %tmp11 and %tmp12 have the nsw flag.
1390 assert(AddOpA->getOpcode() == Instruction::Add &&
1391 AddOpB->getOpcode() == Instruction::Add &&
1392 checkNoWrapFlags(AddOpA, Signed) && checkNoWrapFlags(AddOpB, Signed));
1393 if (AddOpA->getOperand(i: MatchingOpIdxA) ==
1394 AddOpB->getOperand(i: MatchingOpIdxB)) {
1395 Value *OtherOperandA = AddOpA->getOperand(i: MatchingOpIdxA == 1 ? 0 : 1);
1396 Value *OtherOperandB = AddOpB->getOperand(i: MatchingOpIdxB == 1 ? 0 : 1);
1397 Instruction *OtherInstrA = dyn_cast<Instruction>(Val: OtherOperandA);
1398 Instruction *OtherInstrB = dyn_cast<Instruction>(Val: OtherOperandB);
1399 // Match `x +nsw/nuw y` and `x +nsw/nuw (y +nsw/nuw IdxDiff)`.
1400 if (OtherInstrB && OtherInstrB->getOpcode() == Instruction::Add &&
1401 checkNoWrapFlags(I: OtherInstrB, Signed) &&
1402 isa<ConstantInt>(Val: OtherInstrB->getOperand(i: 1))) {
1403 int64_t CstVal =
1404 cast<ConstantInt>(Val: OtherInstrB->getOperand(i: 1))->getSExtValue();
1405 if (OtherInstrB->getOperand(i: 0) == OtherOperandA &&
1406 IdxDiff.getSExtValue() == CstVal)
1407 return true;
1408 }
1409 // Match `x +nsw/nuw (y +nsw/nuw -Idx)` and `x +nsw/nuw (y +nsw/nuw x)`.
1410 if (OtherInstrA && OtherInstrA->getOpcode() == Instruction::Add &&
1411 checkNoWrapFlags(I: OtherInstrA, Signed) &&
1412 isa<ConstantInt>(Val: OtherInstrA->getOperand(i: 1))) {
1413 int64_t CstVal =
1414 cast<ConstantInt>(Val: OtherInstrA->getOperand(i: 1))->getSExtValue();
1415 if (OtherInstrA->getOperand(i: 0) == OtherOperandB &&
1416 IdxDiff.getSExtValue() == -CstVal)
1417 return true;
1418 }
1419 // Match `x +nsw/nuw (y +nsw/nuw c)` and
1420 // `x +nsw/nuw (y +nsw/nuw (c + IdxDiff))`.
1421 if (OtherInstrA && OtherInstrB &&
1422 OtherInstrA->getOpcode() == Instruction::Add &&
1423 OtherInstrB->getOpcode() == Instruction::Add &&
1424 checkNoWrapFlags(I: OtherInstrA, Signed) &&
1425 checkNoWrapFlags(I: OtherInstrB, Signed) &&
1426 isa<ConstantInt>(Val: OtherInstrA->getOperand(i: 1)) &&
1427 isa<ConstantInt>(Val: OtherInstrB->getOperand(i: 1))) {
1428 int64_t CstValA =
1429 cast<ConstantInt>(Val: OtherInstrA->getOperand(i: 1))->getSExtValue();
1430 int64_t CstValB =
1431 cast<ConstantInt>(Val: OtherInstrB->getOperand(i: 1))->getSExtValue();
1432 if (OtherInstrA->getOperand(i: 0) == OtherInstrB->getOperand(i: 0) &&
1433 IdxDiff.getSExtValue() == (CstValB - CstValA))
1434 return true;
1435 }
1436 }
1437 return false;
1438}
1439
1440std::optional<APInt> Vectorizer::getConstantOffsetComplexAddrs(
1441 Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
1442 LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs PtrA=" << *PtrA
1443 << " PtrB=" << *PtrB << " ContextInst=" << *ContextInst
1444 << " Depth=" << Depth << "\n");
1445 auto *GEPA = dyn_cast<GetElementPtrInst>(Val: PtrA);
1446 auto *GEPB = dyn_cast<GetElementPtrInst>(Val: PtrB);
1447 if (!GEPA || !GEPB)
1448 return getConstantOffsetSelects(PtrA, PtrB, ContextInst, Depth);
1449
1450 // Look through GEPs after checking they're the same except for the last
1451 // index.
1452 if (GEPA->getNumOperands() != GEPB->getNumOperands() ||
1453 GEPA->getPointerOperand() != GEPB->getPointerOperand())
1454 return std::nullopt;
1455 gep_type_iterator GTIA = gep_type_begin(GEP: GEPA);
1456 gep_type_iterator GTIB = gep_type_begin(GEP: GEPB);
1457 for (unsigned I = 0, E = GEPA->getNumIndices() - 1; I < E; ++I) {
1458 if (GTIA.getOperand() != GTIB.getOperand())
1459 return std::nullopt;
1460 ++GTIA;
1461 ++GTIB;
1462 }
1463
1464 Instruction *OpA = dyn_cast<Instruction>(Val: GTIA.getOperand());
1465 Instruction *OpB = dyn_cast<Instruction>(Val: GTIB.getOperand());
1466 if (!OpA || !OpB || OpA->getOpcode() != OpB->getOpcode() ||
1467 OpA->getType() != OpB->getType())
1468 return std::nullopt;
1469
1470 uint64_t Stride = GTIA.getSequentialElementStride(DL);
1471
1472 // Only look through a ZExt/SExt.
1473 if (!isa<SExtInst>(Val: OpA) && !isa<ZExtInst>(Val: OpA))
1474 return std::nullopt;
1475
1476 bool Signed = isa<SExtInst>(Val: OpA);
1477
1478 // At this point A could be a function parameter, i.e. not an instruction
1479 Value *ValA = OpA->getOperand(i: 0);
1480 OpB = dyn_cast<Instruction>(Val: OpB->getOperand(i: 0));
1481 if (!OpB || ValA->getType() != OpB->getType())
1482 return std::nullopt;
1483
1484 const SCEV *OffsetSCEVA = SE.getSCEV(V: ValA);
1485 const SCEV *OffsetSCEVB = SE.getSCEV(V: OpB);
1486 const SCEV *IdxDiffSCEV = SE.getMinusSCEV(LHS: OffsetSCEVB, RHS: OffsetSCEVA);
1487 if (IdxDiffSCEV == SE.getCouldNotCompute())
1488 return std::nullopt;
1489
1490 ConstantRange IdxDiffRange = SE.getSignedRange(S: IdxDiffSCEV);
1491 if (!IdxDiffRange.isSingleElement())
1492 return std::nullopt;
1493 APInt IdxDiff = *IdxDiffRange.getSingleElement();
1494
1495 LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetComplexAddrs IdxDiff=" << IdxDiff
1496 << "\n");
1497
1498 // Now we need to prove that adding IdxDiff to ValA won't overflow.
1499 bool Safe = false;
1500
1501 // First attempt: if OpB is an add with NSW/NUW, and OpB is IdxDiff added to
1502 // ValA, we're okay.
1503 if (OpB->getOpcode() == Instruction::Add &&
1504 isa<ConstantInt>(Val: OpB->getOperand(i: 1)) &&
1505 IdxDiff.sle(RHS: cast<ConstantInt>(Val: OpB->getOperand(i: 1))->getSExtValue()) &&
1506 checkNoWrapFlags(I: OpB, Signed))
1507 Safe = true;
1508
1509 // Second attempt: check if we have eligible add NSW/NUW instruction
1510 // sequences.
1511 OpA = dyn_cast<Instruction>(Val: ValA);
1512 if (!Safe && OpA && OpA->getOpcode() == Instruction::Add &&
1513 OpB->getOpcode() == Instruction::Add && checkNoWrapFlags(I: OpA, Signed) &&
1514 checkNoWrapFlags(I: OpB, Signed)) {
1515 // In the checks below a matching operand in OpA and OpB is an operand which
1516 // is the same in those two instructions. Below we account for possible
1517 // orders of the operands of these add instructions.
1518 for (unsigned MatchingOpIdxA : {0, 1})
1519 for (unsigned MatchingOpIdxB : {0, 1})
1520 if (!Safe)
1521 Safe = checkIfSafeAddSequence(IdxDiff, AddOpA: OpA, MatchingOpIdxA, AddOpB: OpB,
1522 MatchingOpIdxB, Signed);
1523 }
1524
1525 unsigned BitWidth = ValA->getType()->getScalarSizeInBits();
1526
1527 // Third attempt:
1528 //
1529 // Assuming IdxDiff is positive: If all set bits of IdxDiff or any higher
1530 // order bit other than the sign bit are known to be zero in ValA, we can add
1531 // Diff to it while guaranteeing no overflow of any sort.
1532 //
1533 // If IdxDiff is negative, do the same, but swap ValA and ValB.
1534 if (!Safe) {
1535 // When computing known bits, use the GEPs as context instructions, since
1536 // they likely are in the same BB as the load/store.
1537 KnownBits Known(BitWidth);
1538 computeKnownBits(V: (IdxDiff.sge(RHS: 0) ? ValA : OpB), Known, DL, AC: &AC, CxtI: ContextInst,
1539 DT: &DT);
1540 APInt BitsAllowedToBeSet = Known.Zero.zext(width: IdxDiff.getBitWidth());
1541 if (Signed)
1542 BitsAllowedToBeSet.clearBit(BitPosition: BitWidth - 1);
1543 Safe = BitsAllowedToBeSet.uge(RHS: IdxDiff.abs());
1544 }
1545
1546 if (Safe)
1547 return IdxDiff * Stride;
1548 return std::nullopt;
1549}
1550
1551std::optional<APInt> Vectorizer::getConstantOffsetSelects(
1552 Value *PtrA, Value *PtrB, Instruction *ContextInst, unsigned Depth) {
1553 if (Depth++ == MaxDepth)
1554 return std::nullopt;
1555
1556 if (auto *SelectA = dyn_cast<SelectInst>(Val: PtrA)) {
1557 if (auto *SelectB = dyn_cast<SelectInst>(Val: PtrB)) {
1558 if (SelectA->getCondition() != SelectB->getCondition())
1559 return std::nullopt;
1560 LLVM_DEBUG(dbgs() << "LSV: getConstantOffsetSelects, PtrA=" << *PtrA
1561 << ", PtrB=" << *PtrB << ", ContextInst="
1562 << *ContextInst << ", Depth=" << Depth << "\n");
1563 std::optional<APInt> TrueDiff = getConstantOffset(
1564 PtrA: SelectA->getTrueValue(), PtrB: SelectB->getTrueValue(), ContextInst, Depth);
1565 if (!TrueDiff)
1566 return std::nullopt;
1567 std::optional<APInt> FalseDiff =
1568 getConstantOffset(PtrA: SelectA->getFalseValue(), PtrB: SelectB->getFalseValue(),
1569 ContextInst, Depth);
1570 if (TrueDiff == FalseDiff)
1571 return TrueDiff;
1572 }
1573 }
1574 return std::nullopt;
1575}
1576
1577void Vectorizer::mergeEquivalenceClasses(EquivalenceClassMap &EQClasses) const {
1578 if (EQClasses.size() < 2) // There is nothing to merge.
1579 return;
1580
1581 // The reduced key has all elements of the ECClassKey except the underlying
1582 // object. Check that EqClassKey has 4 elements and define the reduced key.
1583 static_assert(std::tuple_size_v<EqClassKey> == 4,
1584 "EqClassKey has changed - EqClassReducedKey needs changes too");
1585 using EqClassReducedKey =
1586 std::tuple<std::tuple_element_t<1, EqClassKey> /* AddrSpace */,
1587 std::tuple_element_t<2, EqClassKey> /* Element size */,
1588 std::tuple_element_t<3, EqClassKey> /* IsLoad; */>;
1589 using ECReducedKeyToUnderlyingObjectMap =
1590 MapVector<EqClassReducedKey,
1591 SmallPtrSet<std::tuple_element_t<0, EqClassKey>, 4>>;
1592
1593 // Form a map from the reduced key (without the underlying object) to the
1594 // underlying objects: 1 reduced key to many underlying objects, to form
1595 // groups of potentially merge-able equivalence classes.
1596 ECReducedKeyToUnderlyingObjectMap RedKeyToUOMap;
1597 bool FoundPotentiallyOptimizableEC = false;
1598 for (const auto &EC : EQClasses) {
1599 const auto &Key = EC.first;
1600 EqClassReducedKey RedKey{std::get<1>(t: Key), std::get<2>(t: Key),
1601 std::get<3>(t: Key)};
1602 auto &UOMap = RedKeyToUOMap[RedKey];
1603 UOMap.insert(Ptr: std::get<0>(t: Key));
1604 if (UOMap.size() > 1)
1605 FoundPotentiallyOptimizableEC = true;
1606 }
1607 if (!FoundPotentiallyOptimizableEC)
1608 return;
1609
1610 LLVM_DEBUG({
1611 dbgs() << "LSV: mergeEquivalenceClasses: before merging:\n";
1612 for (const auto &EC : EQClasses) {
1613 dbgs() << " Key: {" << EC.first << "}\n";
1614 for (const auto &Inst : EC.second)
1615 dbgs() << " Inst: " << *Inst << '\n';
1616 }
1617 });
1618 LLVM_DEBUG({
1619 dbgs() << "LSV: mergeEquivalenceClasses: RedKeyToUOMap:\n";
1620 for (const auto &RedKeyToUO : RedKeyToUOMap) {
1621 dbgs() << " Reduced key: {" << std::get<0>(RedKeyToUO.first) << ", "
1622 << std::get<1>(RedKeyToUO.first) << ", "
1623 << static_cast<int>(std::get<2>(RedKeyToUO.first)) << "} --> "
1624 << RedKeyToUO.second.size() << " underlying objects:\n";
1625 for (auto UObject : RedKeyToUO.second)
1626 dbgs() << " " << *UObject << '\n';
1627 }
1628 });
1629
1630 using UObjectToUObjectMap = DenseMap<const Value *, const Value *>;
1631
1632 // Compute the ultimate targets for a set of underlying objects.
1633 auto GetUltimateTargets =
1634 [](SmallPtrSetImpl<const Value *> &UObjects) -> UObjectToUObjectMap {
1635 UObjectToUObjectMap IndirectionMap;
1636 for (const auto *UObject : UObjects) {
1637 const unsigned MaxLookupDepth = 1; // look for 1-level indirections only
1638 const auto *UltimateTarget = getUnderlyingObject(V: UObject, MaxLookup: MaxLookupDepth);
1639 if (UltimateTarget != UObject)
1640 IndirectionMap[UObject] = UltimateTarget;
1641 }
1642 UObjectToUObjectMap UltimateTargetsMap;
1643 for (const auto *UObject : UObjects) {
1644 auto Target = UObject;
1645 auto It = IndirectionMap.find(Val: Target);
1646 for (; It != IndirectionMap.end(); It = IndirectionMap.find(Val: Target))
1647 Target = It->second;
1648 UltimateTargetsMap[UObject] = Target;
1649 }
1650 return UltimateTargetsMap;
1651 };
1652
1653 // For each item in RedKeyToUOMap, if it has more than one underlying object,
1654 // try to merge the equivalence classes.
1655 for (auto &[RedKey, UObjects] : RedKeyToUOMap) {
1656 if (UObjects.size() < 2)
1657 continue;
1658 auto UTMap = GetUltimateTargets(UObjects);
1659 for (const auto &[UObject, UltimateTarget] : UTMap) {
1660 if (UObject == UltimateTarget)
1661 continue;
1662
1663 EqClassKey KeyFrom{UObject, std::get<0>(t&: RedKey), std::get<1>(t&: RedKey),
1664 std::get<2>(t&: RedKey)};
1665 EqClassKey KeyTo{UltimateTarget, std::get<0>(t&: RedKey), std::get<1>(t&: RedKey),
1666 std::get<2>(t&: RedKey)};
1667 // The entry for KeyFrom is guarantted to exist, unlike KeyTo. Thus,
1668 // request the reference to the instructions vector for KeyTo first.
1669 const auto &VecTo = EQClasses[KeyTo];
1670 const auto &VecFrom = EQClasses[KeyFrom];
1671 SmallVector<Instruction *, 8> MergedVec;
1672 std::merge(first1: VecFrom.begin(), last1: VecFrom.end(), first2: VecTo.begin(), last2: VecTo.end(),
1673 result: std::back_inserter(x&: MergedVec),
1674 comp: [](Instruction *A, Instruction *B) {
1675 return A && B && A->comesBefore(Other: B);
1676 });
1677 EQClasses[KeyTo] = std::move(MergedVec);
1678 EQClasses.erase(Key: KeyFrom);
1679 }
1680 }
1681 LLVM_DEBUG({
1682 dbgs() << "LSV: mergeEquivalenceClasses: after merging:\n";
1683 for (const auto &EC : EQClasses) {
1684 dbgs() << " Key: {" << EC.first << "}\n";
1685 for (const auto &Inst : EC.second)
1686 dbgs() << " Inst: " << *Inst << '\n';
1687 }
1688 });
1689}
1690
1691EquivalenceClassMap
1692Vectorizer::collectEquivalenceClasses(BasicBlock::iterator Begin,
1693 BasicBlock::iterator End) {
1694 EquivalenceClassMap Ret;
1695
1696 auto GetUnderlyingObject = [](const Value *Ptr) -> const Value * {
1697 const Value *ObjPtr = llvm::getUnderlyingObject(V: Ptr);
1698 if (const auto *Sel = dyn_cast<SelectInst>(Val: ObjPtr)) {
1699 // The select's themselves are distinct instructions even if they share
1700 // the same condition and evaluate to consecutive pointers for true and
1701 // false values of the condition. Therefore using the select's themselves
1702 // for grouping instructions would put consecutive accesses into different
1703 // lists and they won't be even checked for being consecutive, and won't
1704 // be vectorized.
1705 return Sel->getCondition();
1706 }
1707 return ObjPtr;
1708 };
1709
1710 for (Instruction &I : make_range(x: Begin, y: End)) {
1711 auto *LI = dyn_cast<LoadInst>(Val: &I);
1712 auto *SI = dyn_cast<StoreInst>(Val: &I);
1713 if (!LI && !SI)
1714 continue;
1715
1716 if ((LI && !LI->isSimple()) || (SI && !SI->isSimple()))
1717 continue;
1718
1719 if ((LI && !TTI.isLegalToVectorizeLoad(LI)) ||
1720 (SI && !TTI.isLegalToVectorizeStore(SI)))
1721 continue;
1722
1723 Type *Ty = getLoadStoreType(I: &I);
1724 if (!VectorType::isValidElementType(ElemTy: Ty->getScalarType()))
1725 continue;
1726
1727 // Skip weird non-byte sizes. They probably aren't worth the effort of
1728 // handling correctly.
1729 unsigned TySize = DL.getTypeSizeInBits(Ty);
1730 if ((TySize % 8) != 0)
1731 continue;
1732
1733 // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain
1734 // functions are currently using an integer type for the vectorized
1735 // load/store, and does not support casting between the integer type and a
1736 // vector of pointers (e.g. i64 to <2 x i16*>)
1737 if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy())
1738 continue;
1739
1740 Value *Ptr = getLoadStorePointerOperand(V: &I);
1741 unsigned AS = Ptr->getType()->getPointerAddressSpace();
1742 unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AddrSpace: AS);
1743
1744 unsigned VF = VecRegSize / TySize;
1745 VectorType *VecTy = dyn_cast<VectorType>(Val: Ty);
1746
1747 // Only handle power-of-two sized elements.
1748 if ((!VecTy && !isPowerOf2_32(Value: DL.getTypeSizeInBits(Ty))) ||
1749 (VecTy && !isPowerOf2_32(Value: DL.getTypeSizeInBits(Ty: VecTy->getScalarType()))))
1750 continue;
1751
1752 // No point in looking at these if they're too big to vectorize.
1753 if (TySize > VecRegSize / 2 ||
1754 (VecTy && TTI.getLoadVectorFactor(VF, LoadSize: TySize, ChainSizeInBytes: TySize / 8, VecTy) == 0))
1755 continue;
1756
1757 Ret[{GetUnderlyingObject(Ptr), AS,
1758 DL.getTypeSizeInBits(Ty: getLoadStoreType(I: &I)->getScalarType()),
1759 /*IsLoad=*/LI != nullptr}]
1760 .emplace_back(Args: &I);
1761 }
1762
1763 mergeEquivalenceClasses(EQClasses&: Ret);
1764 return Ret;
1765}
1766
1767std::vector<Chain> Vectorizer::gatherChains(ArrayRef<Instruction *> Instrs) {
1768 if (Instrs.empty())
1769 return {};
1770
1771 unsigned AS = getLoadStoreAddressSpace(I: Instrs[0]);
1772 unsigned ASPtrBits = DL.getIndexSizeInBits(AS);
1773
1774#ifndef NDEBUG
1775 // Check that Instrs is in BB order and all have the same addr space.
1776 for (size_t I = 1; I < Instrs.size(); ++I) {
1777 assert(Instrs[I - 1]->comesBefore(Instrs[I]));
1778 assert(getLoadStoreAddressSpace(Instrs[I]) == AS);
1779 }
1780#endif
1781
1782 // Machinery to build an MRU-hashtable of Chains.
1783 //
1784 // (Ideally this could be done with MapVector, but as currently implemented,
1785 // moving an element to the front of a MapVector is O(n).)
1786 struct InstrListElem : ilist_node<InstrListElem>,
1787 std::pair<Instruction *, Chain> {
1788 explicit InstrListElem(Instruction *I)
1789 : std::pair<Instruction *, Chain>(I, {}) {}
1790 };
1791 struct InstrListElemDenseMapInfo {
1792 using PtrInfo = DenseMapInfo<InstrListElem *>;
1793 using IInfo = DenseMapInfo<Instruction *>;
1794 static InstrListElem *getEmptyKey() { return PtrInfo::getEmptyKey(); }
1795 static InstrListElem *getTombstoneKey() {
1796 return PtrInfo::getTombstoneKey();
1797 }
1798 static unsigned getHashValue(const InstrListElem *E) {
1799 return IInfo::getHashValue(PtrVal: E->first);
1800 }
1801 static bool isEqual(const InstrListElem *A, const InstrListElem *B) {
1802 if (A == getEmptyKey() || B == getEmptyKey())
1803 return A == getEmptyKey() && B == getEmptyKey();
1804 if (A == getTombstoneKey() || B == getTombstoneKey())
1805 return A == getTombstoneKey() && B == getTombstoneKey();
1806 return IInfo::isEqual(LHS: A->first, RHS: B->first);
1807 }
1808 };
1809 SpecificBumpPtrAllocator<InstrListElem> Allocator;
1810 simple_ilist<InstrListElem> MRU;
1811 DenseSet<InstrListElem *, InstrListElemDenseMapInfo> Chains;
1812
1813 // Compare each instruction in `instrs` to leader of the N most recently-used
1814 // chains. This limits the O(n^2) behavior of this pass while also allowing
1815 // us to build arbitrarily long chains.
1816 for (Instruction *I : Instrs) {
1817 constexpr int MaxChainsToTry = 64;
1818
1819 bool MatchFound = false;
1820 auto ChainIter = MRU.begin();
1821 for (size_t J = 0; J < MaxChainsToTry && ChainIter != MRU.end();
1822 ++J, ++ChainIter) {
1823 if (std::optional<APInt> Offset = getConstantOffset(
1824 PtrA: getLoadStorePointerOperand(V: ChainIter->first),
1825 PtrB: getLoadStorePointerOperand(V: I),
1826 /*ContextInst=*/
1827 (ChainIter->first->comesBefore(Other: I) ? I : ChainIter->first))) {
1828 // `Offset` might not have the expected number of bits, if e.g. AS has a
1829 // different number of bits than opaque pointers.
1830 ChainIter->second.emplace_back(Args&: I, Args&: Offset.value());
1831 // Move ChainIter to the front of the MRU list.
1832 MRU.remove(N&: *ChainIter);
1833 MRU.push_front(Node&: *ChainIter);
1834 MatchFound = true;
1835 break;
1836 }
1837 }
1838
1839 if (!MatchFound) {
1840 APInt ZeroOffset(ASPtrBits, 0);
1841 InstrListElem *E = new (Allocator.Allocate()) InstrListElem(I);
1842 E->second.emplace_back(Args&: I, Args&: ZeroOffset);
1843 MRU.push_front(Node&: *E);
1844 Chains.insert(V: E);
1845 }
1846 }
1847
1848 std::vector<Chain> Ret;
1849 Ret.reserve(n: Chains.size());
1850 // Iterate over MRU rather than Chains so the order is deterministic.
1851 for (auto &E : MRU)
1852 if (E.second.size() > 1)
1853 Ret.emplace_back(args: std::move(E.second));
1854 return Ret;
1855}
1856
1857std::optional<APInt> Vectorizer::getConstantOffset(Value *PtrA, Value *PtrB,
1858 Instruction *ContextInst,
1859 unsigned Depth) {
1860 LLVM_DEBUG(dbgs() << "LSV: getConstantOffset, PtrA=" << *PtrA
1861 << ", PtrB=" << *PtrB << ", ContextInst= " << *ContextInst
1862 << ", Depth=" << Depth << "\n");
1863 // We'll ultimately return a value of this bit width, even if computations
1864 // happen in a different width.
1865 unsigned OrigBitWidth = DL.getIndexTypeSizeInBits(Ty: PtrA->getType());
1866 APInt OffsetA(OrigBitWidth, 0);
1867 APInt OffsetB(OrigBitWidth, 0);
1868 PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, Offset&: OffsetA);
1869 PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, Offset&: OffsetB);
1870 unsigned NewPtrBitWidth = DL.getTypeStoreSizeInBits(Ty: PtrA->getType());
1871 if (NewPtrBitWidth != DL.getTypeStoreSizeInBits(Ty: PtrB->getType()))
1872 return std::nullopt;
1873
1874 // If we have to shrink the pointer, stripAndAccumulateInBoundsConstantOffsets
1875 // should properly handle a possible overflow and the value should fit into
1876 // the smallest data type used in the cast/gep chain.
1877 assert(OffsetA.getSignificantBits() <= NewPtrBitWidth &&
1878 OffsetB.getSignificantBits() <= NewPtrBitWidth);
1879
1880 OffsetA = OffsetA.sextOrTrunc(width: NewPtrBitWidth);
1881 OffsetB = OffsetB.sextOrTrunc(width: NewPtrBitWidth);
1882 if (PtrA == PtrB)
1883 return (OffsetB - OffsetA).sextOrTrunc(width: OrigBitWidth);
1884
1885 // Try to compute B - A.
1886 const SCEV *DistScev = SE.getMinusSCEV(LHS: SE.getSCEV(V: PtrB), RHS: SE.getSCEV(V: PtrA));
1887 if (DistScev != SE.getCouldNotCompute()) {
1888 LLVM_DEBUG(dbgs() << "LSV: SCEV PtrB - PtrA =" << *DistScev << "\n");
1889 ConstantRange DistRange = SE.getSignedRange(S: DistScev);
1890 if (DistRange.isSingleElement()) {
1891 // Handle index width (the width of Dist) != pointer width (the width of
1892 // the Offset*s at this point).
1893 APInt Dist = DistRange.getSingleElement()->sextOrTrunc(width: NewPtrBitWidth);
1894 return (OffsetB - OffsetA + Dist).sextOrTrunc(width: OrigBitWidth);
1895 }
1896 }
1897 if (std::optional<APInt> Diff =
1898 getConstantOffsetComplexAddrs(PtrA, PtrB, ContextInst, Depth))
1899 return (OffsetB - OffsetA + Diff->sext(width: OffsetB.getBitWidth()))
1900 .sextOrTrunc(width: OrigBitWidth);
1901 return std::nullopt;
1902}
1903
1904bool Vectorizer::accessIsAllowedAndFast(unsigned SizeBytes, unsigned AS,
1905 Align Alignment,
1906 unsigned VecElemBits) const {
1907 // Aligned vector accesses are ALWAYS faster than element-wise accesses.
1908 if (Alignment.value() % SizeBytes == 0)
1909 return true;
1910
1911 // Ask TTI whether misaligned accesses are faster as vector or element-wise.
1912 unsigned VectorizedSpeed = 0;
1913 bool AllowsMisaligned = TTI.allowsMisalignedMemoryAccesses(
1914 Context&: F.getContext(), BitWidth: SizeBytes * 8, AddressSpace: AS, Alignment, Fast: &VectorizedSpeed);
1915 if (!AllowsMisaligned) {
1916 LLVM_DEBUG(
1917 dbgs() << "LSV: Access of " << SizeBytes << "B in addrspace " << AS
1918 << " with alignment " << Alignment.value()
1919 << " is misaligned, and therefore can't be vectorized.\n");
1920 return false;
1921 }
1922
1923 unsigned ElementwiseSpeed = 0;
1924 (TTI).allowsMisalignedMemoryAccesses(Context&: (F).getContext(), BitWidth: VecElemBits, AddressSpace: AS,
1925 Alignment, Fast: &ElementwiseSpeed);
1926 if (VectorizedSpeed < ElementwiseSpeed) {
1927 LLVM_DEBUG(dbgs() << "LSV: Access of " << SizeBytes << "B in addrspace "
1928 << AS << " with alignment " << Alignment.value()
1929 << " has relative speed " << VectorizedSpeed
1930 << ", which is lower than the elementwise speed of "
1931 << ElementwiseSpeed
1932 << ". Therefore this access won't be vectorized.\n");
1933 return false;
1934 }
1935 return true;
1936}
1937
1938ChainElem Vectorizer::createExtraElementAfter(const ChainElem &Prev, Type *Ty,
1939 APInt Offset, StringRef Prefix,
1940 Align Alignment) {
1941 Instruction *NewElement = nullptr;
1942 Builder.SetInsertPoint(Prev.Inst->getNextNode());
1943 if (LoadInst *PrevLoad = dyn_cast<LoadInst>(Val: Prev.Inst)) {
1944 Value *NewGep = Builder.CreatePtrAdd(
1945 Ptr: PrevLoad->getPointerOperand(), Offset: Builder.getInt(AI: Offset), Name: Prefix + "GEP");
1946 LLVM_DEBUG(dbgs() << "LSV: Extra GEP Created: \n" << *NewGep << "\n");
1947 NewElement = Builder.CreateAlignedLoad(Ty, Ptr: NewGep, Align: Alignment, Name: Prefix);
1948 } else {
1949 StoreInst *PrevStore = cast<StoreInst>(Val: Prev.Inst);
1950
1951 Value *NewGep = Builder.CreatePtrAdd(
1952 Ptr: PrevStore->getPointerOperand(), Offset: Builder.getInt(AI: Offset), Name: Prefix + "GEP");
1953 LLVM_DEBUG(dbgs() << "LSV: Extra GEP Created: \n" << *NewGep << "\n");
1954 NewElement =
1955 Builder.CreateAlignedStore(Val: PoisonValue::get(T: Ty), Ptr: NewGep, Align: Alignment);
1956 }
1957
1958 // Attach all metadata to the new element.
1959 // propagateMetadata will fold it into the final vector when applicable.
1960 NewElement->copyMetadata(SrcInst: *Prev.Inst);
1961
1962 // Cache created elements for tracking and cleanup
1963 ExtraElements.insert(V: NewElement);
1964
1965 APInt NewOffsetFromLeader = Prev.OffsetFromLeader + Offset;
1966 LLVM_DEBUG(dbgs() << "LSV: Extra Element Created: \n"
1967 << *NewElement
1968 << " OffsetFromLeader: " << NewOffsetFromLeader << "\n");
1969 return ChainElem{NewElement, NewOffsetFromLeader};
1970}
1971
1972Value *Vectorizer::createMaskForExtraElements(const ArrayRef<ChainElem> C,
1973 FixedVectorType *VecTy) {
1974 // Start each mask element as false
1975 SmallVector<Constant *, 64> MaskElts(VecTy->getNumElements(),
1976 Builder.getInt1(V: false));
1977 // Iterate over the chain and set the corresponding mask element to true for
1978 // each element that is not an extra element.
1979 for (const ChainElem &E : C) {
1980 if (ExtraElements.contains(V: E.Inst))
1981 continue;
1982 unsigned EOffset =
1983 (E.OffsetFromLeader - C[0].OffsetFromLeader).getZExtValue();
1984 unsigned VecIdx =
1985 8 * EOffset / DL.getTypeSizeInBits(Ty: VecTy->getScalarType());
1986 if (FixedVectorType *VT =
1987 dyn_cast<FixedVectorType>(Val: getLoadStoreType(I: E.Inst)))
1988 for (unsigned J = 0; J < VT->getNumElements(); ++J)
1989 MaskElts[VecIdx + J] = Builder.getInt1(V: true);
1990 else
1991 MaskElts[VecIdx] = Builder.getInt1(V: true);
1992 }
1993 return ConstantVector::get(V: MaskElts);
1994}
1995
1996void Vectorizer::deleteExtraElements() {
1997 for (auto *ExtraElement : ExtraElements) {
1998 if (isa<LoadInst>(Val: ExtraElement)) {
1999 [[maybe_unused]] bool Deleted =
2000 RecursivelyDeleteTriviallyDeadInstructions(V: ExtraElement);
2001 assert(Deleted && "Extra Load should always be trivially dead");
2002 } else {
2003 // Unlike Extra Loads, Extra Stores won't be "dead", but should all be
2004 // deleted regardless. They will have either been combined into a masked
2005 // store, or will be left behind and need to be cleaned up.
2006 auto *PtrOperand = getLoadStorePointerOperand(V: ExtraElement);
2007 ExtraElement->eraseFromParent();
2008 RecursivelyDeleteTriviallyDeadInstructions(V: PtrOperand);
2009 }
2010 }
2011
2012 ExtraElements.clear();
2013}
2014