1//===------- SVEShuffleOpts - SVE Shuffle Optimization --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Tries to pattern match and combine scalable vector shuffles that could
10// be more efficiently performed by tbl instructions.
11//
12// An example would be a loop with 4 multiply-accumulate reductions, where the
13// new data in each vector iterations comes from a 4-way deinterleaving of
14// smaller datatypes loaded from memory which are then zero extended.
15//
16// Something like the following:
17// %bgra = call ... @llvm.masked.load
18// %deinterleave = call ... @llvm.vector.deinterleave4(%bgra)
19// If the load was of a <vscale x 8 x i16>, we now have 4 deinterleaved
20// <vscale x 2 x i16> values.
21// %b.i16 = extractvalue %deinterleave, 0
22// %b.i64 = zext <vscale x 2 x i16> %b.i16 to <vscale x 2 x i64>
23// %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.i64
24// <repeat for the other 3 subvectors>
25//
26// If the initial load is a legal vector rather than 4x the size (generating a
27// structured ld4 instead), we would see multiple uunpkhi/lo instructions for
28// the extensions, followed by uzp1/2 instructions for the deinterleave.
29// Instead, we can replace all of those with 4 tbl instructions. The tradeoff,
30// of course, is that we now have 4 mask values to maintain which may increase
31// register pressure.
32//
33// This basic transform could be performed in CodeGenPrepare (as the equivalent
34// for NEON is), or in a DAG Combine. However, we hope to extend it to detect
35// other shuffles that we can fold into the tbl. Extending the above example,
36// if instead of directly adding to the accumulator we multiplied it by a
37// common term for all 4 components that had been reversed:
38// %common.load = call @llvm.masked.load
39// %common.reverse = call @llvm.vector.reverse
40// These would be loaded at the extended size, <vscale x 2 x i64> in our
41// example.
42// %b.mul = mul <vscale x 2 x i64> %b.i64, %common.reverse
43// %acc.b.next = add <vscale x 2 x i64> %acc.b, %b.mul
44// <repeat for the other 3 subvectors, using %common.reverse for each)
45//
46// In this case, the reverse isn't applied to the deinterleaved data in the
47// original IR, but to the common term multiplied by the individual bgra
48// elements. If the order of the elements in the accumulator is important, we
49// cannot change that. If, however, we know that the accumulator is reduced to
50// a single scalar after the loop and the data is either integers or floating
51// point with reassociation allowed, we could instead choose a different mask
52// for the tbls to reverse the individual bgra elements instead, removing an
53// additional instruction from the loop. This does require looking beyond the
54// blocks in the loop, so DAGCombine won't help.
55//
56// We should also be able to introduce new shuffles in order to balance out
57// SVE's bottom/top instruction pairs, which act on even/odd lanes instead of
58// the high or low half of a register.
59//
60// This pass may end up being a temporary solution that is removed if we can
61// create a generic vector shuffle intrinsic and move this feature to
62// LoopVectorize itself, as that would allow for better cost modelling.
63//
64//===----------------------------------------------------------------------===//
65
66#include "AArch64.h"
67#include "AArch64Subtarget.h"
68#include "AArch64TargetMachine.h"
69#include "llvm/Analysis/AssumptionCache.h"
70#include "llvm/Analysis/LoopInfo.h"
71#include "llvm/Analysis/LoopPass.h"
72#include "llvm/Analysis/MemorySSA.h"
73#include "llvm/Analysis/TargetTransformInfo.h"
74#include "llvm/Analysis/ValueTracking.h"
75#include "llvm/CodeGen/TargetLowering.h"
76#include "llvm/CodeGen/TargetPassConfig.h"
77#include "llvm/CodeGen/TargetSubtargetInfo.h"
78#include "llvm/IR/Constants.h"
79#include "llvm/IR/IRBuilder.h"
80#include "llvm/IR/Instructions.h"
81#include "llvm/IR/IntrinsicInst.h"
82#include "llvm/IR/IntrinsicsAArch64.h"
83#include "llvm/IR/LLVMContext.h"
84#include "llvm/IR/PassManager.h"
85#include "llvm/IR/PatternMatch.h"
86#include "llvm/InitializePasses.h"
87#include <array>
88
89using namespace llvm;
90using namespace llvm::PatternMatch;
91
92#define DEBUG_TYPE "aarch64-sve-shuffle-opts"
93
94/// A mapping between a vector_deinterleaveN intrinsic and extending cast
95/// instructions used on the resulting subvectors.
96using DeinterleaveMap = SmallDenseMap<CallInst *, std::array<CastInst *, 4>>;
97
98/// Evaluate a deinterleave and see what the uses are. If we find other
99/// operations that we can combine into a tbl shuffle, add the deinterleave and
100/// the operations (currently only zext or uitofp) to the candidates map.
101static void evaluateDeinterleave(IntrinsicInst *I, DeinterleaveMap &Candidates,
102 Loop &L, const AArch64TargetLowering &TL,
103 const DataLayout DL) {
104 assert(I->getIntrinsicID() == Intrinsic::vector_deinterleave4 &&
105 "Only deinterleave4 supported currently");
106
107 ConstantRange VScaleRange = getVScaleRange(F: I->getFunction(), BitWidth: 64);
108 // TBL zeroes elements with an out-of-bounds index, but for the largest
109 // possible SVE vector (2048b) the maximum value for i8 elements (255) is not
110 // large enough to encode an 'out of bounds' value. So we can only perform
111 // this optimization for i8 elements if we know vscale is < 16.
112 EVT InputVT = TL.getValueType(DL, Ty: I->getOperand(i_nocapture: 0)->getType());
113 if (!InputVT.isScalableVector() ||
114 (InputVT.getScalarSizeInBits() < 16 &&
115 (!VScaleRange.getUpper().ult(RHS: 16) || VScaleRange.isUpperWrapped())) ||
116 TL.getTypeConversion(Context&: I->getContext(), VT: InputVT).first !=
117 TargetLoweringBase::TypeLegal)
118 return;
119
120 std::array<CastInst *, 4> Extends = {};
121 unsigned Opcode = 0;
122 Type *DestTy = nullptr;
123 for (User *U : I->users()) {
124 auto *Extract = dyn_cast<ExtractValueInst>(Val: U);
125 if (!Extract || !Extract->hasOneUse())
126 return;
127
128 // We expect only a single cast instruction as a user for the extract.
129 auto *Extend = dyn_cast_if_present<CastInst>(Val: *Extract->users().begin());
130 if (!Extend || (!isa<ZExtInst>(Val: Extend) && !isa<UIToFPInst>(Val: Extend)))
131 return;
132
133 // We're only interested if the uses are in the loop. This is almost
134 // certainly the case.
135 if (!L.contains(Inst: Extend))
136 return;
137
138 Opcode = Extend->getOpcode();
139 DestTy = Extend->getDestTy();
140
141 // Make sure DestTy matches the input size.
142 if (DestTy->getPrimitiveSizeInBits() != InputVT.getSizeInBits())
143 return;
144
145 Extends[Extract->getIndices().front()] = Extend;
146 }
147
148 // Check that all extracted values are being extended the same way, and that
149 // we have the expected number of extensions.
150 if (!all_of(Range&: Extends, P: [DestTy, Opcode](CastInst *CI) {
151 return !CI || (CI->getDestTy() == DestTy && CI->getOpcode() == Opcode);
152 }))
153 return;
154
155 Candidates.try_emplace(Key: I, Args&: Extends);
156}
157
158/// Given a map of deinterleaves to zext or uitofp casts, remove the operations
159/// and replace them with tbl shuffles.
160static void optimizeSVEDeinterleavedExtends(DeinterleaveMap Deinterleaves) {
161 for (auto &[Deinterleave, Extends] : Deinterleaves) {
162 VectorType *DestTy = cast<VectorType>(Val: Extends[0]->getDestTy());
163 VectorType *SrcTy = cast<VectorType>(Val: Extends[0]->getSrcTy());
164 unsigned DstBits = DestTy->getScalarSizeInBits();
165 unsigned SrcBits = SrcTy->getScalarSizeInBits();
166 bool IsUIToFP = isa<UIToFPInst>(Val: Extends[0]);
167 VectorType *StepVecTy = VectorType::getInteger(VTy: DestTy);
168 Value *Input = Deinterleave->getOperand(i_nocapture: 0);
169 Type *InputTy = Input->getType();
170
171 APInt Invalid = APInt::getAllOnes(numBits: DstBits);
172 for (auto [Idx, Extend] : enumerate(First&: Extends)) {
173 // If not all lanes were extracted, we can have gaps. Skip over them.
174 if (!Extend)
175 continue;
176 // Build the mask using stepvectors and casting.
177 // We want to select the Idx'th element, and every 4 elements after that.
178 // Each element needs to be zero extended; we can do that by providing
179 // tbl index values that are out of range. We can't do that nicely with
180 // a stepvector of the same element type as the input type, but we can
181 // do it with elements the size of the output type.
182 // E.g. for element 0 of a 16b -> 64b zext, we would start with a mask of
183 // 0xFFFF_FFFF_FFFF_0000 + Idx for the start of the stepvector, and use a
184 // step of 4. We then cast that back to an element size of 16b, yielding
185 // <0x0000 + Idx, 0xFFFF, 0xFFFF, 0xFFFF, 0x0004 + Idx, 0xFFFF...>.
186 APInt StartIdx = Invalid << SrcBits;
187 StartIdx += Idx;
188 IRBuilder<> Builder(Extend);
189 Value *StepVector = Builder.CreateStepVector(DstType: StepVecTy);
190 Value *ScaledSteps =
191 Builder.CreateNUWMul(LHS: StepVector, RHS: ConstantInt::get(Ty: StepVecTy, V: 4));
192 Value *ZextTbl = Builder.CreateNUWAdd(
193 LHS: ScaledSteps, RHS: ConstantInt::get(Ty: StepVecTy, V: StartIdx));
194 Value *FinalMask = Builder.CreateBitCast(V: ZextTbl, DestTy: InputTy);
195
196 // Replace the deinterleave, extractvalue, and extension chain with
197 // a tbl directly on the input value.
198 Value *Tbl = Builder.CreateIntrinsic(ID: Intrinsic::aarch64_sve_tbl,
199 OverloadTypes: {InputTy}, Args: {Input, FinalMask});
200 Value *Widen = Builder.CreateBitCast(V: Tbl, DestTy: StepVecTy);
201 if (IsUIToFP)
202 Widen = Builder.CreateUIToFP(V: Widen, DestTy);
203 LLVM_DEBUG(dbgs() << "SVETBLOPT: Replaced " << *Extend << " with "
204 << *Widen << "\n");
205 Extend->replaceAllUsesWith(V: Widen);
206 Extend->eraseFromParent();
207 }
208
209 // Delete the unused extracts and deinterleave.
210 for (User *U : make_early_inc_range(Range: Deinterleave->users()))
211 cast<Instruction>(Val: U)->eraseFromParent();
212 Deinterleave->eraseFromParent();
213 }
214}
215
216static bool processLoop(Loop &L, const AArch64Subtarget &ST, DataLayout DL) {
217 // At present, we only want to do this for innermost loops when SVE
218 // is available.
219 if (!L.isInnermost() || !ST.isSVEorStreamingSVEAvailable())
220 return false;
221
222 // TODO: Pull other shuffles into the tbl where possible.
223 // TODO: Add more advanced cases, such as introducing shuffles so that
224 // the SVE odd/even BT narrowing instructions can be used.
225 // TODO: Support other deinterleaves.
226 const AArch64TargetLowering &TL = *ST.getTargetLowering();
227 assert(DL.isLittleEndian() &&
228 "Shuffle optimizations unsupported for big endian targets.");
229 DeinterleaveMap Candidates;
230 for (auto *BB : L.blocks())
231 for (auto &I : *BB)
232 if (match(V: &I, P: m_Intrinsic<Intrinsic::vector_deinterleave4>(Op0: m_Value())))
233 evaluateDeinterleave(I: cast<IntrinsicInst>(Val: &I), Candidates, L, TL, DL);
234
235 if (Candidates.empty())
236 return false;
237
238 optimizeSVEDeinterleavedExtends(Deinterleaves: Candidates);
239 return true;
240}
241
242namespace {
243struct SVEShuffleOpts : public LoopPass {
244 static char ID; // Pass identification, replacement for typeid
245 SVEShuffleOpts() : LoopPass(ID) {}
246
247 bool runOnLoop(Loop *L, LPPassManager &PM) override {
248 if (skipLoop(L))
249 return false;
250
251 TargetPassConfig &TPC = getAnalysis<TargetPassConfig>();
252 const AArch64TargetMachine &TM = TPC.getTM<AArch64TargetMachine>();
253 const AArch64Subtarget &ST =
254 *TM.getSubtargetImpl(F: *L->getHeader()->getParent());
255
256 return processLoop(L&: *L, ST, DL: TM.createDataLayout());
257 }
258
259 void getAnalysisUsage(AnalysisUsage &AU) const override {
260 AU.addRequired<TargetPassConfig>();
261 AU.setPreservesCFG();
262 }
263
264 StringRef getPassName() const override { return "SVE Shuffle Optimizations"; }
265};
266} // end anonymous namespace
267
268char SVEShuffleOpts::ID = 0;
269static const char *name = "SVE Shuffle Optimizations";
270INITIALIZE_PASS_BEGIN(SVEShuffleOpts, DEBUG_TYPE, name, false, false)
271INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
272INITIALIZE_PASS_END(SVEShuffleOpts, DEBUG_TYPE, name, false, false)
273
274Pass *llvm::createSVEShuffleOptsPass() { return new SVEShuffleOpts(); }
275
276PreservedAnalyses SVEShuffleOptsPass::run(Loop &L, LoopAnalysisManager &AM,
277 LoopStandardAnalysisResults &AR,
278 LPMUpdater &U) {
279 const AArch64Subtarget &ST =
280 *TM.getSubtargetImpl(F: *L.getHeader()->getParent());
281
282 if (processLoop(L, ST, DL: TM.createDataLayout())) {
283 PreservedAnalyses PA;
284 PA.preserveSet<CFGAnalyses>();
285 PA.preserve<TargetIRAnalysis>();
286 PA.preserve<AssumptionAnalysis>();
287 PA.preserve<MemorySSAAnalysis>();
288 return PA;
289 }
290
291 return PreservedAnalyses::all();
292}
293