1//===- Legality.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h"
10#include "llvm/SandboxIR/Instruction.h"
11#include "llvm/SandboxIR/Operator.h"
12#include "llvm/SandboxIR/Utils.h"
13#include "llvm/SandboxIR/Value.h"
14#include "llvm/Support/Debug.h"
15#include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h"
16#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"
17
18namespace llvm::sandboxir {
19
20#ifndef NDEBUG
21void ShuffleMask::dump() const {
22 print(dbgs());
23 dbgs() << "\n";
24}
25
26void LegalityResult::dump() const {
27 print(dbgs());
28 dbgs() << "\n";
29}
30#endif // NDEBUG
31
32std::optional<ResultReason>
33LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes(
34 ArrayRef<Value *> Bndl) {
35 auto *I0 = cast<Instruction>(Val: Bndl[0]);
36 auto Opcode = I0->getOpcode();
37 // If they have different opcodes, then we cannot form a vector (for now).
38 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [Opcode](Value *V) {
39 return cast<Instruction>(Val: V)->getOpcode() != Opcode;
40 }))
41 return ResultReason::DiffOpcodes;
42
43 // If not the same scalar type, Pack. This will accept scalars and vectors as
44 // long as the element type is the same.
45 Type *ElmTy0 = VecUtils::getElementType(Ty: Utils::getExpectedType(V: I0));
46 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [ElmTy0](Value *V) {
47 return VecUtils::getElementType(Ty: Utils::getExpectedType(V)) != ElmTy0;
48 }))
49 return ResultReason::DiffTypes;
50
51 // TODO: Allow vectorization of instrs with different flags as long as we
52 // change them to the least common one.
53 // For now pack if differnt FastMathFlags.
54 if (isa<FPMathOperator>(Val: I0)) {
55 FastMathFlags FMF0 = cast<Instruction>(Val: Bndl[0])->getFastMathFlags();
56 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FMF0](auto *V) {
57 return cast<Instruction>(V)->getFastMathFlags() != FMF0;
58 }))
59 return ResultReason::DiffMathFlags;
60 }
61
62 // TODO: Allow vectorization by using common flags.
63 // For now Pack if they don't have the same wrap flags.
64 bool CanHaveWrapFlags =
65 isa<OverflowingBinaryOperator>(Val: I0) || isa<TruncInst>(Val: I0);
66 if (CanHaveWrapFlags) {
67 bool NUW0 = I0->hasNoUnsignedWrap();
68 bool NSW0 = I0->hasNoSignedWrap();
69 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [NUW0, NSW0](auto *V) {
70 return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 ||
71 cast<Instruction>(V)->hasNoSignedWrap() != NSW0;
72 })) {
73 return ResultReason::DiffWrapFlags;
74 }
75 }
76
77 // Now we need to do further checks for specific opcodes.
78 switch (Opcode) {
79 case Instruction::Opcode::ZExt:
80 case Instruction::Opcode::SExt:
81 case Instruction::Opcode::FPToUI:
82 case Instruction::Opcode::FPToSI:
83 case Instruction::Opcode::FPExt:
84 case Instruction::Opcode::PtrToInt:
85 case Instruction::Opcode::IntToPtr:
86 case Instruction::Opcode::SIToFP:
87 case Instruction::Opcode::UIToFP:
88 case Instruction::Opcode::Trunc:
89 case Instruction::Opcode::FPTrunc:
90 case Instruction::Opcode::BitCast: {
91 // We have already checked that they are of the same opcode.
92 assert(all_of(Bndl,
93 [Opcode](Value *V) {
94 return cast<Instruction>(V)->getOpcode() == Opcode;
95 }) &&
96 "Different opcodes, should have early returned!");
97 // But for these opcodes we should also check the operand type.
98 Type *FromTy0 = Utils::getExpectedType(V: I0->getOperand(OpIdx: 0));
99 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FromTy0](Value *V) {
100 return Utils::getExpectedType(V: cast<User>(Val: V)->getOperand(OpIdx: 0)) !=
101 FromTy0;
102 }))
103 return ResultReason::DiffTypes;
104 return std::nullopt;
105 }
106 case Instruction::Opcode::FCmp:
107 case Instruction::Opcode::ICmp: {
108 // We need the same predicate..
109 auto Pred0 = cast<CmpInst>(Val: I0)->getPredicate();
110 bool Same = all_of(Range&: Bndl, P: [Pred0](Value *V) {
111 return cast<CmpInst>(Val: V)->getPredicate() == Pred0;
112 });
113 if (Same)
114 return std::nullopt;
115 return ResultReason::DiffOpcodes;
116 }
117 case Instruction::Opcode::Select: {
118 auto *Sel0 = cast<SelectInst>(Val: Bndl[0]);
119 auto *Cond0 = Sel0->getCondition();
120 if (VecUtils::getNumLanes(V: Cond0) != VecUtils::getNumLanes(V: Sel0))
121 // TODO: For now we don't vectorize if the lanes in the condition don't
122 // match those of the select instruction.
123 return ResultReason::Unimplemented;
124 return std::nullopt;
125 }
126 case Instruction::Opcode::FNeg:
127 case Instruction::Opcode::Add:
128 case Instruction::Opcode::FAdd:
129 case Instruction::Opcode::Sub:
130 case Instruction::Opcode::FSub:
131 case Instruction::Opcode::Mul:
132 case Instruction::Opcode::FMul:
133 case Instruction::Opcode::FRem:
134 case Instruction::Opcode::UDiv:
135 case Instruction::Opcode::SDiv:
136 case Instruction::Opcode::FDiv:
137 case Instruction::Opcode::URem:
138 case Instruction::Opcode::SRem:
139 case Instruction::Opcode::Shl:
140 case Instruction::Opcode::LShr:
141 case Instruction::Opcode::AShr:
142 case Instruction::Opcode::And:
143 case Instruction::Opcode::Or:
144 case Instruction::Opcode::Xor:
145 return std::nullopt;
146 case Instruction::Opcode::Load:
147 if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL))
148 return std::nullopt;
149 return ResultReason::NotConsecutive;
150 case Instruction::Opcode::Store:
151 if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL))
152 return std::nullopt;
153 return ResultReason::NotConsecutive;
154 case Instruction::Opcode::PHI:
155 return ResultReason::Unimplemented;
156 case Instruction::Opcode::Opaque:
157 return ResultReason::Unimplemented;
158 case Instruction::Opcode::Br:
159 case Instruction::Opcode::Ret:
160 case Instruction::Opcode::AddrSpaceCast:
161 case Instruction::Opcode::InsertElement:
162 case Instruction::Opcode::InsertValue:
163 case Instruction::Opcode::ExtractElement:
164 case Instruction::Opcode::ExtractValue:
165 case Instruction::Opcode::ShuffleVector:
166 case Instruction::Opcode::Call:
167 case Instruction::Opcode::GetElementPtr:
168 case Instruction::Opcode::Switch:
169 return ResultReason::Unimplemented;
170 case Instruction::Opcode::VAArg:
171 case Instruction::Opcode::Freeze:
172 case Instruction::Opcode::Fence:
173 case Instruction::Opcode::Invoke:
174 case Instruction::Opcode::CallBr:
175 case Instruction::Opcode::LandingPad:
176 case Instruction::Opcode::CatchPad:
177 case Instruction::Opcode::CleanupPad:
178 case Instruction::Opcode::CatchRet:
179 case Instruction::Opcode::CleanupRet:
180 case Instruction::Opcode::Resume:
181 case Instruction::Opcode::CatchSwitch:
182 case Instruction::Opcode::AtomicRMW:
183 case Instruction::Opcode::AtomicCmpXchg:
184 case Instruction::Opcode::Alloca:
185 case Instruction::Opcode::Unreachable:
186 return ResultReason::Infeasible;
187 }
188
189 return std::nullopt;
190}
191
192CollectDescr
193LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
194 SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
195 Vec.reserve(N: Bndl.size());
196 for (auto [Elm, V] : enumerate(First&: Bndl)) {
197 if (auto *VecOp = IMaps.getVectorForOrig(Orig: V)) {
198 // If there is a vector containing `V`, then get the lane it came from.
199 std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(Vec: VecOp, Orig: V);
200 // This could be a vector, like <2 x float> in which case the mask needs
201 // to enumerate all lanes.
202 for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
203 Vec.emplace_back(Args&: VecOp, Args: ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
204 } else {
205 Vec.emplace_back(Args: V);
206 }
207 }
208 return CollectDescr(std::move(Vec));
209}
210
211const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl,
212 bool SkipScheduling) {
213 // If Bndl contains values other than instructions, we need to Pack.
214 if (any_of(Range&: Bndl, P: [](auto *V) { return !isa<Instruction>(V); }))
215 return createLegalityResult<Pack>(Args: ResultReason::NotInstructions);
216 // Pack if not in the same BB.
217 auto *BB = cast<Instruction>(Val: Bndl[0])->getParent();
218 if (any_of(Range: drop_begin(RangeOrContainer&: Bndl),
219 P: [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; }))
220 return createLegalityResult<Pack>(Args: ResultReason::DiffBBs);
221 // Pack if instructions repeat, i.e., require some sort of broadcast.
222 SmallPtrSet<Value *, 8> Unique(llvm::from_range, Bndl);
223 if (Unique.size() != Bndl.size())
224 return createLegalityResult<Pack>(Args: ResultReason::RepeatedInstrs);
225
226 auto CollectDescrs = getHowToCollectValues(Bndl);
227 if (CollectDescrs.hasVectorInputs()) {
228 if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) {
229 auto [Vec, Mask] = *ValueShuffleOpt;
230 if (Mask.isIdentity())
231 return createLegalityResult<DiamondReuse>(Args&: Vec);
232 return createLegalityResult<DiamondReuseWithShuffle>(Args&: Vec, Args&: Mask);
233 }
234 return createLegalityResult<DiamondReuseMultiInput>(
235 Args: std::move(CollectDescrs));
236 }
237
238 if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl))
239 return createLegalityResult<Pack>(Args&: *ReasonOpt);
240
241 if (!SkipScheduling) {
242 // TODO: Try to remove the IBndl vector.
243 SmallVector<Instruction *, 8> IBndl;
244 IBndl.reserve(N: Bndl.size());
245 for (auto *V : Bndl)
246 IBndl.push_back(Elt: cast<Instruction>(Val: V));
247 if (!Sched.trySchedule(Instrs: IBndl))
248 return createLegalityResult<Pack>(Args: ResultReason::CantSchedule);
249 }
250
251 return createLegalityResult<Widen>();
252}
253
254void LegalityAnalysis::clear() {
255 Sched.clear();
256 IMaps.clear();
257}
258} // namespace llvm::sandboxir
259