1//===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass interleaves around sext/zext/trunc instructions. MVE does not have
10// a single sext/zext or trunc instruction that takes the bottom half of a
11// vector and extends to a full width, like NEON has with MOVL. Instead it is
12// expected that this happens through top/bottom instructions. So the MVE
13// equivalent VMOVLT/B instructions take either the even or odd elements of the
14// input and extend them to the larger type, producing a vector with half the
15// number of elements each of double the bitwidth. As there is no simple
16// instruction, we often have to turn sext/zext/trunc into a series of lane
17// moves (or stack loads/stores, which we do not do yet).
18//
19// This pass takes vector code that starts at truncs, looks for interconnected
20// blobs of operations that end with sext/zext (or constants/splats) of the
21// form:
22// %sa = sext v8i16 %a to v8i32
23// %sb = sext v8i16 %b to v8i32
24// %add = add v8i32 %sa, %sb
25// %r = trunc %add to v8i16
26// And adds shuffles to allow the use of VMOVL/VMOVN instrctions:
27// %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7>
28// %sa = sext v8i16 %sha to v8i32
29// %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7>
30// %sb = sext v8i16 %shb to v8i32
31// %add = add v8i32 %sa, %sb
32// %r = trunc %add to v8i16
33// %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7>
34// Which can then be split and lowered to MVE instructions efficiently:
35// %sa_b = VMOVLB.s16 %a
36// %sa_t = VMOVLT.s16 %a
37// %sb_b = VMOVLB.s16 %b
38// %sb_t = VMOVLT.s16 %b
39// %add_b = VADD.i32 %sa_b, %sb_b
40// %add_t = VADD.i32 %sa_t, %sb_t
41// %r = VMOVNT.i16 %add_b, %add_t
42//
43//===----------------------------------------------------------------------===//
44
45#include "ARM.h"
46#include "ARMBaseInstrInfo.h"
47#include "ARMSubtarget.h"
48#include "llvm/ADT/SetVector.h"
49#include "llvm/Analysis/TargetTransformInfo.h"
50#include "llvm/CodeGen/TargetLowering.h"
51#include "llvm/CodeGen/TargetPassConfig.h"
52#include "llvm/IR/BasicBlock.h"
53#include "llvm/IR/DerivedTypes.h"
54#include "llvm/IR/Function.h"
55#include "llvm/IR/IRBuilder.h"
56#include "llvm/IR/InstIterator.h"
57#include "llvm/IR/InstrTypes.h"
58#include "llvm/IR/Instruction.h"
59#include "llvm/IR/Instructions.h"
60#include "llvm/IR/IntrinsicInst.h"
61#include "llvm/IR/Intrinsics.h"
62#include "llvm/IR/Type.h"
63#include "llvm/IR/Value.h"
64#include "llvm/InitializePasses.h"
65#include "llvm/Pass.h"
66#include "llvm/Support/Casting.h"
67#include <cassert>
68
69using namespace llvm;
70
71#define DEBUG_TYPE "mve-laneinterleave"
72
73static cl::opt<bool> EnableInterleave(
74 "enable-mve-interleave", cl::Hidden, cl::init(Val: true),
75 cl::desc("Enable interleave MVE vector operation lowering"));
76
77namespace {
78
79class MVELaneInterleaving : public FunctionPass {
80public:
81 static char ID; // Pass identification, replacement for typeid
82
83 explicit MVELaneInterleaving() : FunctionPass(ID) {
84 initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry());
85 }
86
87 bool runOnFunction(Function &F) override;
88
89 StringRef getPassName() const override { return "MVE lane interleaving"; }
90
91 void getAnalysisUsage(AnalysisUsage &AU) const override {
92 AU.setPreservesCFG();
93 AU.addRequired<TargetPassConfig>();
94 FunctionPass::getAnalysisUsage(AU);
95 }
96};
97
98} // end anonymous namespace
99
100char MVELaneInterleaving::ID = 0;
101
102INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving", false,
103 false)
104
105Pass *llvm::createMVELaneInterleavingPass() {
106 return new MVELaneInterleaving();
107}
108
109static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts,
110 SmallSetVector<Instruction *, 4> &Truncs) {
111 // This is not always beneficial to transform. Exts can be incorporated into
112 // loads, Truncs can be folded into stores.
113 // Truncs are usually the same number of instructions,
114 // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving
115 // Exts are unfortunately more instructions in the general case:
116 // A=VLDRH.32; B=VLDRH.32;
117 // vs with interleaving:
118 // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T
119 // But those VMOVL may be folded into a VMULL.
120
121 // But expensive extends/truncs are always good to remove. FPExts always
122 // involve extra VCVT's so are always considered to be beneficial to convert.
123 for (auto *E : Exts) {
124 if (isa<FPExtInst>(Val: E) || !isa<LoadInst>(Val: E->getOperand(i: 0))) {
125 LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n");
126 return true;
127 }
128 }
129 for (auto *T : Truncs) {
130 if (T->hasOneUse() && !isa<StoreInst>(Val: *T->user_begin())) {
131 LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n");
132 return true;
133 }
134 }
135
136 // Otherwise, we know we have a load(ext), see if any of the Extends are a
137 // vmull. This is a simple heuristic and certainly not perfect.
138 for (auto *E : Exts) {
139 if (!E->hasOneUse() ||
140 cast<Instruction>(Val: *E->user_begin())->getOpcode() != Instruction::Mul) {
141 LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n");
142 return false;
143 }
144 }
145 return true;
146}
147
148static bool tryInterleave(Instruction *Start,
149 SmallPtrSetImpl<Instruction *> &Visited) {
150 LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n");
151
152 if (!isa<Instruction>(Val: Start->getOperand(i: 0)))
153 return false;
154
155 // Look for connected operations starting from Ext's, terminating at Truncs.
156 std::vector<Instruction *> Worklist;
157 Worklist.push_back(x: Start);
158 Worklist.push_back(x: cast<Instruction>(Val: Start->getOperand(i: 0)));
159
160 SmallSetVector<Instruction *, 4> Truncs;
161 SmallSetVector<Instruction *, 4> Reducts;
162 SmallSetVector<Instruction *, 4> Exts;
163 SmallSetVector<Use *, 4> OtherLeafs;
164 SmallSetVector<Instruction *, 4> Ops;
165
166 while (!Worklist.empty()) {
167 Instruction *I = Worklist.back();
168 Worklist.pop_back();
169
170 switch (I->getOpcode()) {
171 // Truncs
172 case Instruction::Trunc:
173 case Instruction::FPTrunc:
174 if (!Truncs.insert(X: I))
175 continue;
176 Visited.insert(Ptr: I);
177 break;
178
179 // Extend leafs
180 case Instruction::SExt:
181 case Instruction::ZExt:
182 case Instruction::FPExt:
183 if (Exts.count(key: I))
184 continue;
185 for (auto *Use : I->users())
186 Worklist.push_back(x: cast<Instruction>(Val: Use));
187 Exts.insert(X: I);
188 break;
189
190 case Instruction::Call: {
191 IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I);
192 if (!II)
193 return false;
194
195 if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) {
196 if (!Reducts.insert(X: I))
197 continue;
198 Visited.insert(Ptr: I);
199 break;
200 }
201
202 switch (II->getIntrinsicID()) {
203 case Intrinsic::abs:
204 case Intrinsic::smin:
205 case Intrinsic::smax:
206 case Intrinsic::umin:
207 case Intrinsic::umax:
208 case Intrinsic::sadd_sat:
209 case Intrinsic::ssub_sat:
210 case Intrinsic::uadd_sat:
211 case Intrinsic::usub_sat:
212 case Intrinsic::minnum:
213 case Intrinsic::maxnum:
214 case Intrinsic::fabs:
215 case Intrinsic::fma:
216 case Intrinsic::ceil:
217 case Intrinsic::floor:
218 case Intrinsic::rint:
219 case Intrinsic::round:
220 case Intrinsic::trunc:
221 break;
222 default:
223 return false;
224 }
225 [[fallthrough]]; // Fall through to treating these like an operator below.
226 }
227 // Binary/tertiary ops
228 case Instruction::Add:
229 case Instruction::Sub:
230 case Instruction::Mul:
231 case Instruction::AShr:
232 case Instruction::LShr:
233 case Instruction::Shl:
234 case Instruction::ICmp:
235 case Instruction::FCmp:
236 case Instruction::FAdd:
237 case Instruction::FMul:
238 case Instruction::Select:
239 if (!Ops.insert(X: I))
240 continue;
241
242 for (Use &Op : I->operands()) {
243 if (!isa<FixedVectorType>(Val: Op->getType()))
244 continue;
245 if (isa<Instruction>(Val: Op))
246 Worklist.push_back(x: cast<Instruction>(Val: &Op));
247 else
248 OtherLeafs.insert(X: &Op);
249 }
250
251 for (auto *Use : I->users())
252 Worklist.push_back(x: cast<Instruction>(Val: Use));
253 break;
254
255 case Instruction::ShuffleVector:
256 // A shuffle of a splat is a splat.
257 if (cast<ShuffleVectorInst>(Val: I)->isZeroEltSplat())
258 continue;
259 [[fallthrough]];
260
261 default:
262 LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n");
263 return false;
264 }
265 }
266
267 if (Exts.empty() && OtherLeafs.empty())
268 return false;
269
270 LLVM_DEBUG({
271 dbgs() << "Found group:\n Exts:\n";
272 for (auto *I : Exts)
273 dbgs() << " " << *I << "\n";
274 dbgs() << " Ops:\n";
275 for (auto *I : Ops)
276 dbgs() << " " << *I << "\n";
277 dbgs() << " OtherLeafs:\n";
278 for (auto *I : OtherLeafs)
279 dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n";
280 dbgs() << " Truncs:\n";
281 for (auto *I : Truncs)
282 dbgs() << " " << *I << "\n";
283 dbgs() << " Reducts:\n";
284 for (auto *I : Reducts)
285 dbgs() << " " << *I << "\n";
286 });
287
288 assert((!Truncs.empty() || !Reducts.empty()) &&
289 "Expected some truncs or reductions");
290 if (Truncs.empty() && Exts.empty())
291 return false;
292
293 auto *VT = !Truncs.empty()
294 ? cast<FixedVectorType>(Val: Truncs[0]->getType())
295 : cast<FixedVectorType>(Val: Exts[0]->getOperand(i: 0)->getType());
296 LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n");
297
298 // Check types
299 unsigned NumElts = VT->getNumElements();
300 unsigned BaseElts = VT->getScalarSizeInBits() == 16
301 ? 8
302 : (VT->getScalarSizeInBits() == 8 ? 16 : 0);
303 if (BaseElts == 0 || NumElts % BaseElts != 0) {
304 LLVM_DEBUG(dbgs() << " Type is unsupported\n");
305 return false;
306 }
307 if (Start->getOperand(i: 0)->getType()->getScalarSizeInBits() !=
308 VT->getScalarSizeInBits() * 2) {
309 LLVM_DEBUG(dbgs() << " Type not double sized\n");
310 return false;
311 }
312 for (Instruction *I : Exts)
313 if (I->getOperand(i: 0)->getType() != VT) {
314 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
315 return false;
316 }
317 for (Instruction *I : Truncs)
318 if (I->getType() != VT) {
319 LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n");
320 return false;
321 }
322
323 // Check that it looks beneficial
324 if (!isProfitableToInterleave(Exts, Truncs))
325 return false;
326 if (!Reducts.empty() && (Ops.empty() || all_of(Range&: Ops, P: [](Instruction *I) {
327 return I->getOpcode() == Instruction::Mul ||
328 I->getOpcode() == Instruction::Select ||
329 I->getOpcode() == Instruction::ICmp;
330 }))) {
331 LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n");
332 return false;
333 }
334
335 // Create new shuffles around the extends / truncs / other leaves.
336 IRBuilder<> Builder(Start);
337
338 SmallVector<int, 16> LeafMask;
339 SmallVector<int, 16> TruncMask;
340 // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15
341 // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15
342 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
343 for (unsigned i = 0; i < BaseElts / 2; i++)
344 LeafMask.push_back(Elt: Base + i * 2);
345 for (unsigned i = 0; i < BaseElts / 2; i++)
346 LeafMask.push_back(Elt: Base + i * 2 + 1);
347 }
348 for (unsigned Base = 0; Base < NumElts; Base += BaseElts) {
349 for (unsigned i = 0; i < BaseElts / 2; i++) {
350 TruncMask.push_back(Elt: Base + i);
351 TruncMask.push_back(Elt: Base + i + BaseElts / 2);
352 }
353 }
354
355 for (Instruction *I : Exts) {
356 LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n");
357 Builder.SetInsertPoint(I);
358 Value *Shuffle = Builder.CreateShuffleVector(V: I->getOperand(i: 0), Mask: LeafMask);
359 bool FPext = isa<FPExtInst>(Val: I);
360 bool Sext = isa<SExtInst>(Val: I);
361 Value *Ext = FPext ? Builder.CreateFPExt(V: Shuffle, DestTy: I->getType())
362 : Sext ? Builder.CreateSExt(V: Shuffle, DestTy: I->getType())
363 : Builder.CreateZExt(V: Shuffle, DestTy: I->getType());
364 I->replaceAllUsesWith(V: Ext);
365 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
366 }
367
368 for (Use *I : OtherLeafs) {
369 LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n");
370 Builder.SetInsertPoint(cast<Instruction>(Val: I->getUser()));
371 Value *Shuffle = Builder.CreateShuffleVector(V: I->get(), Mask: LeafMask);
372 I->getUser()->setOperand(i: I->getOperandNo(), Val: Shuffle);
373 LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n");
374 }
375
376 for (Instruction *I : Truncs) {
377 LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n");
378
379 Builder.SetInsertPoint(TheBB: I->getParent(), IP: ++I->getIterator());
380 Value *Shuf = Builder.CreateShuffleVector(V: I, Mask: TruncMask);
381 I->replaceAllUsesWith(V: Shuf);
382 cast<Instruction>(Val: Shuf)->setOperand(i: 0, Val: I);
383
384 LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n");
385 }
386
387 return true;
388}
389
390// Add reductions are fairly common and associative, meaning we can start the
391// interleaving from them and don't need to emit a shuffle.
392static bool isAddReduction(Instruction &I) {
393 if (auto *II = dyn_cast<IntrinsicInst>(Val: &I))
394 return II->getIntrinsicID() == Intrinsic::vector_reduce_add;
395 return false;
396}
397
398bool MVELaneInterleaving::runOnFunction(Function &F) {
399 if (!EnableInterleave)
400 return false;
401 auto &TPC = getAnalysis<TargetPassConfig>();
402 auto &TM = TPC.getTM<TargetMachine>();
403 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
404 if (!ST->hasMVEIntegerOps())
405 return false;
406
407 bool Changed = false;
408
409 SmallPtrSet<Instruction *, 16> Visited;
410 for (Instruction &I : reverse(C: instructions(F))) {
411 if (((I.getType()->isVectorTy() &&
412 (isa<TruncInst>(Val: I) || isa<FPTruncInst>(Val: I))) ||
413 isAddReduction(I)) &&
414 !Visited.count(Ptr: &I))
415 Changed |= tryInterleave(Start: &I, Visited);
416 }
417
418 return Changed;
419}
420