1 | //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This pass interleaves around sext/zext/trunc instructions. MVE does not have |
10 | // a single sext/zext or trunc instruction that takes the bottom half of a |
11 | // vector and extends to a full width, like NEON has with MOVL. Instead it is |
12 | // expected that this happens through top/bottom instructions. So the MVE |
13 | // equivalent VMOVLT/B instructions take either the even or odd elements of the |
14 | // input and extend them to the larger type, producing a vector with half the |
15 | // number of elements each of double the bitwidth. As there is no simple |
16 | // instruction, we often have to turn sext/zext/trunc into a series of lane |
17 | // moves (or stack loads/stores, which we do not do yet). |
18 | // |
19 | // This pass takes vector code that starts at truncs, looks for interconnected |
20 | // blobs of operations that end with sext/zext (or constants/splats) of the |
21 | // form: |
22 | // %sa = sext v8i16 %a to v8i32 |
23 | // %sb = sext v8i16 %b to v8i32 |
24 | // %add = add v8i32 %sa, %sb |
25 | // %r = trunc %add to v8i16 |
26 | // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: |
27 | // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
28 | // %sa = sext v8i16 %sha to v8i32 |
29 | // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
30 | // %sb = sext v8i16 %shb to v8i32 |
31 | // %add = add v8i32 %sa, %sb |
32 | // %r = trunc %add to v8i16 |
33 | // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> |
34 | // Which can then be split and lowered to MVE instructions efficiently: |
35 | // %sa_b = VMOVLB.s16 %a |
36 | // %sa_t = VMOVLT.s16 %a |
37 | // %sb_b = VMOVLB.s16 %b |
38 | // %sb_t = VMOVLT.s16 %b |
39 | // %add_b = VADD.i32 %sa_b, %sb_b |
40 | // %add_t = VADD.i32 %sa_t, %sb_t |
41 | // %r = VMOVNT.i16 %add_b, %add_t |
42 | // |
43 | //===----------------------------------------------------------------------===// |
44 | |
45 | #include "ARM.h" |
46 | #include "ARMBaseInstrInfo.h" |
47 | #include "ARMSubtarget.h" |
48 | #include "llvm/ADT/SetVector.h" |
49 | #include "llvm/Analysis/TargetTransformInfo.h" |
50 | #include "llvm/CodeGen/TargetLowering.h" |
51 | #include "llvm/CodeGen/TargetPassConfig.h" |
52 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
53 | #include "llvm/IR/BasicBlock.h" |
54 | #include "llvm/IR/Constant.h" |
55 | #include "llvm/IR/Constants.h" |
56 | #include "llvm/IR/DerivedTypes.h" |
57 | #include "llvm/IR/Function.h" |
58 | #include "llvm/IR/IRBuilder.h" |
59 | #include "llvm/IR/InstIterator.h" |
60 | #include "llvm/IR/InstrTypes.h" |
61 | #include "llvm/IR/Instruction.h" |
62 | #include "llvm/IR/Instructions.h" |
63 | #include "llvm/IR/IntrinsicInst.h" |
64 | #include "llvm/IR/Intrinsics.h" |
65 | #include "llvm/IR/IntrinsicsARM.h" |
66 | #include "llvm/IR/PatternMatch.h" |
67 | #include "llvm/IR/Type.h" |
68 | #include "llvm/IR/Value.h" |
69 | #include "llvm/InitializePasses.h" |
70 | #include "llvm/Pass.h" |
71 | #include "llvm/Support/Casting.h" |
72 | #include <algorithm> |
73 | #include <cassert> |
74 | |
75 | using namespace llvm; |
76 | |
77 | #define DEBUG_TYPE "mve-laneinterleave" |
78 | |
79 | cl::opt<bool> EnableInterleave( |
80 | "enable-mve-interleave" , cl::Hidden, cl::init(Val: true), |
81 | cl::desc("Enable interleave MVE vector operation lowering" )); |
82 | |
83 | namespace { |
84 | |
85 | class MVELaneInterleaving : public FunctionPass { |
86 | public: |
87 | static char ID; // Pass identification, replacement for typeid |
88 | |
89 | explicit MVELaneInterleaving() : FunctionPass(ID) { |
90 | initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); |
91 | } |
92 | |
93 | bool runOnFunction(Function &F) override; |
94 | |
95 | StringRef getPassName() const override { return "MVE lane interleaving" ; } |
96 | |
97 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
98 | AU.setPreservesCFG(); |
99 | AU.addRequired<TargetPassConfig>(); |
100 | FunctionPass::getAnalysisUsage(AU); |
101 | } |
102 | }; |
103 | |
104 | } // end anonymous namespace |
105 | |
106 | char MVELaneInterleaving::ID = 0; |
107 | |
108 | INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving" , false, |
109 | false) |
110 | |
111 | Pass *llvm::createMVELaneInterleavingPass() { |
112 | return new MVELaneInterleaving(); |
113 | } |
114 | |
115 | static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts, |
116 | SmallSetVector<Instruction *, 4> &Truncs) { |
117 | // This is not always beneficial to transform. Exts can be incorporated into |
118 | // loads, Truncs can be folded into stores. |
119 | // Truncs are usually the same number of instructions, |
120 | // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving |
121 | // Exts are unfortunately more instructions in the general case: |
122 | // A=VLDRH.32; B=VLDRH.32; |
123 | // vs with interleaving: |
124 | // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T |
125 | // But those VMOVL may be folded into a VMULL. |
126 | |
127 | // But expensive extends/truncs are always good to remove. FPExts always |
128 | // involve extra VCVT's so are always considered to be beneficial to convert. |
129 | for (auto *E : Exts) { |
130 | if (isa<FPExtInst>(Val: E) || !isa<LoadInst>(Val: E->getOperand(i: 0))) { |
131 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n" ); |
132 | return true; |
133 | } |
134 | } |
135 | for (auto *T : Truncs) { |
136 | if (T->hasOneUse() && !isa<StoreInst>(Val: *T->user_begin())) { |
137 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n" ); |
138 | return true; |
139 | } |
140 | } |
141 | |
142 | // Otherwise, we know we have a load(ext), see if any of the Extends are a |
143 | // vmull. This is a simple heuristic and certainly not perfect. |
144 | for (auto *E : Exts) { |
145 | if (!E->hasOneUse() || |
146 | cast<Instruction>(Val: *E->user_begin())->getOpcode() != Instruction::Mul) { |
147 | LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n" ); |
148 | return false; |
149 | } |
150 | } |
151 | return true; |
152 | } |
153 | |
154 | static bool tryInterleave(Instruction *Start, |
155 | SmallPtrSetImpl<Instruction *> &Visited) { |
156 | LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n" ); |
157 | |
158 | if (!isa<Instruction>(Val: Start->getOperand(i: 0))) |
159 | return false; |
160 | |
161 | // Look for connected operations starting from Ext's, terminating at Truncs. |
162 | std::vector<Instruction *> Worklist; |
163 | Worklist.push_back(x: Start); |
164 | Worklist.push_back(x: cast<Instruction>(Val: Start->getOperand(i: 0))); |
165 | |
166 | SmallSetVector<Instruction *, 4> Truncs; |
167 | SmallSetVector<Instruction *, 4> Reducts; |
168 | SmallSetVector<Instruction *, 4> Exts; |
169 | SmallSetVector<Use *, 4> OtherLeafs; |
170 | SmallSetVector<Instruction *, 4> Ops; |
171 | |
172 | while (!Worklist.empty()) { |
173 | Instruction *I = Worklist.back(); |
174 | Worklist.pop_back(); |
175 | |
176 | switch (I->getOpcode()) { |
177 | // Truncs |
178 | case Instruction::Trunc: |
179 | case Instruction::FPTrunc: |
180 | if (!Truncs.insert(X: I)) |
181 | continue; |
182 | Visited.insert(Ptr: I); |
183 | break; |
184 | |
185 | // Extend leafs |
186 | case Instruction::SExt: |
187 | case Instruction::ZExt: |
188 | case Instruction::FPExt: |
189 | if (Exts.count(key: I)) |
190 | continue; |
191 | for (auto *Use : I->users()) |
192 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
193 | Exts.insert(X: I); |
194 | break; |
195 | |
196 | case Instruction::Call: { |
197 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I); |
198 | if (!II) |
199 | return false; |
200 | |
201 | if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) { |
202 | if (!Reducts.insert(X: I)) |
203 | continue; |
204 | Visited.insert(Ptr: I); |
205 | break; |
206 | } |
207 | |
208 | switch (II->getIntrinsicID()) { |
209 | case Intrinsic::abs: |
210 | case Intrinsic::smin: |
211 | case Intrinsic::smax: |
212 | case Intrinsic::umin: |
213 | case Intrinsic::umax: |
214 | case Intrinsic::sadd_sat: |
215 | case Intrinsic::ssub_sat: |
216 | case Intrinsic::uadd_sat: |
217 | case Intrinsic::usub_sat: |
218 | case Intrinsic::minnum: |
219 | case Intrinsic::maxnum: |
220 | case Intrinsic::fabs: |
221 | case Intrinsic::fma: |
222 | case Intrinsic::ceil: |
223 | case Intrinsic::floor: |
224 | case Intrinsic::rint: |
225 | case Intrinsic::round: |
226 | case Intrinsic::trunc: |
227 | break; |
228 | default: |
229 | return false; |
230 | } |
231 | [[fallthrough]]; // Fall through to treating these like an operator below. |
232 | } |
233 | // Binary/tertiary ops |
234 | case Instruction::Add: |
235 | case Instruction::Sub: |
236 | case Instruction::Mul: |
237 | case Instruction::AShr: |
238 | case Instruction::LShr: |
239 | case Instruction::Shl: |
240 | case Instruction::ICmp: |
241 | case Instruction::FCmp: |
242 | case Instruction::FAdd: |
243 | case Instruction::FMul: |
244 | case Instruction::Select: |
245 | if (!Ops.insert(X: I)) |
246 | continue; |
247 | |
248 | for (Use &Op : I->operands()) { |
249 | if (!isa<FixedVectorType>(Val: Op->getType())) |
250 | continue; |
251 | if (isa<Instruction>(Val: Op)) |
252 | Worklist.push_back(x: cast<Instruction>(Val: &Op)); |
253 | else |
254 | OtherLeafs.insert(X: &Op); |
255 | } |
256 | |
257 | for (auto *Use : I->users()) |
258 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
259 | break; |
260 | |
261 | case Instruction::ShuffleVector: |
262 | // A shuffle of a splat is a splat. |
263 | if (cast<ShuffleVectorInst>(Val: I)->isZeroEltSplat()) |
264 | continue; |
265 | [[fallthrough]]; |
266 | |
267 | default: |
268 | LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n" ); |
269 | return false; |
270 | } |
271 | } |
272 | |
273 | if (Exts.empty() && OtherLeafs.empty()) |
274 | return false; |
275 | |
276 | LLVM_DEBUG({ |
277 | dbgs() << "Found group:\n Exts:\n" ; |
278 | for (auto *I : Exts) |
279 | dbgs() << " " << *I << "\n" ; |
280 | dbgs() << " Ops:\n" ; |
281 | for (auto *I : Ops) |
282 | dbgs() << " " << *I << "\n" ; |
283 | dbgs() << " OtherLeafs:\n" ; |
284 | for (auto *I : OtherLeafs) |
285 | dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n" ; |
286 | dbgs() << " Truncs:\n" ; |
287 | for (auto *I : Truncs) |
288 | dbgs() << " " << *I << "\n" ; |
289 | dbgs() << " Reducts:\n" ; |
290 | for (auto *I : Reducts) |
291 | dbgs() << " " << *I << "\n" ; |
292 | }); |
293 | |
294 | assert((!Truncs.empty() || !Reducts.empty()) && |
295 | "Expected some truncs or reductions" ); |
296 | if (Truncs.empty() && Exts.empty()) |
297 | return false; |
298 | |
299 | auto *VT = !Truncs.empty() |
300 | ? cast<FixedVectorType>(Val: Truncs[0]->getType()) |
301 | : cast<FixedVectorType>(Val: Exts[0]->getOperand(i: 0)->getType()); |
302 | LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n" ); |
303 | |
304 | // Check types |
305 | unsigned NumElts = VT->getNumElements(); |
306 | unsigned BaseElts = VT->getScalarSizeInBits() == 16 |
307 | ? 8 |
308 | : (VT->getScalarSizeInBits() == 8 ? 16 : 0); |
309 | if (BaseElts == 0 || NumElts % BaseElts != 0) { |
310 | LLVM_DEBUG(dbgs() << " Type is unsupported\n" ); |
311 | return false; |
312 | } |
313 | if (Start->getOperand(i: 0)->getType()->getScalarSizeInBits() != |
314 | VT->getScalarSizeInBits() * 2) { |
315 | LLVM_DEBUG(dbgs() << " Type not double sized\n" ); |
316 | return false; |
317 | } |
318 | for (Instruction *I : Exts) |
319 | if (I->getOperand(i: 0)->getType() != VT) { |
320 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
321 | return false; |
322 | } |
323 | for (Instruction *I : Truncs) |
324 | if (I->getType() != VT) { |
325 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
326 | return false; |
327 | } |
328 | |
329 | // Check that it looks beneficial |
330 | if (!isProfitableToInterleave(Exts, Truncs)) |
331 | return false; |
332 | if (!Reducts.empty() && (Ops.empty() || all_of(Range&: Ops, P: [](Instruction *I) { |
333 | return I->getOpcode() == Instruction::Mul || |
334 | I->getOpcode() == Instruction::Select || |
335 | I->getOpcode() == Instruction::ICmp; |
336 | }))) { |
337 | LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n" ); |
338 | return false; |
339 | } |
340 | |
341 | // Create new shuffles around the extends / truncs / other leaves. |
342 | IRBuilder<> Builder(Start); |
343 | |
344 | SmallVector<int, 16> LeafMask; |
345 | SmallVector<int, 16> TruncMask; |
346 | // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 |
347 | // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 |
348 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
349 | for (unsigned i = 0; i < BaseElts / 2; i++) |
350 | LeafMask.push_back(Elt: Base + i * 2); |
351 | for (unsigned i = 0; i < BaseElts / 2; i++) |
352 | LeafMask.push_back(Elt: Base + i * 2 + 1); |
353 | } |
354 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
355 | for (unsigned i = 0; i < BaseElts / 2; i++) { |
356 | TruncMask.push_back(Elt: Base + i); |
357 | TruncMask.push_back(Elt: Base + i + BaseElts / 2); |
358 | } |
359 | } |
360 | |
361 | for (Instruction *I : Exts) { |
362 | LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n" ); |
363 | Builder.SetInsertPoint(I); |
364 | Value *Shuffle = Builder.CreateShuffleVector(V: I->getOperand(i: 0), Mask: LeafMask); |
365 | bool FPext = isa<FPExtInst>(Val: I); |
366 | bool Sext = isa<SExtInst>(Val: I); |
367 | Value *Ext = FPext ? Builder.CreateFPExt(V: Shuffle, DestTy: I->getType()) |
368 | : Sext ? Builder.CreateSExt(V: Shuffle, DestTy: I->getType()) |
369 | : Builder.CreateZExt(V: Shuffle, DestTy: I->getType()); |
370 | I->replaceAllUsesWith(V: Ext); |
371 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
372 | } |
373 | |
374 | for (Use *I : OtherLeafs) { |
375 | LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n" ); |
376 | Builder.SetInsertPoint(cast<Instruction>(Val: I->getUser())); |
377 | Value *Shuffle = Builder.CreateShuffleVector(V: I->get(), Mask: LeafMask); |
378 | I->getUser()->setOperand(i: I->getOperandNo(), Val: Shuffle); |
379 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
380 | } |
381 | |
382 | for (Instruction *I : Truncs) { |
383 | LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n" ); |
384 | |
385 | Builder.SetInsertPoint(TheBB: I->getParent(), IP: ++I->getIterator()); |
386 | Value *Shuf = Builder.CreateShuffleVector(V: I, Mask: TruncMask); |
387 | I->replaceAllUsesWith(V: Shuf); |
388 | cast<Instruction>(Val: Shuf)->setOperand(i: 0, Val: I); |
389 | |
390 | LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n" ); |
391 | } |
392 | |
393 | return true; |
394 | } |
395 | |
396 | // Add reductions are fairly common and associative, meaning we can start the |
397 | // interleaving from them and don't need to emit a shuffle. |
398 | static bool isAddReduction(Instruction &I) { |
399 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) |
400 | return II->getIntrinsicID() == Intrinsic::vector_reduce_add; |
401 | return false; |
402 | } |
403 | |
404 | bool MVELaneInterleaving::runOnFunction(Function &F) { |
405 | if (!EnableInterleave) |
406 | return false; |
407 | auto &TPC = getAnalysis<TargetPassConfig>(); |
408 | auto &TM = TPC.getTM<TargetMachine>(); |
409 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
410 | if (!ST->hasMVEIntegerOps()) |
411 | return false; |
412 | |
413 | bool Changed = false; |
414 | |
415 | SmallPtrSet<Instruction *, 16> Visited; |
416 | for (Instruction &I : reverse(C: instructions(F))) { |
417 | if (((I.getType()->isVectorTy() && |
418 | (isa<TruncInst>(Val: I) || isa<FPTruncInst>(Val: I))) || |
419 | isAddReduction(I)) && |
420 | !Visited.count(Ptr: &I)) |
421 | Changed |= tryInterleave(Start: &I, Visited); |
422 | } |
423 | |
424 | return Changed; |
425 | } |
426 | |