1//===- LoadStoreVec.cpp - Vectorizer pass short load-store chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/LoadStoreVec.h"
10#include "llvm/SandboxIR/Module.h"
11#include "llvm/SandboxIR/Region.h"
12#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
13#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
14#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
15
16namespace llvm {
17
18namespace sandboxir {
19
20std::optional<Type *> LoadStoreVec::canVectorize(ArrayRef<Instruction *> Bndl,
21 Scheduler &Sched) {
22 // Check if in the same BB.
23 if (LegalityAnalysis::differentBlock(Instrs: Bndl))
24 return std::nullopt;
25
26 // Check if instructions repeat.
27 if (!LegalityAnalysis::areUnique(Values: Bndl))
28 return std::nullopt;
29
30 // TODO: This is target-dependent.
31 // Don't mix integer with floating point.
32 bool IsFloat = false;
33 bool IsInteger = false;
34 for ([[maybe_unused]] auto *I : Bndl) {
35 if (Utils::getExpectedType(V: I)->getScalarType()->isFloatingPointTy())
36 IsFloat = true;
37 else
38 IsInteger = true;
39 }
40 if (IsFloat && IsInteger)
41 return std::nullopt;
42
43 // Check scheduling.
44 if (!Sched.trySchedule(Instrs: Bndl))
45 return std::nullopt;
46
47 return VecUtils::getCombinedVectorTypeFor(Bndl, DL: *DL);
48}
49
50void LoadStoreVec::tryEraseDeadInstrs(ArrayRef<Instruction *> Stores,
51 ArrayRef<Instruction *> Loads) {
52 SmallPtrSet<Instruction *, 8> DeadCandidates;
53 for (auto *SI : Stores) {
54 if (auto *PtrI =
55 dyn_cast<Instruction>(Val: cast<StoreInst>(Val: SI)->getPointerOperand()))
56 DeadCandidates.insert(Ptr: PtrI);
57 SI->eraseFromParent();
58 }
59 for (auto *LI : Loads) {
60 if (auto *PtrI =
61 dyn_cast<Instruction>(Val: cast<LoadInst>(Val: LI)->getPointerOperand()))
62 DeadCandidates.insert(Ptr: PtrI);
63 cast<LoadInst>(Val: LI)->eraseFromParent();
64 }
65 for (auto *PtrI : DeadCandidates)
66 if (!PtrI->hasNUsesOrMore(Num: 1))
67 PtrI->eraseFromParent();
68}
69
70bool LoadStoreVec::runOnRegion(Region &Rgn, const Analyses &A) {
71 SmallVector<Instruction *, 8> Bndl(Rgn.getAux().begin(), Rgn.getAux().end());
72 if (Bndl.size() < 2)
73 return false;
74 Function &F = *Bndl[0]->getParent()->getParent();
75 DL = &F.getParent()->getDataLayout();
76 auto &Ctx = F.getContext();
77 Scheduler Sched(A.getAA(), Ctx);
78 if (!VecUtils::areConsecutive<StoreInst, Instruction>(
79 Bndl, SE&: A.getScalarEvolution(), DL: *DL))
80 return false;
81 if (!canVectorize(Bndl, Sched))
82 return false;
83
84 SmallVector<Value *, 4> Operands;
85 Operands.reserve(N: Bndl.size());
86 for (auto *I : Bndl) {
87 auto *Op = cast<StoreInst>(Val: I)->getValueOperand();
88 Operands.push_back(Elt: Op);
89 }
90 BasicBlock *BB = Bndl[0]->getParent();
91 // TODO: For now we only support load operands.
92 // TODO: For now we don't cross BBs.
93 // TODO: For now don't vectorize if the loads have external uses.
94 if (!all_of(Range&: Operands, P: [BB](Value *V) {
95 auto *LI = dyn_cast<LoadInst>(Val: V);
96 if (LI == nullptr)
97 return false;
98 if (LI->getParent() != BB)
99 return false;
100 if (LI->hasNUsesOrMore(Num: 2))
101 return false;
102 return true;
103 }))
104 return false;
105 // TODO: Try to avoid the extra copy to an instruction vector.
106 SmallVector<Instruction *, 8> Loads;
107 Loads.reserve(N: Operands.size());
108 for (Value *Op : Operands)
109 Loads.push_back(Elt: cast<Instruction>(Val: Op));
110
111 bool Consecutive = VecUtils::areConsecutive<LoadInst, Instruction>(
112 Bndl: Loads, SE&: A.getScalarEvolution(), DL: *DL);
113 if (!Consecutive)
114 return false;
115 if (!canVectorize(Bndl: Loads, Sched))
116 return false;
117
118 // Generate vector store and vector load
119 Type *Ty = VecUtils::getCombinedVectorTypeFor(Bndl, DL: *DL);
120 Value *LdPtr = cast<LoadInst>(Val: Loads[0])->getPointerOperand();
121 // TODO: Compute alignment.
122 Align LdAlign(1);
123 auto LdWhereIt = std::next(x: VecUtils::getLowest(Instrs: Loads)->getIterator());
124 auto *VecLd =
125 LoadInst::create(Ty, Ptr: LdPtr, Align: LdAlign, Pos: LdWhereIt, Ctx, Name: "VecIinitL");
126
127 Value *StPtr = cast<StoreInst>(Val: Bndl[0])->getPointerOperand();
128 // TODO: Compute alignment.
129 Align StAlign(1);
130 auto StWhereIt = std::next(x: VecUtils::getLowest(Instrs: Bndl)->getIterator());
131 StoreInst::create(V: VecLd, Ptr: StPtr, Align: StAlign, Pos: StWhereIt, Ctx);
132
133 tryEraseDeadInstrs(Stores: Bndl, Loads);
134 return true;
135}
136
137} // namespace sandboxir
138
139} // namespace llvm
140