1 | //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass interleaves around sext/zext/trunc instructions. MVE does not have |
10 | // a single sext/zext or trunc instruction that takes the bottom half of a |
11 | // vector and extends to a full width, like NEON has with MOVL. Instead it is |
12 | // expected that this happens through top/bottom instructions. So the MVE |
13 | // equivalent VMOVLT/B instructions take either the even or odd elements of the |
14 | // input and extend them to the larger type, producing a vector with half the |
15 | // number of elements each of double the bitwidth. As there is no simple |
16 | // instruction, we often have to turn sext/zext/trunc into a series of lane |
17 | // moves (or stack loads/stores, which we do not do yet). |
18 | // |
19 | // This pass takes vector code that starts at truncs, looks for interconnected |
20 | // blobs of operations that end with sext/zext (or constants/splats) of the |
21 | // form: |
22 | // %sa = sext v8i16 %a to v8i32 |
23 | // %sb = sext v8i16 %b to v8i32 |
24 | // %add = add v8i32 %sa, %sb |
25 | // %r = trunc %add to v8i16 |
26 | // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: |
27 | // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
28 | // %sa = sext v8i16 %sha to v8i32 |
29 | // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
30 | // %sb = sext v8i16 %shb to v8i32 |
31 | // %add = add v8i32 %sa, %sb |
32 | // %r = trunc %add to v8i16 |
33 | // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> |
34 | // Which can then be split and lowered to MVE instructions efficiently: |
35 | // %sa_b = VMOVLB.s16 %a |
36 | // %sa_t = VMOVLT.s16 %a |
37 | // %sb_b = VMOVLB.s16 %b |
38 | // %sb_t = VMOVLT.s16 %b |
39 | // %add_b = VADD.i32 %sa_b, %sb_b |
40 | // %add_t = VADD.i32 %sa_t, %sb_t |
41 | // %r = VMOVNT.i16 %add_b, %add_t |
42 | // |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | #include "ARM.h" |
46 | #include "ARMBaseInstrInfo.h" |
47 | #include "ARMSubtarget.h" |
48 | #include "llvm/ADT/SetVector.h" |
49 | #include "llvm/Analysis/TargetTransformInfo.h" |
50 | #include "llvm/CodeGen/TargetLowering.h" |
51 | #include "llvm/CodeGen/TargetPassConfig.h" |
52 | #include "llvm/IR/BasicBlock.h" |
53 | #include "llvm/IR/DerivedTypes.h" |
54 | #include "llvm/IR/Function.h" |
55 | #include "llvm/IR/IRBuilder.h" |
56 | #include "llvm/IR/InstIterator.h" |
57 | #include "llvm/IR/InstrTypes.h" |
58 | #include "llvm/IR/Instruction.h" |
59 | #include "llvm/IR/Instructions.h" |
60 | #include "llvm/IR/IntrinsicInst.h" |
61 | #include "llvm/IR/Intrinsics.h" |
62 | #include "llvm/IR/Type.h" |
63 | #include "llvm/IR/Value.h" |
64 | #include "llvm/InitializePasses.h" |
65 | #include "llvm/Pass.h" |
66 | #include "llvm/Support/Casting.h" |
67 | #include <cassert> |
68 | |
69 | using namespace llvm; |
70 | |
71 | #define DEBUG_TYPE "mve-laneinterleave" |
72 | |
73 | static cl::opt<bool> EnableInterleave( |
74 | "enable-mve-interleave" , cl::Hidden, cl::init(Val: true), |
75 | cl::desc("Enable interleave MVE vector operation lowering" )); |
76 | |
77 | namespace { |
78 | |
79 | class MVELaneInterleaving : public FunctionPass { |
80 | public: |
81 | static char ID; // Pass identification, replacement for typeid |
82 | |
83 | explicit MVELaneInterleaving() : FunctionPass(ID) { |
84 | initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); |
85 | } |
86 | |
87 | bool runOnFunction(Function &F) override; |
88 | |
89 | StringRef getPassName() const override { return "MVE lane interleaving" ; } |
90 | |
91 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
92 | AU.setPreservesCFG(); |
93 | AU.addRequired<TargetPassConfig>(); |
94 | FunctionPass::getAnalysisUsage(AU); |
95 | } |
96 | }; |
97 | |
98 | } // end anonymous namespace |
99 | |
100 | char MVELaneInterleaving::ID = 0; |
101 | |
102 | INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving" , false, |
103 | false) |
104 | |
105 | Pass *llvm::createMVELaneInterleavingPass() { |
106 | return new MVELaneInterleaving(); |
107 | } |
108 | |
109 | static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts, |
110 | SmallSetVector<Instruction *, 4> &Truncs) { |
111 | // This is not always beneficial to transform. Exts can be incorporated into |
112 | // loads, Truncs can be folded into stores. |
113 | // Truncs are usually the same number of instructions, |
114 | // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving |
115 | // Exts are unfortunately more instructions in the general case: |
116 | // A=VLDRH.32; B=VLDRH.32; |
117 | // vs with interleaving: |
118 | // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T |
119 | // But those VMOVL may be folded into a VMULL. |
120 | |
121 | // But expensive extends/truncs are always good to remove. FPExts always |
122 | // involve extra VCVT's so are always considered to be beneficial to convert. |
123 | for (auto *E : Exts) { |
124 | if (isa<FPExtInst>(Val: E) || !isa<LoadInst>(Val: E->getOperand(i: 0))) { |
125 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n" ); |
126 | return true; |
127 | } |
128 | } |
129 | for (auto *T : Truncs) { |
130 | if (T->hasOneUse() && !isa<StoreInst>(Val: *T->user_begin())) { |
131 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n" ); |
132 | return true; |
133 | } |
134 | } |
135 | |
136 | // Otherwise, we know we have a load(ext), see if any of the Extends are a |
137 | // vmull. This is a simple heuristic and certainly not perfect. |
138 | for (auto *E : Exts) { |
139 | if (!E->hasOneUse() || |
140 | cast<Instruction>(Val: *E->user_begin())->getOpcode() != Instruction::Mul) { |
141 | LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n" ); |
142 | return false; |
143 | } |
144 | } |
145 | return true; |
146 | } |
147 | |
148 | static bool tryInterleave(Instruction *Start, |
149 | SmallPtrSetImpl<Instruction *> &Visited) { |
150 | LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n" ); |
151 | |
152 | if (!isa<Instruction>(Val: Start->getOperand(i: 0))) |
153 | return false; |
154 | |
155 | // Look for connected operations starting from Ext's, terminating at Truncs. |
156 | std::vector<Instruction *> Worklist; |
157 | Worklist.push_back(x: Start); |
158 | Worklist.push_back(x: cast<Instruction>(Val: Start->getOperand(i: 0))); |
159 | |
160 | SmallSetVector<Instruction *, 4> Truncs; |
161 | SmallSetVector<Instruction *, 4> Reducts; |
162 | SmallSetVector<Instruction *, 4> Exts; |
163 | SmallSetVector<Use *, 4> OtherLeafs; |
164 | SmallSetVector<Instruction *, 4> Ops; |
165 | |
166 | while (!Worklist.empty()) { |
167 | Instruction *I = Worklist.back(); |
168 | Worklist.pop_back(); |
169 | |
170 | switch (I->getOpcode()) { |
171 | // Truncs |
172 | case Instruction::Trunc: |
173 | case Instruction::FPTrunc: |
174 | if (!Truncs.insert(X: I)) |
175 | continue; |
176 | Visited.insert(Ptr: I); |
177 | break; |
178 | |
179 | // Extend leafs |
180 | case Instruction::SExt: |
181 | case Instruction::ZExt: |
182 | case Instruction::FPExt: |
183 | if (Exts.count(key: I)) |
184 | continue; |
185 | for (auto *Use : I->users()) |
186 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
187 | Exts.insert(X: I); |
188 | break; |
189 | |
190 | case Instruction::Call: { |
191 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I); |
192 | if (!II) |
193 | return false; |
194 | |
195 | if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) { |
196 | if (!Reducts.insert(X: I)) |
197 | continue; |
198 | Visited.insert(Ptr: I); |
199 | break; |
200 | } |
201 | |
202 | switch (II->getIntrinsicID()) { |
203 | case Intrinsic::abs: |
204 | case Intrinsic::smin: |
205 | case Intrinsic::smax: |
206 | case Intrinsic::umin: |
207 | case Intrinsic::umax: |
208 | case Intrinsic::sadd_sat: |
209 | case Intrinsic::ssub_sat: |
210 | case Intrinsic::uadd_sat: |
211 | case Intrinsic::usub_sat: |
212 | case Intrinsic::minnum: |
213 | case Intrinsic::maxnum: |
214 | case Intrinsic::fabs: |
215 | case Intrinsic::fma: |
216 | case Intrinsic::ceil: |
217 | case Intrinsic::floor: |
218 | case Intrinsic::rint: |
219 | case Intrinsic::round: |
220 | case Intrinsic::trunc: |
221 | break; |
222 | default: |
223 | return false; |
224 | } |
225 | [[fallthrough]]; // Fall through to treating these like an operator below. |
226 | } |
227 | // Binary/tertiary ops |
228 | case Instruction::Add: |
229 | case Instruction::Sub: |
230 | case Instruction::Mul: |
231 | case Instruction::AShr: |
232 | case Instruction::LShr: |
233 | case Instruction::Shl: |
234 | case Instruction::ICmp: |
235 | case Instruction::FCmp: |
236 | case Instruction::FAdd: |
237 | case Instruction::FMul: |
238 | case Instruction::Select: |
239 | if (!Ops.insert(X: I)) |
240 | continue; |
241 | |
242 | for (Use &Op : I->operands()) { |
243 | if (!isa<FixedVectorType>(Val: Op->getType())) |
244 | continue; |
245 | if (isa<Instruction>(Val: Op)) |
246 | Worklist.push_back(x: cast<Instruction>(Val: &Op)); |
247 | else |
248 | OtherLeafs.insert(X: &Op); |
249 | } |
250 | |
251 | for (auto *Use : I->users()) |
252 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
253 | break; |
254 | |
255 | case Instruction::ShuffleVector: |
256 | // A shuffle of a splat is a splat. |
257 | if (cast<ShuffleVectorInst>(Val: I)->isZeroEltSplat()) |
258 | continue; |
259 | [[fallthrough]]; |
260 | |
261 | default: |
262 | LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n" ); |
263 | return false; |
264 | } |
265 | } |
266 | |
267 | if (Exts.empty() && OtherLeafs.empty()) |
268 | return false; |
269 | |
270 | LLVM_DEBUG({ |
271 | dbgs() << "Found group:\n Exts:\n" ; |
272 | for (auto *I : Exts) |
273 | dbgs() << " " << *I << "\n" ; |
274 | dbgs() << " Ops:\n" ; |
275 | for (auto *I : Ops) |
276 | dbgs() << " " << *I << "\n" ; |
277 | dbgs() << " OtherLeafs:\n" ; |
278 | for (auto *I : OtherLeafs) |
279 | dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n" ; |
280 | dbgs() << " Truncs:\n" ; |
281 | for (auto *I : Truncs) |
282 | dbgs() << " " << *I << "\n" ; |
283 | dbgs() << " Reducts:\n" ; |
284 | for (auto *I : Reducts) |
285 | dbgs() << " " << *I << "\n" ; |
286 | }); |
287 | |
288 | assert((!Truncs.empty() || !Reducts.empty()) && |
289 | "Expected some truncs or reductions" ); |
290 | if (Truncs.empty() && Exts.empty()) |
291 | return false; |
292 | |
293 | auto *VT = !Truncs.empty() |
294 | ? cast<FixedVectorType>(Val: Truncs[0]->getType()) |
295 | : cast<FixedVectorType>(Val: Exts[0]->getOperand(i: 0)->getType()); |
296 | LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n" ); |
297 | |
298 | // Check types |
299 | unsigned NumElts = VT->getNumElements(); |
300 | unsigned BaseElts = VT->getScalarSizeInBits() == 16 |
301 | ? 8 |
302 | : (VT->getScalarSizeInBits() == 8 ? 16 : 0); |
303 | if (BaseElts == 0 || NumElts % BaseElts != 0) { |
304 | LLVM_DEBUG(dbgs() << " Type is unsupported\n" ); |
305 | return false; |
306 | } |
307 | if (Start->getOperand(i: 0)->getType()->getScalarSizeInBits() != |
308 | VT->getScalarSizeInBits() * 2) { |
309 | LLVM_DEBUG(dbgs() << " Type not double sized\n" ); |
310 | return false; |
311 | } |
312 | for (Instruction *I : Exts) |
313 | if (I->getOperand(i: 0)->getType() != VT) { |
314 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
315 | return false; |
316 | } |
317 | for (Instruction *I : Truncs) |
318 | if (I->getType() != VT) { |
319 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
320 | return false; |
321 | } |
322 | |
323 | // Check that it looks beneficial |
324 | if (!isProfitableToInterleave(Exts, Truncs)) |
325 | return false; |
326 | if (!Reducts.empty() && (Ops.empty() || all_of(Range&: Ops, P: [](Instruction *I) { |
327 | return I->getOpcode() == Instruction::Mul || |
328 | I->getOpcode() == Instruction::Select || |
329 | I->getOpcode() == Instruction::ICmp; |
330 | }))) { |
331 | LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n" ); |
332 | return false; |
333 | } |
334 | |
335 | // Create new shuffles around the extends / truncs / other leaves. |
336 | IRBuilder<> Builder(Start); |
337 | |
338 | SmallVector<int, 16> LeafMask; |
339 | SmallVector<int, 16> TruncMask; |
340 | // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 |
341 | // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 |
342 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
343 | for (unsigned i = 0; i < BaseElts / 2; i++) |
344 | LeafMask.push_back(Elt: Base + i * 2); |
345 | for (unsigned i = 0; i < BaseElts / 2; i++) |
346 | LeafMask.push_back(Elt: Base + i * 2 + 1); |
347 | } |
348 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
349 | for (unsigned i = 0; i < BaseElts / 2; i++) { |
350 | TruncMask.push_back(Elt: Base + i); |
351 | TruncMask.push_back(Elt: Base + i + BaseElts / 2); |
352 | } |
353 | } |
354 | |
355 | for (Instruction *I : Exts) { |
356 | LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n" ); |
357 | Builder.SetInsertPoint(I); |
358 | Value *Shuffle = Builder.CreateShuffleVector(V: I->getOperand(i: 0), Mask: LeafMask); |
359 | bool FPext = isa<FPExtInst>(Val: I); |
360 | bool Sext = isa<SExtInst>(Val: I); |
361 | Value *Ext = FPext ? Builder.CreateFPExt(V: Shuffle, DestTy: I->getType()) |
362 | : Sext ? Builder.CreateSExt(V: Shuffle, DestTy: I->getType()) |
363 | : Builder.CreateZExt(V: Shuffle, DestTy: I->getType()); |
364 | I->replaceAllUsesWith(V: Ext); |
365 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
366 | } |
367 | |
368 | for (Use *I : OtherLeafs) { |
369 | LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n" ); |
370 | Builder.SetInsertPoint(cast<Instruction>(Val: I->getUser())); |
371 | Value *Shuffle = Builder.CreateShuffleVector(V: I->get(), Mask: LeafMask); |
372 | I->getUser()->setOperand(i: I->getOperandNo(), Val: Shuffle); |
373 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
374 | } |
375 | |
376 | for (Instruction *I : Truncs) { |
377 | LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n" ); |
378 | |
379 | Builder.SetInsertPoint(TheBB: I->getParent(), IP: ++I->getIterator()); |
380 | Value *Shuf = Builder.CreateShuffleVector(V: I, Mask: TruncMask); |
381 | I->replaceAllUsesWith(V: Shuf); |
382 | cast<Instruction>(Val: Shuf)->setOperand(i: 0, Val: I); |
383 | |
384 | LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n" ); |
385 | } |
386 | |
387 | return true; |
388 | } |
389 | |
390 | // Add reductions are fairly common and associative, meaning we can start the |
391 | // interleaving from them and don't need to emit a shuffle. |
392 | static bool isAddReduction(Instruction &I) { |
393 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) |
394 | return II->getIntrinsicID() == Intrinsic::vector_reduce_add; |
395 | return false; |
396 | } |
397 | |
398 | bool MVELaneInterleaving::runOnFunction(Function &F) { |
399 | if (!EnableInterleave) |
400 | return false; |
401 | auto &TPC = getAnalysis<TargetPassConfig>(); |
402 | auto &TM = TPC.getTM<TargetMachine>(); |
403 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
404 | if (!ST->hasMVEIntegerOps()) |
405 | return false; |
406 | |
407 | bool Changed = false; |
408 | |
409 | SmallPtrSet<Instruction *, 16> Visited; |
410 | for (Instruction &I : reverse(C: instructions(F))) { |
411 | if (((I.getType()->isVectorTy() && |
412 | (isa<TruncInst>(Val: I) || isa<FPTruncInst>(Val: I))) || |
413 | isAddReduction(I)) && |
414 | !Visited.count(Ptr: &I)) |
415 | Changed |= tryInterleave(Start: &I, Visited); |
416 | } |
417 | |
418 | return Changed; |
419 | } |
420 | |