1//===- LoadStoreVec.cpp - Vectorizer pass short load-store chains ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/LoadStoreVec.h"
10#include "llvm/SandboxIR/Module.h"
11#include "llvm/SandboxIR/Region.h"
12#include "llvm/Support/CommandLine.h"
13#include "llvm/Support/InstructionCost.h"
14#include "llvm/Transforms/Vectorize/SandboxVectorizer/Debug.h"
15#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
16#include "llvm/Transforms/Vectorize/SandboxVectorizer/RegionWithScore.h"
17#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
18#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
19
20namespace llvm {
21
22extern cl::opt<int> CostThreshold; // Defined in TransactionAcceptOrRevert.cpp
23
24namespace sandboxir {
25
26#define DEBUG_PREFIX_LOCAL DEBUG_PREFIX "LoadStoreVec: "
27
28std::optional<Type *> LoadStoreVec::canVectorize(ArrayRef<Instruction *> Bndl,
29 Scheduler &Sched) {
30 // Check if in the same BB.
31 if (LegalityAnalysis::differentBlock(Instrs: Bndl))
32 return std::nullopt;
33
34 // Check if instructions repeat.
35 if (!LegalityAnalysis::areUnique(Values: Bndl))
36 return std::nullopt;
37
38 // Check scheduling.
39 if (!Sched.trySchedule(Instrs: Bndl))
40 return std::nullopt;
41
42 return VecUtils::getCombinedVectorTypeFor(Bndl, DL: *DL);
43}
44
45void LoadStoreVec::tryEraseDeadInstrs(ArrayRef<Instruction *> Stores,
46 ArrayRef<Value *> Operands) {
47 SmallPtrSet<Instruction *, 8> DeadCandidates;
48 for (auto *SI : Stores) {
49 if (auto *PtrI =
50 dyn_cast<Instruction>(Val: cast<StoreInst>(Val: SI)->getPointerOperand()))
51 DeadCandidates.insert(Ptr: PtrI);
52 SI->eraseFromParent();
53 }
54 for (auto *Op : Operands) {
55 if (auto *LI = dyn_cast<LoadInst>(Val: Op)) {
56 if (auto *PtrI =
57 dyn_cast<Instruction>(Val: cast<LoadInst>(Val: LI)->getPointerOperand()))
58 DeadCandidates.insert(Ptr: PtrI);
59 cast<LoadInst>(Val: LI)->eraseFromParent();
60 }
61 }
62 for (auto *PtrI : DeadCandidates)
63 if (!PtrI->hasNUsesOrMore(Num: 1))
64 PtrI->eraseFromParent();
65}
66
67bool LoadStoreVec::runOnRegion(Region &Rgn, const Analyses &A) {
68 SmallVector<Instruction *, 8> Bndl(Rgn.getAux().begin(), Rgn.getAux().end());
69 if (Bndl.size() < 2)
70 return false;
71 Function &F = *Bndl[0]->getParent()->getParent();
72 DL = &F.getParent()->getDataLayout();
73 auto &Ctx = F.getContext();
74 Scheduler Sched(A.getAA(), Ctx);
75 if (!VecUtils::areConsecutive<StoreInst, Instruction>(
76 Bndl, SE&: A.getScalarEvolution(), DL: *DL))
77 return false;
78 if (!canVectorize(Bndl, Sched))
79 return false;
80
81 const auto &SB = cast<RegionWithScore>(Val&: Rgn).getScoreboard();
82 InstructionCost CostBefore = SB.getAfterCost() - SB.getBeforeCost();
83
84 SmallVector<Value *, 4> Operands;
85 Operands.reserve(N: Bndl.size());
86 for (auto *I : Bndl) {
87 auto *Op = cast<StoreInst>(Val: I)->getValueOperand();
88 Operands.push_back(Elt: Op);
89 }
90 BasicBlock *BB = Bndl[0]->getParent();
91 // TODO: For now we only support load operands.
92 // TODO: For now we don't cross BBs.
93 // TODO: For now don't vectorize if the loads have external uses.
94 bool AllLoads = all_of(Range&: Operands, P: [BB](Value *V) {
95 auto *LI = dyn_cast<LoadInst>(Val: V);
96 if (LI == nullptr)
97 return false;
98 // TODO: For now we don't cross BBs.
99 if (LI->getParent() != BB)
100 return false;
101 if (LI->hasNUsesOrMore(Num: 2))
102 return false;
103 return true;
104 });
105 bool AllConstants =
106 all_of(Range&: Operands, P: [](Value *V) { return isa<Constant>(Val: V); });
107 if (!AllLoads && !AllConstants)
108 return false;
109
110 // Vectorizing mixed floats and integers with external uses may not be
111 // profitable on some targets, so save state here.
112 Ctx.save();
113
114 Value *VecOp = nullptr;
115 if (AllLoads) {
116 // TODO: Try to avoid the extra copy to an instruction vector.
117 SmallVector<Instruction *, 8> Loads;
118 Loads.reserve(N: Operands.size());
119 for (Value *Op : Operands)
120 Loads.push_back(Elt: cast<Instruction>(Val: Op));
121
122 bool Consecutive = VecUtils::areConsecutive<LoadInst, Instruction>(
123 Bndl: Loads, SE&: A.getScalarEvolution(), DL: *DL);
124 if (!Consecutive) {
125 Ctx.accept();
126 return false;
127 }
128 if (!canVectorize(Bndl: Loads, Sched)) {
129 Ctx.accept();
130 return false;
131 }
132
133 // Generate vector load.
134 Type *Ty = VecUtils::getCombinedVectorTypeFor(Bndl, DL: *DL);
135 Value *LdPtr = cast<LoadInst>(Val: Loads[0])->getPointerOperand();
136 // TODO: Compute alignment.
137 Align LdAlign(1);
138 auto LdWhereIt = std::next(x: VecUtils::getLowest(Instrs: Loads)->getIterator());
139 VecOp = LoadInst::create(Ty, Ptr: LdPtr, Align: LdAlign, Pos: LdWhereIt, Ctx, Name: "VecIinitL");
140 } else if (AllConstants) {
141 SmallVector<Constant *, 8> Constants;
142 Constants.reserve(N: Operands.size());
143 for (Value *Op : Operands) {
144 auto *COp = cast<Constant>(Val: Op);
145 if (auto *AggrCOp = dyn_cast<ConstantAggregate>(Val: COp)) {
146 // If the operand is a constant aggregate, then append all its elements.
147 for (Value *Elm : AggrCOp->operands())
148 Constants.push_back(Elt: cast<Constant>(Val: Elm));
149 } else if (auto *SeqCOp = dyn_cast<ConstantDataSequential>(Val: COp)) {
150 for (auto ElmIdx : seq<unsigned>(Size: SeqCOp->getNumElements()))
151 Constants.push_back(Elt: SeqCOp->getElementAsConstant(ElmIdx));
152 } else if (auto *Zero = dyn_cast<ConstantAggregateZero>(Val: COp)) {
153 auto *ZeroElm = Zero->getSequentialElement();
154 for ([[maybe_unused]] auto Cnt :
155 seq<unsigned>(Size: Zero->getElementCount().getFixedValue()))
156 Constants.push_back(Elt: ZeroElm);
157 } else if (isa<ConstantInt>(Val: COp) && isa<VectorType>(Val: COp->getType())) {
158 auto *Elm = ConstantInt::get(Ctx, V: cast<ConstantInt>(Val: COp)->getValue());
159 for ([[maybe_unused]] auto Cnt :
160 seq<unsigned>(Size: cast<VectorType>(Val: COp->getType())
161 ->getElementCount()
162 .getFixedValue()))
163 Constants.push_back(Elt: Elm);
164 } else if (isa<ConstantFP>(Val: COp) && isa<VectorType>(Val: COp->getType())) {
165 auto *Elm = ConstantFP::get(V: cast<ConstantFP>(Val: COp)->getValue(), Ctx);
166 for ([[maybe_unused]] auto Cnt :
167 seq<unsigned>(Size: cast<VectorType>(Val: COp->getType())
168 ->getElementCount()
169 .getFixedValue()))
170 Constants.push_back(Elt: Elm);
171 } else {
172 Constants.push_back(Elt: COp);
173 }
174 }
175 VecOp = ConstantVector::get(V: Constants);
176 }
177
178 // Generate vector store.
179 Value *StPtr = cast<StoreInst>(Val: Bndl[0])->getPointerOperand();
180 // TODO: Compute alignment.
181 Align StAlign(1);
182 auto StWhereIt = std::next(x: VecUtils::getLowest(Instrs: Bndl)->getIterator());
183 StoreInst::create(V: VecOp, Ptr: StPtr, Align: StAlign, Pos: StWhereIt, Ctx);
184
185 tryEraseDeadInstrs(Stores: Bndl, Operands);
186
187 // Check the cost.
188 InstructionCost CostAfter = SB.getAfterCost() - SB.getBeforeCost();
189 InstructionCost CostGain = CostAfter - CostBefore;
190 LLVM_DEBUG(dbgs() << DEBUG_PREFIX_LOCAL << "CostGain=" << CostGain
191 << " (After=" << CostAfter << " Before=" << CostBefore
192 << ")\n");
193 if (CostGain > CostThreshold) {
194 LLVM_DEBUG(dbgs() << DEBUG_PREFIX_LOCAL << "Not profitable, reverting.\n");
195 Ctx.revert();
196 return false;
197 }
198 LLVM_DEBUG(dbgs() << DEBUG_PREFIX_LOCAL << "Profitable accepting.\n");
199 Ctx.accept();
200 return true;
201}
202
203} // namespace sandboxir
204
205} // namespace llvm
206