1//===- Legality.h -----------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Legality checks for the Sandbox Vectorizer.
10//
11
12#ifndef LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
13#define LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
14
15#include "llvm/ADT/ArrayRef.h"
16#include "llvm/Analysis/ScalarEvolution.h"
17#include "llvm/IR/DataLayout.h"
18#include "llvm/Support/Casting.h"
19#include "llvm/Support/Compiler.h"
20#include "llvm/Support/raw_ostream.h"
21#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
22#include "llvm/Transforms/Vectorize/SandboxVectorizer/Scheduler.h"
23
24namespace llvm::sandboxir {
25
26class LegalityAnalysis;
27class Value;
28class InstrMaps;
29
30class ShuffleMask {
31public:
32 using IndicesVecT = SmallVector<int, 8>;
33
34private:
35 IndicesVecT Indices;
36
37public:
38 ShuffleMask(SmallVectorImpl<int> &&Indices) : Indices(std::move(Indices)) {}
39 ShuffleMask(std::initializer_list<int> Indices) : Indices(Indices) {}
40 explicit ShuffleMask(ArrayRef<int> Indices) : Indices(Indices) {}
41 operator ArrayRef<int>() const { return Indices; }
42 /// Creates and returns an identity shuffle mask of size \p Sz.
43 /// For example if Sz == 4 the returned mask is {0, 1, 2, 3}.
44 static ShuffleMask getIdentity(unsigned Sz) {
45 IndicesVecT Indices;
46 Indices.reserve(N: Sz);
47 llvm::append_range(C&: Indices, R: seq<int>(Begin: 0, End: (int)Sz));
48 return ShuffleMask(std::move(Indices));
49 }
50 /// \Returns true if the mask is a perfect identity mask with consecutive
51 /// indices, i.e., performs no lane shuffling, like 0,1,2,3...
52 bool isIdentity() const {
53 for (auto [Idx, Elm] : enumerate(First: Indices)) {
54 if ((int)Idx != Elm)
55 return false;
56 }
57 return true;
58 }
59 bool operator==(const ShuffleMask &Other) const {
60 return Indices == Other.Indices;
61 }
62 bool operator!=(const ShuffleMask &Other) const { return !(*this == Other); }
63 size_t size() const { return Indices.size(); }
64 int operator[](int Idx) const { return Indices[Idx]; }
65 using const_iterator = IndicesVecT::const_iterator;
66 const_iterator begin() const { return Indices.begin(); }
67 const_iterator end() const { return Indices.end(); }
68#ifndef NDEBUG
69 friend raw_ostream &operator<<(raw_ostream &OS, const ShuffleMask &Mask) {
70 Mask.print(OS);
71 return OS;
72 }
73 void print(raw_ostream &OS) const {
74 interleave(Indices, OS, [&OS](auto Elm) { OS << Elm; }, ",");
75 }
76 LLVM_DUMP_METHOD void dump() const;
77#endif
78};
79
80enum class LegalityResultID {
81 Pack, ///> Collect scalar values.
82 Widen, ///> Vectorize by combining scalars to a vector.
83 DiamondReuse, ///> Don't generate new code, reuse existing vector.
84 DiamondReuseWithShuffle, ///> Reuse the existing vector but add a shuffle.
85 DiamondReuseMultiInput, ///> Reuse more than one vector and/or scalars.
86};
87
88/// The reason for vectorizing or not vectorizing.
89enum class ResultReason {
90 NotInstructions,
91 DiffOpcodes,
92 DiffTypes,
93 DiffMathFlags,
94 DiffWrapFlags,
95 DiffBBs,
96 RepeatedInstrs,
97 NotConsecutive,
98 CantSchedule,
99 Unimplemented,
100 Infeasible,
101 ForcePackForDebugging,
102};
103
104#ifndef NDEBUG
105struct ToStr {
106 static const char *getLegalityResultID(LegalityResultID ID) {
107 switch (ID) {
108 case LegalityResultID::Pack:
109 return "Pack";
110 case LegalityResultID::Widen:
111 return "Widen";
112 case LegalityResultID::DiamondReuse:
113 return "DiamondReuse";
114 case LegalityResultID::DiamondReuseWithShuffle:
115 return "DiamondReuseWithShuffle";
116 case LegalityResultID::DiamondReuseMultiInput:
117 return "DiamondReuseMultiInput";
118 }
119 llvm_unreachable("Unknown LegalityResultID enum");
120 }
121
122 static const char *getVecReason(ResultReason Reason) {
123 switch (Reason) {
124 case ResultReason::NotInstructions:
125 return "NotInstructions";
126 case ResultReason::DiffOpcodes:
127 return "DiffOpcodes";
128 case ResultReason::DiffTypes:
129 return "DiffTypes";
130 case ResultReason::DiffMathFlags:
131 return "DiffMathFlags";
132 case ResultReason::DiffWrapFlags:
133 return "DiffWrapFlags";
134 case ResultReason::DiffBBs:
135 return "DiffBBs";
136 case ResultReason::RepeatedInstrs:
137 return "RepeatedInstrs";
138 case ResultReason::NotConsecutive:
139 return "NotConsecutive";
140 case ResultReason::CantSchedule:
141 return "CantSchedule";
142 case ResultReason::Unimplemented:
143 return "Unimplemented";
144 case ResultReason::Infeasible:
145 return "Infeasible";
146 case ResultReason::ForcePackForDebugging:
147 return "ForcePackForDebugging";
148 }
149 llvm_unreachable("Unknown ResultReason enum");
150 }
151};
152#endif // NDEBUG
153
154/// The legality outcome is represented by a class rather than an enum class
155/// because in some cases the legality checks are expensive and look for a
156/// particular instruction that can be passed along to the vectorizer to avoid
157/// repeating the same expensive computation.
158class LegalityResult {
159protected:
160 LegalityResultID ID;
161 /// Only Legality can create LegalityResults.
162 LegalityResult(LegalityResultID ID) : ID(ID) {}
163 friend class LegalityAnalysis;
164
165 /// We shouldn't need copies.
166 LegalityResult(const LegalityResult &) = delete;
167 LegalityResult &operator=(const LegalityResult &) = delete;
168
169public:
170 virtual ~LegalityResult() = default;
171 LegalityResultID getSubclassID() const { return ID; }
172#ifndef NDEBUG
173 virtual void print(raw_ostream &OS) const {
174 OS << ToStr::getLegalityResultID(ID);
175 }
176 LLVM_DUMP_METHOD void dump() const;
177 friend raw_ostream &operator<<(raw_ostream &OS, const LegalityResult &LR) {
178 LR.print(OS);
179 return OS;
180 }
181#endif // NDEBUG
182};
183
184/// Base class for results with reason.
185class LegalityResultWithReason : public LegalityResult {
186 [[maybe_unused]] ResultReason Reason;
187 LegalityResultWithReason(LegalityResultID ID, ResultReason Reason)
188 : LegalityResult(ID), Reason(Reason) {}
189 friend class Pack; // For constructor.
190
191public:
192 ResultReason getReason() const { return Reason; }
193#ifndef NDEBUG
194 void print(raw_ostream &OS) const override {
195 LegalityResult::print(OS);
196 OS << " Reason: " << ToStr::getVecReason(Reason);
197 }
198#endif
199};
200
201class Widen final : public LegalityResult {
202 friend class LegalityAnalysis;
203 Widen() : LegalityResult(LegalityResultID::Widen) {}
204
205public:
206 static bool classof(const LegalityResult *From) {
207 return From->getSubclassID() == LegalityResultID::Widen;
208 }
209};
210
211class DiamondReuse final : public LegalityResult {
212 friend class LegalityAnalysis;
213 Action *Vec;
214 DiamondReuse(Action *Vec)
215 : LegalityResult(LegalityResultID::DiamondReuse), Vec(Vec) {}
216
217public:
218 static bool classof(const LegalityResult *From) {
219 return From->getSubclassID() == LegalityResultID::DiamondReuse;
220 }
221 Action *getVector() const { return Vec; }
222};
223
224class DiamondReuseWithShuffle final : public LegalityResult {
225 friend class LegalityAnalysis;
226 Action *Vec;
227 ShuffleMask Mask;
228 DiamondReuseWithShuffle(Action *Vec, const ShuffleMask &Mask)
229 : LegalityResult(LegalityResultID::DiamondReuseWithShuffle), Vec(Vec),
230 Mask(Mask) {}
231
232public:
233 static bool classof(const LegalityResult *From) {
234 return From->getSubclassID() == LegalityResultID::DiamondReuseWithShuffle;
235 }
236 Action *getVector() const { return Vec; }
237 const ShuffleMask &getMask() const { return Mask; }
238};
239
240class Pack final : public LegalityResultWithReason {
241 Pack(ResultReason Reason)
242 : LegalityResultWithReason(LegalityResultID::Pack, Reason) {}
243 friend class LegalityAnalysis; // For constructor.
244
245public:
246 static bool classof(const LegalityResult *From) {
247 return From->getSubclassID() == LegalityResultID::Pack;
248 }
249};
250
251/// Describes how to collect the values needed by each lane.
252class CollectDescr {
253public:
254 /// Describes how to get a value element. If the value is a vector then it
255 /// also provides the index to extract it from.
256 class ExtractElementDescr {
257 PointerUnion<Action *, Value *> V = nullptr;
258 /// The index in `V` that the value can be extracted from.
259 int ExtractIdx = 0;
260
261 public:
262 ExtractElementDescr(Action *V, int ExtractIdx)
263 : V(V), ExtractIdx(ExtractIdx) {}
264 ExtractElementDescr(Value *V) : V(V) {}
265 Action *getValue() const { return cast<Action *>(Val: V); }
266 Value *getScalar() const { return cast<Value *>(Val: V); }
267 bool needsExtract() const { return isa<Action *>(Val: V); }
268 int getExtractIdx() const { return ExtractIdx; }
269 };
270
271 using DescrVecT = SmallVector<ExtractElementDescr, 4>;
272 DescrVecT Descrs;
273
274public:
275 CollectDescr(SmallVectorImpl<ExtractElementDescr> &&Descrs)
276 : Descrs(std::move(Descrs)) {}
277 /// If all elements come from a single vector input, then return that vector
278 /// and also the shuffle mask required to get them in order.
279 std::optional<std::pair<Action *, ShuffleMask>> getSingleInput() const {
280 const auto &Descr0 = *Descrs.begin();
281 if (!Descr0.needsExtract())
282 return std::nullopt;
283 auto *V0 = Descr0.getValue();
284 ShuffleMask::IndicesVecT MaskIndices;
285 MaskIndices.push_back(Elt: Descr0.getExtractIdx());
286 for (const auto &Descr : drop_begin(RangeOrContainer: Descrs)) {
287 if (!Descr.needsExtract())
288 return std::nullopt;
289 if (Descr.getValue() != V0)
290 return std::nullopt;
291 MaskIndices.push_back(Elt: Descr.getExtractIdx());
292 }
293 return std::make_pair(x&: V0, y: ShuffleMask(std::move(MaskIndices)));
294 }
295 bool hasVectorInputs() const {
296 return any_of(Range: Descrs, P: [](const auto &D) { return D.needsExtract(); });
297 }
298 const SmallVector<ExtractElementDescr, 4> &getDescrs() const {
299 return Descrs;
300 }
301};
302
303class DiamondReuseMultiInput final : public LegalityResult {
304 friend class LegalityAnalysis;
305 CollectDescr Descr;
306 DiamondReuseMultiInput(CollectDescr &&Descr)
307 : LegalityResult(LegalityResultID::DiamondReuseMultiInput),
308 Descr(std::move(Descr)) {}
309
310public:
311 static bool classof(const LegalityResult *From) {
312 return From->getSubclassID() == LegalityResultID::DiamondReuseMultiInput;
313 }
314 const CollectDescr &getCollectDescr() const { return Descr; }
315};
316
317/// Performs the legality analysis and returns a LegalityResult object.
318class LegalityAnalysis {
319 Scheduler Sched;
320 /// Owns the legality result objects created by createLegalityResult().
321 SmallVector<std::unique_ptr<LegalityResult>> ResultPool;
322 /// Checks opcodes, types and other IR-specifics and returns a ResultReason
323 /// object if not vectorizable, or nullptr otherwise.
324 std::optional<ResultReason>
325 notVectorizableBasedOnOpcodesAndTypes(ArrayRef<Value *> Bndl);
326
327 ScalarEvolution &SE;
328 const DataLayout &DL;
329 InstrMaps &IMaps;
330
331 /// Finds how we can collect the values in \p Bndl from the vectorized or
332 /// non-vectorized code. It returns a map of the value we should extract from
333 /// and the corresponding shuffle mask we need to use.
334 CollectDescr getHowToCollectValues(ArrayRef<Value *> Bndl) const;
335
336public:
337 LegalityAnalysis(AAResults &AA, ScalarEvolution &SE, const DataLayout &DL,
338 Context &Ctx, InstrMaps &IMaps)
339 : Sched(AA, Ctx), SE(SE), DL(DL), IMaps(IMaps) {}
340 /// A LegalityResult factory.
341 template <typename ResultT, typename... ArgsT>
342 ResultT &createLegalityResult(ArgsT &&...Args) {
343 ResultPool.push_back(
344 std::unique_ptr<ResultT>(new ResultT(std::move(Args)...)));
345 return cast<ResultT>(*ResultPool.back());
346 }
347 /// Checks if it's legal to vectorize the instructions in \p Bndl.
348 /// \Returns a LegalityResult object owned by LegalityAnalysis.
349 /// \p SkipScheduling skips the scheduler check and is only meant for testing.
350 // TODO: Try to remove the SkipScheduling argument by refactoring the tests.
351 LLVM_ABI const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
352 bool SkipScheduling = false);
353 /// \Returns a Pack with reason 'ForcePackForDebugging'.
354 const LegalityResult &getForcedPackForDebugging() {
355 return createLegalityResult<Pack>(Args: ResultReason::ForcePackForDebugging);
356 }
357 LLVM_ABI void clear();
358};
359
360} // namespace llvm::sandboxir
361
362#endif // LLVM_TRANSFORMS_VECTORIZE_SANDBOXVECTORIZER_LEGALITY_H
363