| 1 | //===------- SVEShuffleOpts - SVE Shuffle Optimization --------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // Tries to pattern match and combine scalable vector shuffles that could |
| 10 | // be more efficiently performed by tbl instructions. |
| 11 | // |
| 12 | // An example would be a loop with 4 multiply-accumulate reductions, where the |
| 13 | // new data in each vector iterations comes from a 4-way deinterleaving of |
| 14 | // smaller datatypes loaded from memory which are then zero extended. |
| 15 | // |
| 16 | // Something like the following: |
| 17 | // %bgra = call ... @llvm.masked.load |
| 18 | // %deinterleave = call ... @llvm.vector.deinterleave4(%bgra) |
| 19 | // If the load was of a <vscale x 8 x i16>, we now have 4 deinterleaved |
| 20 | // <vscale x 2 x i16> values. |
| 21 | // %b.i16 = extractvalue %deinterleave, 0 |
| 22 | // %b.i64 = zext <vscale x 2 x i16> %b.i16 to <vscale x 2 x i64> |
| 23 | // %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.i64 |
| 24 | // <repeat for the other 3 subvectors> |
| 25 | // |
| 26 | // If the initial load is a legal vector rather than 4x the size (generating a |
| 27 | // structured ld4 instead), we would see multiple uunpkhi/lo instructions for |
| 28 | // the extensions, followed by uzp1/2 instructions for the deinterleave. |
| 29 | // Instead, we can replace all of those with 4 tbl instructions. The tradeoff, |
| 30 | // of course, is that we now have 4 mask values to maintain which may increase |
| 31 | // register pressure. |
| 32 | // |
| 33 | // This basic transform could be performed in CodeGenPrepare (as the equivalent |
| 34 | // for NEON is), or in a DAG Combine. However, we hope to extend it to detect |
| 35 | // other shuffles that we can fold into the tbl. Extending the above example, |
| 36 | // if instead of directly adding to the accumulator we multiplied it by a |
| 37 | // common term for all 4 components that had been reversed: |
| 38 | // %common.load = call @llvm.masked.load |
| 39 | // %common.reverse = call @llvm.vector.reverse |
| 40 | // These would be loaded at the extended size, <vscale x 2 x i64> in our |
| 41 | // example. |
| 42 | // %b.mul = mul <vscale x 2 x i64> %b.i64, %common.reverse |
| 43 | // %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.mul |
| 44 | // <repeat for the other 3 subvectors, using %common.reverse for each) |
| 45 | // |
| 46 | // In this case, the reverse isn't applied to the deinterleaved data in the |
| 47 | // original IR, but to the common term multiplied by the individual bgra |
| 48 | // elements. If the order of the elements in the accumulator is important, we |
| 49 | // cannot change that. If, however, we know that the accumulator is reduced to |
| 50 | // a single scalar after the loop and the data is either integers or floating |
| 51 | // point with reassociation allowed, we could instead choose a different mask |
| 52 | // for the tbls to reverse the individual bgra elements instead, removing an |
| 53 | // additional instruction from the loop. This does require looking beyond the |
| 54 | // blocks in the loop, so DAGCombine won't help. |
| 55 | // |
| 56 | // We should also be able to introduce new shuffles in order to balance out |
| 57 | // SVE's bottom/top instruction pairs, which act on even/odd lanes instead of |
| 58 | // the high or low half of a register. |
| 59 | // |
| 60 | // This pass may end up being a temporary solution that is removed if we can |
| 61 | // create a generic vector shuffle intrinsic and move this feature to |
| 62 | // LoopVectorize itself, as that would allow for better cost modelling. |
| 63 | // |
| 64 | //===----------------------------------------------------------------------===// |
| 65 | |
| 66 | #include "AArch64.h" |
| 67 | #include "AArch64Subtarget.h" |
| 68 | #include "AArch64TargetMachine.h" |
| 69 | #include "llvm/Analysis/AssumptionCache.h" |
| 70 | #include "llvm/Analysis/LoopInfo.h" |
| 71 | #include "llvm/Analysis/LoopPass.h" |
| 72 | #include "llvm/Analysis/MemorySSA.h" |
| 73 | #include "llvm/Analysis/TargetTransformInfo.h" |
| 74 | #include "llvm/Analysis/ValueTracking.h" |
| 75 | #include "llvm/CodeGen/TargetLowering.h" |
| 76 | #include "llvm/CodeGen/TargetPassConfig.h" |
| 77 | #include "llvm/CodeGen/TargetSubtargetInfo.h" |
| 78 | #include "llvm/IR/Constants.h" |
| 79 | #include "llvm/IR/IRBuilder.h" |
| 80 | #include "llvm/IR/Instructions.h" |
| 81 | #include "llvm/IR/IntrinsicInst.h" |
| 82 | #include "llvm/IR/IntrinsicsAArch64.h" |
| 83 | #include "llvm/IR/LLVMContext.h" |
| 84 | #include "llvm/IR/PassManager.h" |
| 85 | #include "llvm/IR/PatternMatch.h" |
| 86 | #include "llvm/InitializePasses.h" |
| 87 | #include <array> |
| 88 | |
| 89 | using namespace llvm; |
| 90 | using namespace llvm::PatternMatch; |
| 91 | |
| 92 | #define DEBUG_TYPE "aarch64-sve-shuffle-opts" |
| 93 | |
| 94 | /// A mapping between a vector_deinterleaveN intrinsic and extending cast |
| 95 | /// instructions used on the resulting subvectors. |
| 96 | using DeinterleaveMap = SmallDenseMap<CallInst *, std::array<CastInst *, 4>>; |
| 97 | |
| 98 | /// Evaluate a deinterleave and see what the uses are. If we find other |
| 99 | /// operations that we can combine into a tbl shuffle, add the deinterleave and |
| 100 | /// the operations (currently only zext or uitofp) to the candidates map. |
| 101 | static void evaluateDeinterleave(IntrinsicInst *I, DeinterleaveMap &Candidates, |
| 102 | Loop &L, const AArch64TargetLowering &TL, |
| 103 | const DataLayout DL) { |
| 104 | assert(I->getIntrinsicID() == Intrinsic::vector_deinterleave4 && |
| 105 | "Only deinterleave4 supported currently" ); |
| 106 | |
| 107 | ConstantRange VScaleRange = getVScaleRange(F: I->getFunction(), BitWidth: 64); |
| 108 | // TBL zeroes elements with an out-of-bounds index, but for the largest |
| 109 | // possible SVE vector (2048b) the maximum value for i8 elements (255) is not |
| 110 | // large enough to encode an 'out of bounds' value. So we can only perform |
| 111 | // this optimization for i8 elements if we know vscale is < 16. |
| 112 | EVT InputVT = TL.getValueType(DL, Ty: I->getOperand(i_nocapture: 0)->getType()); |
| 113 | if (!InputVT.isScalableVector() || |
| 114 | (InputVT.getScalarSizeInBits() < 16 && |
| 115 | (!VScaleRange.getUpper().ult(RHS: 16) || VScaleRange.isUpperWrapped())) || |
| 116 | TL.getTypeConversion(Context&: I->getContext(), VT: InputVT).first != |
| 117 | TargetLoweringBase::TypeLegal) |
| 118 | return; |
| 119 | |
| 120 | std::array<CastInst *, 4> Extends = {}; |
| 121 | unsigned Opcode = 0; |
| 122 | Type *DestTy = nullptr; |
| 123 | for (User *U : I->users()) { |
| 124 | auto * = dyn_cast<ExtractValueInst>(Val: U); |
| 125 | if (!Extract || !Extract->hasOneUse()) |
| 126 | return; |
| 127 | |
| 128 | // We expect only a single cast instruction as a user for the extract. |
| 129 | auto *Extend = dyn_cast_if_present<CastInst>(Val: *Extract->users().begin()); |
| 130 | if (!Extend || (!isa<ZExtInst>(Val: Extend) && !isa<UIToFPInst>(Val: Extend))) |
| 131 | return; |
| 132 | |
| 133 | // We're only interested if the uses are in the loop. This is almost |
| 134 | // certainly the case. |
| 135 | if (!L.contains(Inst: Extend)) |
| 136 | return; |
| 137 | |
| 138 | Opcode = Extend->getOpcode(); |
| 139 | DestTy = Extend->getDestTy(); |
| 140 | |
| 141 | // Make sure DestTy matches the input size. |
| 142 | if (DestTy->getPrimitiveSizeInBits() != InputVT.getSizeInBits()) |
| 143 | return; |
| 144 | |
| 145 | Extends[Extract->getIndices().front()] = Extend; |
| 146 | } |
| 147 | |
| 148 | // Check that all extracted values are being extended the same way, and that |
| 149 | // we have the expected number of extensions. |
| 150 | if (!all_of(Range&: Extends, P: [DestTy, Opcode](CastInst *CI) { |
| 151 | return !CI || (CI->getDestTy() == DestTy && CI->getOpcode() == Opcode); |
| 152 | })) |
| 153 | return; |
| 154 | |
| 155 | Candidates.try_emplace(Key: I, Args&: Extends); |
| 156 | } |
| 157 | |
| 158 | /// Given a map of deinterleaves to zext or uitofp casts, remove the operations |
| 159 | /// and replace them with tbl shuffles. |
| 160 | static void optimizeSVEDeinterleavedExtends(DeinterleaveMap Deinterleaves) { |
| 161 | for (auto &[Deinterleave, Extends] : Deinterleaves) { |
| 162 | VectorType *DestTy = cast<VectorType>(Val: Extends[0]->getDestTy()); |
| 163 | VectorType *SrcTy = cast<VectorType>(Val: Extends[0]->getSrcTy()); |
| 164 | unsigned DstBits = DestTy->getScalarSizeInBits(); |
| 165 | unsigned SrcBits = SrcTy->getScalarSizeInBits(); |
| 166 | bool IsUIToFP = isa<UIToFPInst>(Val: Extends[0]); |
| 167 | VectorType *StepVecTy = VectorType::getInteger(VTy: DestTy); |
| 168 | Value *Input = Deinterleave->getOperand(i_nocapture: 0); |
| 169 | Type *InputTy = Input->getType(); |
| 170 | |
| 171 | APInt Invalid = APInt::getAllOnes(numBits: DstBits); |
| 172 | for (auto [Idx, Extend] : enumerate(First&: Extends)) { |
| 173 | // If not all lanes were extracted, we can have gaps. Skip over them. |
| 174 | if (!Extend) |
| 175 | continue; |
| 176 | // Build the mask using stepvectors and casting. |
| 177 | // We want to select the Idx'th element, and every 4 elements after that. |
| 178 | // Each element needs to be zero extended; we can do that by providing |
| 179 | // tbl index values that are out of range. We can't do that nicely with |
| 180 | // a stepvector of the same element type as the input type, but we can |
| 181 | // do it with elements the size of the output type. |
| 182 | // E.g. for element 0 of a 16b -> 64b zext, we would start with a mask of |
| 183 | // 0xFFFF_FFFF_FFFF_0000 + Idx for the start of the stepvector, and use a |
| 184 | // step of 4. We then cast that back to an element size of 16b, yielding |
| 185 | // <0x0000 + Idx, 0xFFFF, 0xFFFF, 0xFFFF, 0x0004 + Idx, 0xFFFF...>. |
| 186 | APInt StartIdx = Invalid << SrcBits; |
| 187 | StartIdx += Idx; |
| 188 | IRBuilder<> Builder(Extend); |
| 189 | Value *StepVector = Builder.CreateStepVector(DstType: StepVecTy); |
| 190 | Value *ScaledSteps = |
| 191 | Builder.CreateNUWMul(LHS: StepVector, RHS: ConstantInt::get(Ty: StepVecTy, V: 4)); |
| 192 | Value *ZextTbl = Builder.CreateNUWAdd( |
| 193 | LHS: ScaledSteps, RHS: ConstantInt::get(Ty: StepVecTy, V: StartIdx)); |
| 194 | Value *FinalMask = Builder.CreateBitCast(V: ZextTbl, DestTy: InputTy); |
| 195 | |
| 196 | // Replace the deinterleave, extractvalue, and extension chain with |
| 197 | // a tbl directly on the input value. |
| 198 | Value *Tbl = Builder.CreateIntrinsic(ID: Intrinsic::aarch64_sve_tbl, |
| 199 | OverloadTypes: {InputTy}, Args: {Input, FinalMask}); |
| 200 | Value *Widen = Builder.CreateBitCast(V: Tbl, DestTy: StepVecTy); |
| 201 | if (IsUIToFP) |
| 202 | Widen = Builder.CreateUIToFP(V: Widen, DestTy); |
| 203 | LLVM_DEBUG(dbgs() << "SVETBLOPT: Replaced " << *Extend << " with " |
| 204 | << *Widen << "\n" ); |
| 205 | Extend->replaceAllUsesWith(V: Widen); |
| 206 | Extend->eraseFromParent(); |
| 207 | } |
| 208 | |
| 209 | // Delete the unused extracts and deinterleave. |
| 210 | for (User *U : make_early_inc_range(Range: Deinterleave->users())) |
| 211 | cast<Instruction>(Val: U)->eraseFromParent(); |
| 212 | Deinterleave->eraseFromParent(); |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | static bool processLoop(Loop &L, const AArch64Subtarget &ST, DataLayout DL) { |
| 217 | // At present, we only want to do this for innermost loops when SVE |
| 218 | // is available. |
| 219 | if (!L.isInnermost() || !ST.isSVEorStreamingSVEAvailable()) |
| 220 | return false; |
| 221 | |
| 222 | // TODO: Pull other shuffles into the tbl where possible. |
| 223 | // TODO: Add more advanced cases, such as introducing shuffles so that |
| 224 | // the SVE odd/even BT narrowing instructions can be used. |
| 225 | // TODO: Support other deinterleaves. |
| 226 | const AArch64TargetLowering &TL = *ST.getTargetLowering(); |
| 227 | assert(DL.isLittleEndian() && |
| 228 | "Shuffle optimizations unsupported for big endian targets." ); |
| 229 | DeinterleaveMap Candidates; |
| 230 | for (auto *BB : L.blocks()) |
| 231 | for (auto &I : *BB) |
| 232 | if (match(V: &I, P: m_Intrinsic<Intrinsic::vector_deinterleave4>(Op0: m_Value()))) |
| 233 | evaluateDeinterleave(I: cast<IntrinsicInst>(Val: &I), Candidates, L, TL, DL); |
| 234 | |
| 235 | if (Candidates.empty()) |
| 236 | return false; |
| 237 | |
| 238 | optimizeSVEDeinterleavedExtends(Deinterleaves: Candidates); |
| 239 | return true; |
| 240 | } |
| 241 | |
| 242 | namespace { |
| 243 | struct SVEShuffleOpts : public LoopPass { |
| 244 | static char ID; // Pass identification, replacement for typeid |
| 245 | SVEShuffleOpts() : LoopPass(ID) {} |
| 246 | |
| 247 | bool runOnLoop(Loop *L, LPPassManager &PM) override { |
| 248 | if (skipLoop(L)) |
| 249 | return false; |
| 250 | |
| 251 | TargetPassConfig &TPC = getAnalysis<TargetPassConfig>(); |
| 252 | const AArch64TargetMachine &TM = TPC.getTM<AArch64TargetMachine>(); |
| 253 | const AArch64Subtarget &ST = |
| 254 | *TM.getSubtargetImpl(F: *L->getHeader()->getParent()); |
| 255 | |
| 256 | return processLoop(L&: *L, ST, DL: TM.createDataLayout()); |
| 257 | } |
| 258 | |
| 259 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 260 | AU.addRequired<TargetPassConfig>(); |
| 261 | AU.setPreservesCFG(); |
| 262 | } |
| 263 | |
| 264 | StringRef getPassName() const override { return "SVE Shuffle Optimizations" ; } |
| 265 | }; |
| 266 | } // end anonymous namespace |
| 267 | |
| 268 | char SVEShuffleOpts::ID = 0; |
| 269 | static const char *name = "SVE Shuffle Optimizations" ; |
| 270 | INITIALIZE_PASS_BEGIN(SVEShuffleOpts, DEBUG_TYPE, name, false, false) |
| 271 | INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) |
| 272 | INITIALIZE_PASS_END(SVEShuffleOpts, DEBUG_TYPE, name, false, false) |
| 273 | |
| 274 | Pass *llvm::createSVEShuffleOptsPass() { return new SVEShuffleOpts(); } |
| 275 | |
| 276 | PreservedAnalyses SVEShuffleOptsPass::run(Loop &L, LoopAnalysisManager &AM, |
| 277 | LoopStandardAnalysisResults &AR, |
| 278 | LPMUpdater &U) { |
| 279 | const AArch64Subtarget &ST = |
| 280 | *TM.getSubtargetImpl(F: *L.getHeader()->getParent()); |
| 281 | |
| 282 | if (processLoop(L, ST, DL: TM.createDataLayout())) { |
| 283 | PreservedAnalyses PA; |
| 284 | PA.preserveSet<CFGAnalyses>(); |
| 285 | PA.preserve<TargetIRAnalysis>(); |
| 286 | PA.preserve<AssumptionAnalysis>(); |
| 287 | PA.preserve<MemorySSAAnalysis>(); |
| 288 | return PA; |
| 289 | } |
| 290 | |
| 291 | return PreservedAnalyses::all(); |
| 292 | } |
| 293 | |