1 | //===- Legality.cpp -------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/Legality.h" |
10 | #include "llvm/SandboxIR/Instruction.h" |
11 | #include "llvm/SandboxIR/Operator.h" |
12 | #include "llvm/SandboxIR/Utils.h" |
13 | #include "llvm/SandboxIR/Value.h" |
14 | #include "llvm/Support/Debug.h" |
15 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/InstrMaps.h" |
16 | #include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h" |
17 | |
18 | namespace llvm::sandboxir { |
19 | |
20 | #ifndef NDEBUG |
21 | void ShuffleMask::dump() const { |
22 | print(dbgs()); |
23 | dbgs() << "\n" ; |
24 | } |
25 | |
26 | void LegalityResult::dump() const { |
27 | print(dbgs()); |
28 | dbgs() << "\n" ; |
29 | } |
30 | #endif // NDEBUG |
31 | |
32 | std::optional<ResultReason> |
33 | LegalityAnalysis::notVectorizableBasedOnOpcodesAndTypes( |
34 | ArrayRef<Value *> Bndl) { |
35 | auto *I0 = cast<Instruction>(Val: Bndl[0]); |
36 | auto Opcode = I0->getOpcode(); |
37 | // If they have different opcodes, then we cannot form a vector (for now). |
38 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [Opcode](Value *V) { |
39 | return cast<Instruction>(Val: V)->getOpcode() != Opcode; |
40 | })) |
41 | return ResultReason::DiffOpcodes; |
42 | |
43 | // If not the same scalar type, Pack. This will accept scalars and vectors as |
44 | // long as the element type is the same. |
45 | Type *ElmTy0 = VecUtils::getElementType(Ty: Utils::getExpectedType(V: I0)); |
46 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [ElmTy0](Value *V) { |
47 | return VecUtils::getElementType(Ty: Utils::getExpectedType(V)) != ElmTy0; |
48 | })) |
49 | return ResultReason::DiffTypes; |
50 | |
51 | // TODO: Allow vectorization of instrs with different flags as long as we |
52 | // change them to the least common one. |
53 | // For now pack if differnt FastMathFlags. |
54 | if (isa<FPMathOperator>(Val: I0)) { |
55 | FastMathFlags FMF0 = cast<Instruction>(Val: Bndl[0])->getFastMathFlags(); |
56 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FMF0](auto *V) { |
57 | return cast<Instruction>(V)->getFastMathFlags() != FMF0; |
58 | })) |
59 | return ResultReason::DiffMathFlags; |
60 | } |
61 | |
62 | // TODO: Allow vectorization by using common flags. |
63 | // For now Pack if they don't have the same wrap flags. |
64 | bool CanHaveWrapFlags = |
65 | isa<OverflowingBinaryOperator>(Val: I0) || isa<TruncInst>(Val: I0); |
66 | if (CanHaveWrapFlags) { |
67 | bool NUW0 = I0->hasNoUnsignedWrap(); |
68 | bool NSW0 = I0->hasNoSignedWrap(); |
69 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [NUW0, NSW0](auto *V) { |
70 | return cast<Instruction>(V)->hasNoUnsignedWrap() != NUW0 || |
71 | cast<Instruction>(V)->hasNoSignedWrap() != NSW0; |
72 | })) { |
73 | return ResultReason::DiffWrapFlags; |
74 | } |
75 | } |
76 | |
77 | // Now we need to do further checks for specific opcodes. |
78 | switch (Opcode) { |
79 | case Instruction::Opcode::ZExt: |
80 | case Instruction::Opcode::SExt: |
81 | case Instruction::Opcode::FPToUI: |
82 | case Instruction::Opcode::FPToSI: |
83 | case Instruction::Opcode::FPExt: |
84 | case Instruction::Opcode::PtrToInt: |
85 | case Instruction::Opcode::IntToPtr: |
86 | case Instruction::Opcode::SIToFP: |
87 | case Instruction::Opcode::UIToFP: |
88 | case Instruction::Opcode::Trunc: |
89 | case Instruction::Opcode::FPTrunc: |
90 | case Instruction::Opcode::BitCast: { |
91 | // We have already checked that they are of the same opcode. |
92 | assert(all_of(Bndl, |
93 | [Opcode](Value *V) { |
94 | return cast<Instruction>(V)->getOpcode() == Opcode; |
95 | }) && |
96 | "Different opcodes, should have early returned!" ); |
97 | // But for these opcodes we should also check the operand type. |
98 | Type *FromTy0 = Utils::getExpectedType(V: I0->getOperand(OpIdx: 0)); |
99 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), P: [FromTy0](Value *V) { |
100 | return Utils::getExpectedType(V: cast<User>(Val: V)->getOperand(OpIdx: 0)) != |
101 | FromTy0; |
102 | })) |
103 | return ResultReason::DiffTypes; |
104 | return std::nullopt; |
105 | } |
106 | case Instruction::Opcode::FCmp: |
107 | case Instruction::Opcode::ICmp: { |
108 | // We need the same predicate.. |
109 | auto Pred0 = cast<CmpInst>(Val: I0)->getPredicate(); |
110 | bool Same = all_of(Range&: Bndl, P: [Pred0](Value *V) { |
111 | return cast<CmpInst>(Val: V)->getPredicate() == Pred0; |
112 | }); |
113 | if (Same) |
114 | return std::nullopt; |
115 | return ResultReason::DiffOpcodes; |
116 | } |
117 | case Instruction::Opcode::Select: { |
118 | auto *Sel0 = cast<SelectInst>(Val: Bndl[0]); |
119 | auto *Cond0 = Sel0->getCondition(); |
120 | if (VecUtils::getNumLanes(V: Cond0) != VecUtils::getNumLanes(V: Sel0)) |
121 | // TODO: For now we don't vectorize if the lanes in the condition don't |
122 | // match those of the select instruction. |
123 | return ResultReason::Unimplemented; |
124 | return std::nullopt; |
125 | } |
126 | case Instruction::Opcode::FNeg: |
127 | case Instruction::Opcode::Add: |
128 | case Instruction::Opcode::FAdd: |
129 | case Instruction::Opcode::Sub: |
130 | case Instruction::Opcode::FSub: |
131 | case Instruction::Opcode::Mul: |
132 | case Instruction::Opcode::FMul: |
133 | case Instruction::Opcode::FRem: |
134 | case Instruction::Opcode::UDiv: |
135 | case Instruction::Opcode::SDiv: |
136 | case Instruction::Opcode::FDiv: |
137 | case Instruction::Opcode::URem: |
138 | case Instruction::Opcode::SRem: |
139 | case Instruction::Opcode::Shl: |
140 | case Instruction::Opcode::LShr: |
141 | case Instruction::Opcode::AShr: |
142 | case Instruction::Opcode::And: |
143 | case Instruction::Opcode::Or: |
144 | case Instruction::Opcode::Xor: |
145 | return std::nullopt; |
146 | case Instruction::Opcode::Load: |
147 | if (VecUtils::areConsecutive<LoadInst>(Bndl, SE, DL)) |
148 | return std::nullopt; |
149 | return ResultReason::NotConsecutive; |
150 | case Instruction::Opcode::Store: |
151 | if (VecUtils::areConsecutive<StoreInst>(Bndl, SE, DL)) |
152 | return std::nullopt; |
153 | return ResultReason::NotConsecutive; |
154 | case Instruction::Opcode::PHI: |
155 | return ResultReason::Unimplemented; |
156 | case Instruction::Opcode::Opaque: |
157 | return ResultReason::Unimplemented; |
158 | case Instruction::Opcode::Br: |
159 | case Instruction::Opcode::Ret: |
160 | case Instruction::Opcode::AddrSpaceCast: |
161 | case Instruction::Opcode::InsertElement: |
162 | case Instruction::Opcode::InsertValue: |
163 | case Instruction::Opcode::ExtractElement: |
164 | case Instruction::Opcode::ExtractValue: |
165 | case Instruction::Opcode::ShuffleVector: |
166 | case Instruction::Opcode::Call: |
167 | case Instruction::Opcode::GetElementPtr: |
168 | case Instruction::Opcode::Switch: |
169 | return ResultReason::Unimplemented; |
170 | case Instruction::Opcode::VAArg: |
171 | case Instruction::Opcode::Freeze: |
172 | case Instruction::Opcode::Fence: |
173 | case Instruction::Opcode::Invoke: |
174 | case Instruction::Opcode::CallBr: |
175 | case Instruction::Opcode::LandingPad: |
176 | case Instruction::Opcode::CatchPad: |
177 | case Instruction::Opcode::CleanupPad: |
178 | case Instruction::Opcode::CatchRet: |
179 | case Instruction::Opcode::CleanupRet: |
180 | case Instruction::Opcode::Resume: |
181 | case Instruction::Opcode::CatchSwitch: |
182 | case Instruction::Opcode::AtomicRMW: |
183 | case Instruction::Opcode::AtomicCmpXchg: |
184 | case Instruction::Opcode::Alloca: |
185 | case Instruction::Opcode::Unreachable: |
186 | return ResultReason::Infeasible; |
187 | } |
188 | |
189 | return std::nullopt; |
190 | } |
191 | |
192 | CollectDescr |
193 | LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const { |
194 | SmallVector<CollectDescr::ExtractElementDescr, 4> Vec; |
195 | Vec.reserve(N: Bndl.size()); |
196 | for (auto [Elm, V] : enumerate(First&: Bndl)) { |
197 | if (auto *VecOp = IMaps.getVectorForOrig(Orig: V)) { |
198 | // If there is a vector containing `V`, then get the lane it came from. |
199 | std::optional<int> = IMaps.getOrigLane(Vec: VecOp, Orig: V); |
200 | // This could be a vector, like <2 x float> in which case the mask needs |
201 | // to enumerate all lanes. |
202 | for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln) |
203 | Vec.emplace_back(Args&: VecOp, Args: ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1); |
204 | } else { |
205 | Vec.emplace_back(Args: V); |
206 | } |
207 | } |
208 | return CollectDescr(std::move(Vec)); |
209 | } |
210 | |
211 | const LegalityResult &LegalityAnalysis::canVectorize(ArrayRef<Value *> Bndl, |
212 | bool SkipScheduling) { |
213 | // If Bndl contains values other than instructions, we need to Pack. |
214 | if (any_of(Range&: Bndl, P: [](auto *V) { return !isa<Instruction>(V); })) |
215 | return createLegalityResult<Pack>(Args: ResultReason::NotInstructions); |
216 | // Pack if not in the same BB. |
217 | auto *BB = cast<Instruction>(Val: Bndl[0])->getParent(); |
218 | if (any_of(Range: drop_begin(RangeOrContainer&: Bndl), |
219 | P: [BB](auto *V) { return cast<Instruction>(V)->getParent() != BB; })) |
220 | return createLegalityResult<Pack>(Args: ResultReason::DiffBBs); |
221 | // Pack if instructions repeat, i.e., require some sort of broadcast. |
222 | SmallPtrSet<Value *, 8> Unique(llvm::from_range, Bndl); |
223 | if (Unique.size() != Bndl.size()) |
224 | return createLegalityResult<Pack>(Args: ResultReason::RepeatedInstrs); |
225 | |
226 | auto CollectDescrs = getHowToCollectValues(Bndl); |
227 | if (CollectDescrs.hasVectorInputs()) { |
228 | if (auto ValueShuffleOpt = CollectDescrs.getSingleInput()) { |
229 | auto [Vec, Mask] = *ValueShuffleOpt; |
230 | if (Mask.isIdentity()) |
231 | return createLegalityResult<DiamondReuse>(Args&: Vec); |
232 | return createLegalityResult<DiamondReuseWithShuffle>(Args&: Vec, Args&: Mask); |
233 | } |
234 | return createLegalityResult<DiamondReuseMultiInput>( |
235 | Args: std::move(CollectDescrs)); |
236 | } |
237 | |
238 | if (auto ReasonOpt = notVectorizableBasedOnOpcodesAndTypes(Bndl)) |
239 | return createLegalityResult<Pack>(Args&: *ReasonOpt); |
240 | |
241 | if (!SkipScheduling) { |
242 | // TODO: Try to remove the IBndl vector. |
243 | SmallVector<Instruction *, 8> IBndl; |
244 | IBndl.reserve(N: Bndl.size()); |
245 | for (auto *V : Bndl) |
246 | IBndl.push_back(Elt: cast<Instruction>(Val: V)); |
247 | if (!Sched.trySchedule(Instrs: IBndl)) |
248 | return createLegalityResult<Pack>(Args: ResultReason::CantSchedule); |
249 | } |
250 | |
251 | return createLegalityResult<Widen>(); |
252 | } |
253 | |
254 | void LegalityAnalysis::clear() { |
255 | Sched.clear(); |
256 | IMaps.clear(); |
257 | } |
258 | } // namespace llvm::sandboxir |
259 | |