| 1 | //===- MVELaneInterleaving.cpp - Inverleave for MVE instructions ----------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This pass interleaves around sext/zext/trunc instructions. MVE does not have |
| 10 | // a single sext/zext or trunc instruction that takes the bottom half of a |
| 11 | // vector and extends to a full width, like NEON has with MOVL. Instead it is |
| 12 | // expected that this happens through top/bottom instructions. So the MVE |
| 13 | // equivalent VMOVLT/B instructions take either the even or odd elements of the |
| 14 | // input and extend them to the larger type, producing a vector with half the |
| 15 | // number of elements each of double the bitwidth. As there is no simple |
| 16 | // instruction, we often have to turn sext/zext/trunc into a series of lane |
| 17 | // moves (or stack loads/stores, which we do not do yet). |
| 18 | // |
| 19 | // This pass takes vector code that starts at truncs, looks for interconnected |
| 20 | // blobs of operations that end with sext/zext (or constants/splats) of the |
| 21 | // form: |
| 22 | // %sa = sext v8i16 %a to v8i32 |
| 23 | // %sb = sext v8i16 %b to v8i32 |
| 24 | // %add = add v8i32 %sa, %sb |
| 25 | // %r = trunc %add to v8i16 |
| 26 | // And adds shuffles to allow the use of VMOVL/VMOVN instrctions: |
| 27 | // %sha = shuffle v8i16 %a, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
| 28 | // %sa = sext v8i16 %sha to v8i32 |
| 29 | // %shb = shuffle v8i16 %b, undef, <0, 2, 4, 6, 1, 3, 5, 7> |
| 30 | // %sb = sext v8i16 %shb to v8i32 |
| 31 | // %add = add v8i32 %sa, %sb |
| 32 | // %r = trunc %add to v8i16 |
| 33 | // %shr = shuffle v8i16 %r, undef, <0, 4, 1, 5, 2, 6, 3, 7> |
| 34 | // Which can then be split and lowered to MVE instructions efficiently: |
| 35 | // %sa_b = VMOVLB.s16 %a |
| 36 | // %sa_t = VMOVLT.s16 %a |
| 37 | // %sb_b = VMOVLB.s16 %b |
| 38 | // %sb_t = VMOVLT.s16 %b |
| 39 | // %add_b = VADD.i32 %sa_b, %sb_b |
| 40 | // %add_t = VADD.i32 %sa_t, %sb_t |
| 41 | // %r = VMOVNT.i16 %add_b, %add_t |
| 42 | // |
| 43 | //===----------------------------------------------------------------------===// |
| 44 | |
| 45 | #include "ARM.h" |
| 46 | #include "ARMBaseInstrInfo.h" |
| 47 | #include "ARMSubtarget.h" |
| 48 | #include "llvm/ADT/SetVector.h" |
| 49 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 50 | #include "llvm/CodeGen/TargetLowering.h" |
| 51 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 52 | #include "llvm/IR/BasicBlock.h" |
| 53 | #include "llvm/IR/DerivedTypes.h" |
| 54 | #include "llvm/IR/Function.h" |
| 55 | #include "llvm/IR/IRBuilder.h" |
| 56 | #include "llvm/IR/InstIterator.h" |
| 57 | #include "llvm/IR/InstrTypes.h" |
| 58 | #include "llvm/IR/Instruction.h" |
| 59 | #include "llvm/IR/Instructions.h" |
| 60 | #include "llvm/IR/IntrinsicInst.h" |
| 61 | #include "llvm/IR/Intrinsics.h" |
| 62 | #include "llvm/IR/Type.h" |
| 63 | #include "llvm/IR/Value.h" |
| 64 | #include "llvm/InitializePasses.h" |
| 65 | #include "llvm/Pass.h" |
| 66 | #include "llvm/Support/Casting.h" |
| 67 | #include <cassert> |
| 68 | |
| 69 | using namespace llvm; |
| 70 | |
| 71 | #define DEBUG_TYPE "mve-laneinterleave" |
| 72 | |
| 73 | static cl::opt<bool> EnableInterleave( |
| 74 | "enable-mve-interleave" , cl::Hidden, cl::init(Val: true), |
| 75 | cl::desc("Enable interleave MVE vector operation lowering" )); |
| 76 | |
| 77 | namespace { |
| 78 | |
| 79 | class MVELaneInterleaving : public FunctionPass { |
| 80 | public: |
| 81 | static char ID; // Pass identification, replacement for typeid |
| 82 | |
| 83 | explicit MVELaneInterleaving() : FunctionPass(ID) { |
| 84 | initializeMVELaneInterleavingPass(*PassRegistry::getPassRegistry()); |
| 85 | } |
| 86 | |
| 87 | bool runOnFunction(Function &F) override; |
| 88 | |
| 89 | StringRef getPassName() const override { return "MVE lane interleaving" ; } |
| 90 | |
| 91 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 92 | AU.setPreservesCFG(); |
| 93 | AU.addRequired<TargetPassConfig>(); |
| 94 | FunctionPass::getAnalysisUsage(AU); |
| 95 | } |
| 96 | }; |
| 97 | |
| 98 | } // end anonymous namespace |
| 99 | |
| 100 | char MVELaneInterleaving::ID = 0; |
| 101 | |
| 102 | INITIALIZE_PASS(MVELaneInterleaving, DEBUG_TYPE, "MVE lane interleaving" , false, |
| 103 | false) |
| 104 | |
| 105 | Pass *llvm::createMVELaneInterleavingPass() { |
| 106 | return new MVELaneInterleaving(); |
| 107 | } |
| 108 | |
| 109 | static bool isProfitableToInterleave(SmallSetVector<Instruction *, 4> &Exts, |
| 110 | SmallSetVector<Instruction *, 4> &Truncs) { |
| 111 | // This is not always beneficial to transform. Exts can be incorporated into |
| 112 | // loads, Truncs can be folded into stores. |
| 113 | // Truncs are usually the same number of instructions, |
| 114 | // VSTRH.32(A);VSTRH.32(B) vs VSTRH.16(VMOVNT A, B) with interleaving |
| 115 | // Exts are unfortunately more instructions in the general case: |
| 116 | // A=VLDRH.32; B=VLDRH.32; |
| 117 | // vs with interleaving: |
| 118 | // T=VLDRH.16; A=VMOVNB T; B=VMOVNT T |
| 119 | // But those VMOVL may be folded into a VMULL. |
| 120 | |
| 121 | // But expensive extends/truncs are always good to remove. FPExts always |
| 122 | // involve extra VCVT's so are always considered to be beneficial to convert. |
| 123 | for (auto *E : Exts) { |
| 124 | if (isa<FPExtInst>(Val: E) || !isa<LoadInst>(Val: E->getOperand(i: 0))) { |
| 125 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *E << "\n" ); |
| 126 | return true; |
| 127 | } |
| 128 | } |
| 129 | for (auto *T : Truncs) { |
| 130 | if (T->hasOneUse() && !isa<StoreInst>(Val: *T->user_begin())) { |
| 131 | LLVM_DEBUG(dbgs() << "Beneficial due to " << *T << "\n" ); |
| 132 | return true; |
| 133 | } |
| 134 | } |
| 135 | |
| 136 | // Otherwise, we know we have a load(ext), see if any of the Extends are a |
| 137 | // vmull. This is a simple heuristic and certainly not perfect. |
| 138 | for (auto *E : Exts) { |
| 139 | if (!E->hasOneUse() || |
| 140 | cast<Instruction>(Val: *E->user_begin())->getOpcode() != Instruction::Mul) { |
| 141 | LLVM_DEBUG(dbgs() << "Not beneficial due to " << *E << "\n" ); |
| 142 | return false; |
| 143 | } |
| 144 | } |
| 145 | return true; |
| 146 | } |
| 147 | |
| 148 | static bool tryInterleave(Instruction *Start, |
| 149 | SmallPtrSetImpl<Instruction *> &Visited) { |
| 150 | LLVM_DEBUG(dbgs() << "tryInterleave from " << *Start << "\n" ); |
| 151 | |
| 152 | if (!isa<Instruction>(Val: Start->getOperand(i: 0))) |
| 153 | return false; |
| 154 | |
| 155 | // Look for connected operations starting from Ext's, terminating at Truncs. |
| 156 | std::vector<Instruction *> Worklist; |
| 157 | Worklist.push_back(x: Start); |
| 158 | Worklist.push_back(x: cast<Instruction>(Val: Start->getOperand(i: 0))); |
| 159 | |
| 160 | SmallSetVector<Instruction *, 4> Truncs; |
| 161 | SmallSetVector<Instruction *, 4> Reducts; |
| 162 | SmallSetVector<Instruction *, 4> Exts; |
| 163 | SmallSetVector<Use *, 4> OtherLeafs; |
| 164 | SmallSetVector<Instruction *, 4> Ops; |
| 165 | |
| 166 | while (!Worklist.empty()) { |
| 167 | Instruction *I = Worklist.back(); |
| 168 | Worklist.pop_back(); |
| 169 | |
| 170 | switch (I->getOpcode()) { |
| 171 | // Truncs |
| 172 | case Instruction::Trunc: |
| 173 | case Instruction::FPTrunc: |
| 174 | if (!Truncs.insert(X: I)) |
| 175 | continue; |
| 176 | Visited.insert(Ptr: I); |
| 177 | break; |
| 178 | |
| 179 | // Extend leafs |
| 180 | case Instruction::SExt: |
| 181 | case Instruction::ZExt: |
| 182 | case Instruction::FPExt: |
| 183 | if (Exts.count(key: I)) |
| 184 | continue; |
| 185 | for (auto *Use : I->users()) |
| 186 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
| 187 | Exts.insert(X: I); |
| 188 | break; |
| 189 | |
| 190 | case Instruction::Call: { |
| 191 | IntrinsicInst *II = dyn_cast<IntrinsicInst>(Val: I); |
| 192 | if (!II) |
| 193 | return false; |
| 194 | |
| 195 | if (II->getIntrinsicID() == Intrinsic::vector_reduce_add) { |
| 196 | if (!Reducts.insert(X: I)) |
| 197 | continue; |
| 198 | Visited.insert(Ptr: I); |
| 199 | break; |
| 200 | } |
| 201 | |
| 202 | switch (II->getIntrinsicID()) { |
| 203 | case Intrinsic::abs: |
| 204 | case Intrinsic::smin: |
| 205 | case Intrinsic::smax: |
| 206 | case Intrinsic::umin: |
| 207 | case Intrinsic::umax: |
| 208 | case Intrinsic::sadd_sat: |
| 209 | case Intrinsic::ssub_sat: |
| 210 | case Intrinsic::uadd_sat: |
| 211 | case Intrinsic::usub_sat: |
| 212 | case Intrinsic::minnum: |
| 213 | case Intrinsic::maxnum: |
| 214 | case Intrinsic::fabs: |
| 215 | case Intrinsic::fma: |
| 216 | case Intrinsic::ceil: |
| 217 | case Intrinsic::floor: |
| 218 | case Intrinsic::rint: |
| 219 | case Intrinsic::round: |
| 220 | case Intrinsic::trunc: |
| 221 | break; |
| 222 | default: |
| 223 | return false; |
| 224 | } |
| 225 | [[fallthrough]]; // Fall through to treating these like an operator below. |
| 226 | } |
| 227 | // Binary/tertiary ops |
| 228 | case Instruction::Add: |
| 229 | case Instruction::Sub: |
| 230 | case Instruction::Mul: |
| 231 | case Instruction::AShr: |
| 232 | case Instruction::LShr: |
| 233 | case Instruction::Shl: |
| 234 | case Instruction::ICmp: |
| 235 | case Instruction::FCmp: |
| 236 | case Instruction::FAdd: |
| 237 | case Instruction::FMul: |
| 238 | case Instruction::Select: |
| 239 | if (!Ops.insert(X: I)) |
| 240 | continue; |
| 241 | |
| 242 | for (Use &Op : I->operands()) { |
| 243 | if (!isa<FixedVectorType>(Val: Op->getType())) |
| 244 | continue; |
| 245 | if (isa<Instruction>(Val: Op)) |
| 246 | Worklist.push_back(x: cast<Instruction>(Val: &Op)); |
| 247 | else |
| 248 | OtherLeafs.insert(X: &Op); |
| 249 | } |
| 250 | |
| 251 | for (auto *Use : I->users()) |
| 252 | Worklist.push_back(x: cast<Instruction>(Val: Use)); |
| 253 | break; |
| 254 | |
| 255 | case Instruction::ShuffleVector: |
| 256 | // A shuffle of a splat is a splat. |
| 257 | if (cast<ShuffleVectorInst>(Val: I)->isZeroEltSplat()) |
| 258 | continue; |
| 259 | [[fallthrough]]; |
| 260 | |
| 261 | default: |
| 262 | LLVM_DEBUG(dbgs() << " Unhandled instruction: " << *I << "\n" ); |
| 263 | return false; |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | if (Exts.empty() && OtherLeafs.empty()) |
| 268 | return false; |
| 269 | |
| 270 | LLVM_DEBUG({ |
| 271 | dbgs() << "Found group:\n Exts:\n" ; |
| 272 | for (auto *I : Exts) |
| 273 | dbgs() << " " << *I << "\n" ; |
| 274 | dbgs() << " Ops:\n" ; |
| 275 | for (auto *I : Ops) |
| 276 | dbgs() << " " << *I << "\n" ; |
| 277 | dbgs() << " OtherLeafs:\n" ; |
| 278 | for (auto *I : OtherLeafs) |
| 279 | dbgs() << " " << *I->get() << " of " << *I->getUser() << "\n" ; |
| 280 | dbgs() << " Truncs:\n" ; |
| 281 | for (auto *I : Truncs) |
| 282 | dbgs() << " " << *I << "\n" ; |
| 283 | dbgs() << " Reducts:\n" ; |
| 284 | for (auto *I : Reducts) |
| 285 | dbgs() << " " << *I << "\n" ; |
| 286 | }); |
| 287 | |
| 288 | assert((!Truncs.empty() || !Reducts.empty()) && |
| 289 | "Expected some truncs or reductions" ); |
| 290 | if (Truncs.empty() && Exts.empty()) |
| 291 | return false; |
| 292 | |
| 293 | auto *VT = !Truncs.empty() |
| 294 | ? cast<FixedVectorType>(Val: Truncs[0]->getType()) |
| 295 | : cast<FixedVectorType>(Val: Exts[0]->getOperand(i: 0)->getType()); |
| 296 | LLVM_DEBUG(dbgs() << "Using VT:" << *VT << "\n" ); |
| 297 | |
| 298 | // Check types |
| 299 | unsigned NumElts = VT->getNumElements(); |
| 300 | unsigned BaseElts = VT->getScalarSizeInBits() == 16 |
| 301 | ? 8 |
| 302 | : (VT->getScalarSizeInBits() == 8 ? 16 : 0); |
| 303 | if (BaseElts == 0 || NumElts % BaseElts != 0) { |
| 304 | LLVM_DEBUG(dbgs() << " Type is unsupported\n" ); |
| 305 | return false; |
| 306 | } |
| 307 | if (Start->getOperand(i: 0)->getType()->getScalarSizeInBits() != |
| 308 | VT->getScalarSizeInBits() * 2) { |
| 309 | LLVM_DEBUG(dbgs() << " Type not double sized\n" ); |
| 310 | return false; |
| 311 | } |
| 312 | for (Instruction *I : Exts) |
| 313 | if (I->getOperand(i: 0)->getType() != VT) { |
| 314 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
| 315 | return false; |
| 316 | } |
| 317 | for (Instruction *I : Truncs) |
| 318 | if (I->getType() != VT) { |
| 319 | LLVM_DEBUG(dbgs() << " Wrong type on " << *I << "\n" ); |
| 320 | return false; |
| 321 | } |
| 322 | |
| 323 | // Check that it looks beneficial |
| 324 | if (!isProfitableToInterleave(Exts, Truncs)) |
| 325 | return false; |
| 326 | if (!Reducts.empty() && (Ops.empty() || all_of(Range&: Ops, P: [](Instruction *I) { |
| 327 | return I->getOpcode() == Instruction::Mul || |
| 328 | I->getOpcode() == Instruction::Select || |
| 329 | I->getOpcode() == Instruction::ICmp; |
| 330 | }))) { |
| 331 | LLVM_DEBUG(dbgs() << "Reduction does not look profitable\n" ); |
| 332 | return false; |
| 333 | } |
| 334 | |
| 335 | // Create new shuffles around the extends / truncs / other leaves. |
| 336 | IRBuilder<> Builder(Start); |
| 337 | |
| 338 | SmallVector<int, 16> LeafMask; |
| 339 | SmallVector<int, 16> TruncMask; |
| 340 | // LeafMask : 0, 2, 4, 6, 1, 3, 5, 7 8, 10, 12, 14, 9, 11, 13, 15 |
| 341 | // TruncMask: 0, 4, 1, 5, 2, 6, 3, 7 8, 12, 9, 13, 10, 14, 11, 15 |
| 342 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
| 343 | for (unsigned i = 0; i < BaseElts / 2; i++) |
| 344 | LeafMask.push_back(Elt: Base + i * 2); |
| 345 | for (unsigned i = 0; i < BaseElts / 2; i++) |
| 346 | LeafMask.push_back(Elt: Base + i * 2 + 1); |
| 347 | } |
| 348 | for (unsigned Base = 0; Base < NumElts; Base += BaseElts) { |
| 349 | for (unsigned i = 0; i < BaseElts / 2; i++) { |
| 350 | TruncMask.push_back(Elt: Base + i); |
| 351 | TruncMask.push_back(Elt: Base + i + BaseElts / 2); |
| 352 | } |
| 353 | } |
| 354 | |
| 355 | for (Instruction *I : Exts) { |
| 356 | LLVM_DEBUG(dbgs() << "Replacing ext " << *I << "\n" ); |
| 357 | Builder.SetInsertPoint(I); |
| 358 | Value *Shuffle = Builder.CreateShuffleVector(V: I->getOperand(i: 0), Mask: LeafMask); |
| 359 | bool FPext = isa<FPExtInst>(Val: I); |
| 360 | bool Sext = isa<SExtInst>(Val: I); |
| 361 | Value *Ext = FPext ? Builder.CreateFPExt(V: Shuffle, DestTy: I->getType()) |
| 362 | : Sext ? Builder.CreateSExt(V: Shuffle, DestTy: I->getType()) |
| 363 | : Builder.CreateZExt(V: Shuffle, DestTy: I->getType()); |
| 364 | I->replaceAllUsesWith(V: Ext); |
| 365 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
| 366 | } |
| 367 | |
| 368 | for (Use *I : OtherLeafs) { |
| 369 | LLVM_DEBUG(dbgs() << "Replacing leaf " << *I << "\n" ); |
| 370 | Builder.SetInsertPoint(cast<Instruction>(Val: I->getUser())); |
| 371 | Value *Shuffle = Builder.CreateShuffleVector(V: I->get(), Mask: LeafMask); |
| 372 | I->getUser()->setOperand(i: I->getOperandNo(), Val: Shuffle); |
| 373 | LLVM_DEBUG(dbgs() << " with " << *Shuffle << "\n" ); |
| 374 | } |
| 375 | |
| 376 | for (Instruction *I : Truncs) { |
| 377 | LLVM_DEBUG(dbgs() << "Replacing trunc " << *I << "\n" ); |
| 378 | |
| 379 | Builder.SetInsertPoint(TheBB: I->getParent(), IP: ++I->getIterator()); |
| 380 | Value *Shuf = Builder.CreateShuffleVector(V: I, Mask: TruncMask); |
| 381 | I->replaceAllUsesWith(V: Shuf); |
| 382 | cast<Instruction>(Val: Shuf)->setOperand(i: 0, Val: I); |
| 383 | |
| 384 | LLVM_DEBUG(dbgs() << " with " << *Shuf << "\n" ); |
| 385 | } |
| 386 | |
| 387 | return true; |
| 388 | } |
| 389 | |
| 390 | // Add reductions are fairly common and associative, meaning we can start the |
| 391 | // interleaving from them and don't need to emit a shuffle. |
| 392 | static bool isAddReduction(Instruction &I) { |
| 393 | if (auto *II = dyn_cast<IntrinsicInst>(Val: &I)) |
| 394 | return II->getIntrinsicID() == Intrinsic::vector_reduce_add; |
| 395 | return false; |
| 396 | } |
| 397 | |
| 398 | bool MVELaneInterleaving::runOnFunction(Function &F) { |
| 399 | if (!EnableInterleave) |
| 400 | return false; |
| 401 | auto &TPC = getAnalysis<TargetPassConfig>(); |
| 402 | auto &TM = TPC.getTM<TargetMachine>(); |
| 403 | auto *ST = &TM.getSubtarget<ARMSubtarget>(F); |
| 404 | if (!ST->hasMVEIntegerOps()) |
| 405 | return false; |
| 406 | |
| 407 | bool Changed = false; |
| 408 | |
| 409 | SmallPtrSet<Instruction *, 16> Visited; |
| 410 | for (Instruction &I : reverse(C: instructions(F))) { |
| 411 | if (((I.getType()->isVectorTy() && |
| 412 | (isa<TruncInst>(Val: I) || isa<FPTruncInst>(Val: I))) || |
| 413 | isAddReduction(I)) && |
| 414 | !Visited.count(Ptr: &I)) |
| 415 | Changed |= tryInterleave(Start: &I, Visited); |
| 416 | } |
| 417 | |
| 418 | return Changed; |
| 419 | } |
| 420 | |