1//===- SeedCollector.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
10#include "llvm/Analysis/LoopAccessAnalysis.h"
11#include "llvm/Analysis/ValueTracking.h"
12#include "llvm/IR/Type.h"
13#include "llvm/SandboxIR/Instruction.h"
14#include "llvm/SandboxIR/Utils.h"
15#include "llvm/Support/Compiler.h"
16#include "llvm/Support/Debug.h"
17
18using namespace llvm;
19namespace llvm::sandboxir {
20
21static cl::opt<unsigned> SeedBundleSizeLimit(
22 "sbvec-seed-bundle-size-limit", cl::init(Val: 32), cl::Hidden,
23 cl::desc("Limit the size of the seed bundle to cap compilation time."));
24
25static cl::opt<unsigned> SeedGroupsLimit(
26 "sbvec-seed-groups-limit", cl::init(Val: 256), cl::Hidden,
27 cl::desc("Limit the number of collected seeds groups in a BB to "
28 "cap compilation time."));
29
30ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
31 unsigned MaxVecRegBits,
32 bool ForcePowerOf2) {
33 // Use uint32_t here for compatibility with IsPowerOf2_32
34
35 // BitCount tracks the size of the working slice. From that we can tell
36 // when the working slice's size is a power-of-two and when it exceeds
37 // the legal size in MaxVecBits.
38 uint32_t BitCount = 0;
39 uint32_t NumElements = 0;
40 // Tracks the most recent slice where NumElements gave a power-of-2 BitCount
41 uint32_t NumElementsPowerOfTwo = 0;
42 uint32_t BitCountPowerOfTwo = 0;
43 // Can't start a slice with a used instruction.
44 assert(!isUsed(StartIdx) && "Expected unused at StartIdx");
45 for (Instruction *S : drop_begin(RangeOrContainer&: Seeds, N: StartIdx)) {
46 // Stop if this instruction is used. This needs to be done before
47 // getNumBits() because a "used" instruction may have been erased.
48 if (isUsed(Element: StartIdx + NumElements))
49 break;
50 uint32_t InstBits = Utils::getNumBits(I: S);
51 // Stop if adding it puts the slice over the limit.
52 if (BitCount + InstBits > MaxVecRegBits)
53 break;
54 NumElements++;
55 BitCount += InstBits;
56 if (ForcePowerOf2 && isPowerOf2_32(Value: BitCount)) {
57 NumElementsPowerOfTwo = NumElements;
58 BitCountPowerOfTwo = BitCount;
59 }
60 }
61 if (ForcePowerOf2) {
62 NumElements = NumElementsPowerOfTwo;
63 BitCount = BitCountPowerOfTwo;
64 }
65
66 // Return any non-empty slice
67 if (NumElements > 1) {
68 assert((!ForcePowerOf2 || isPowerOf2_32(BitCount)) &&
69 "Must be a power of two");
70 return ArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
71 }
72 return {};
73}
74
75template <typename LoadOrStoreT>
76SeedContainer::KeyT SeedContainer::getKey(LoadOrStoreT *LSI) const {
77 assert((isa<LoadInst>(LSI) || isa<StoreInst>(LSI)) &&
78 "Expected Load or Store!");
79 Value *Ptr = Utils::getMemInstructionBase(LSI);
80 Instruction::Opcode Op = LSI->getOpcode();
81 Type *Ty = Utils::getExpectedType(V: LSI);
82 if (auto *VTy = dyn_cast<VectorType>(Val: Ty))
83 Ty = VTy->getElementType();
84 return {Ptr, Ty, Op};
85}
86
87// Explicit instantiations
88template SeedContainer::KeyT
89SeedContainer::getKey<LoadInst>(LoadInst *LSI) const;
90template SeedContainer::KeyT
91SeedContainer::getKey<StoreInst>(StoreInst *LSI) const;
92
93bool SeedContainer::erase(Instruction *I) {
94 assert((isa<LoadInst>(I) || isa<StoreInst>(I)) && "Expected Load or Store!");
95 auto It = SeedLookupMap.find(Val: I);
96 if (It == SeedLookupMap.end())
97 return false;
98 SeedBundle *Bndl = It->second;
99 Bndl->setUsed(I);
100 return true;
101}
102
103template <typename LoadOrStoreT> void SeedContainer::insert(LoadOrStoreT *LSI) {
104 // Find the bundle containing seeds for this symbol and type-of-access.
105 auto &BundleVec = Bundles[getKey(LSI)];
106 // Fill this vector of bundles front to back so that only the last bundle in
107 // the vector may have available space. This avoids iteration to find one with
108 // space.
109 if (BundleVec.empty() || BundleVec.back()->size() == SeedBundleSizeLimit)
110 BundleVec.emplace_back(std::make_unique<MemSeedBundle<LoadOrStoreT>>(LSI));
111 else
112 BundleVec.back()->insert(LSI, SE);
113
114 SeedLookupMap[LSI] = BundleVec.back().get();
115}
116
117// Explicit instantiations
118template LLVM_EXPORT_TEMPLATE void SeedContainer::insert<LoadInst>(LoadInst *);
119template LLVM_EXPORT_TEMPLATE void
120SeedContainer::insert<StoreInst>(StoreInst *);
121
122#ifndef NDEBUG
123void SeedContainer::print(raw_ostream &OS) const {
124 for (const auto &Pair : Bundles) {
125 auto [I, Ty, Opc] = Pair.first;
126 const auto &SeedsVec = Pair.second;
127 std::string RefType = dyn_cast<LoadInst>(I) ? "Load"
128 : dyn_cast<StoreInst>(I) ? "Store"
129 : "Other";
130 OS << "[Inst=" << *I << " Ty=" << Ty << " " << RefType << "]\n";
131 for (const auto &SeedPtr : SeedsVec) {
132 SeedPtr->dump(OS);
133 OS << "\n";
134 }
135 }
136 OS << "\n";
137}
138
139LLVM_DUMP_METHOD void SeedContainer::dump() const { print(dbgs()); }
140#endif // NDEBUG
141
142template <typename LoadOrStoreT> static bool isValidMemSeed(LoadOrStoreT *LSI) {
143 if (!LSI->isSimple())
144 return false;
145 auto *Ty = Utils::getExpectedType(V: LSI);
146 // Omit types that are architecturally unvectorizable
147 if (Ty->isX86_FP80Ty() || Ty->isPPC_FP128Ty())
148 return false;
149 // Omit vector types without compile-time-known lane counts
150 if (isa<ScalableVectorType>(Ty))
151 return false;
152 if (auto *VTy = dyn_cast<FixedVectorType>(Ty))
153 return VectorType::isValidElementType(ElemTy: VTy->getElementType());
154 return VectorType::isValidElementType(ElemTy: Ty);
155}
156
157template bool isValidMemSeed<LoadInst>(LoadInst *LSI);
158template bool isValidMemSeed<StoreInst>(StoreInst *LSI);
159
160SeedCollector::SeedCollector(BasicBlock *BB, ScalarEvolution &SE,
161 bool CollectStores, bool CollectLoads)
162 : StoreSeeds(SE), LoadSeeds(SE), Ctx(BB->getContext()) {
163
164 if (!CollectStores && !CollectLoads)
165 return;
166
167 EraseCallbackID = Ctx.registerEraseInstrCallback(CB: [this](Instruction *I) {
168 if (auto SI = dyn_cast<StoreInst>(Val: I))
169 StoreSeeds.erase(I: SI);
170 else if (auto LI = dyn_cast<LoadInst>(Val: I))
171 LoadSeeds.erase(I: LI);
172 });
173
174 // Actually collect the seeds.
175 for (auto &I : *BB) {
176 if (StoreInst *SI = dyn_cast<StoreInst>(Val: &I))
177 if (CollectStores && isValidMemSeed(LSI: SI))
178 StoreSeeds.insert(LSI: SI);
179 if (LoadInst *LI = dyn_cast<LoadInst>(Val: &I))
180 if (CollectLoads && isValidMemSeed(LSI: LI))
181 LoadSeeds.insert(LSI: LI);
182 // Cap compilation time.
183 if (totalNumSeedGroups() > SeedGroupsLimit)
184 break;
185 }
186}
187
188SeedCollector::~SeedCollector() {
189 Ctx.unregisterEraseInstrCallback(ID: EraseCallbackID);
190}
191
192#ifndef NDEBUG
193void SeedCollector::print(raw_ostream &OS) const {
194 OS << "=== StoreSeeds ===\n";
195 StoreSeeds.print(OS);
196 OS << "=== LoadSeeds ===\n";
197 LoadSeeds.print(OS);
198}
199
200void SeedCollector::dump() const { print(dbgs()); }
201#endif
202
203} // namespace llvm::sandboxir
204