1//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass tries to remove back-to-back (smstart, smstop) and
9// (smstop, smstart) sequences. The pass is conservative when it cannot
10// determine that it is safe to remove these sequences.
11//===----------------------------------------------------------------------===//
12
13#include "AArch64InstrInfo.h"
14#include "AArch64MachineFunctionInfo.h"
15#include "AArch64Subtarget.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/CodeGen/MachineBasicBlock.h"
18#include "llvm/CodeGen/MachineFunctionPass.h"
19#include "llvm/CodeGen/MachineRegisterInfo.h"
20#include "llvm/CodeGen/TargetRegisterInfo.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "aarch64-sme-peephole-opt"
25
26namespace {
27
28struct SMEPeepholeOpt : public MachineFunctionPass {
29 static char ID;
30
31 SMEPeepholeOpt() : MachineFunctionPass(ID) {}
32
33 bool runOnMachineFunction(MachineFunction &MF) override;
34
35 StringRef getPassName() const override {
36 return "SME Peephole Optimization pass";
37 }
38
39 void getAnalysisUsage(AnalysisUsage &AU) const override {
40 AU.setPreservesCFG();
41 MachineFunctionPass::getAnalysisUsage(AU);
42 }
43
44 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
45 bool &HasRemovedAllSMChanges) const;
46 bool visitRegSequence(MachineInstr &MI);
47};
48
49char SMEPeepholeOpt::ID = 0;
50
51} // end anonymous namespace
52
53static bool isConditionalStartStop(const MachineInstr *MI) {
54 return MI->getOpcode() == AArch64::MSRpstatePseudo;
55}
56
57static bool isMatchingStartStopPair(const MachineInstr *MI1,
58 const MachineInstr *MI2) {
59 // We only consider the same type of streaming mode change here, i.e.
60 // start/stop SM, or start/stop ZA pairs.
61 if (MI1->getOperand(i: 0).getImm() != MI2->getOperand(i: 0).getImm())
62 return false;
63
64 // One must be 'start', the other must be 'stop'
65 if (MI1->getOperand(i: 1).getImm() == MI2->getOperand(i: 1).getImm())
66 return false;
67
68 bool IsConditional = isConditionalStartStop(MI: MI2);
69 if (isConditionalStartStop(MI: MI1) != IsConditional)
70 return false;
71
72 if (!IsConditional)
73 return true;
74
75 // Check to make sure the conditional start/stop pairs are identical.
76 if (MI1->getOperand(i: 2).getImm() != MI2->getOperand(i: 2).getImm())
77 return false;
78
79 // Ensure reg masks are identical.
80 if (MI1->getOperand(i: 4).getRegMask() != MI2->getOperand(i: 4).getRegMask())
81 return false;
82
83 // This optimisation is unlikely to happen in practice for conditional
84 // smstart/smstop pairs as the virtual registers for pstate.sm will always
85 // be different.
86 // TODO: For this optimisation to apply to conditional smstart/smstop,
87 // this pass will need to do more work to remove redundant calls to
88 // __arm_sme_state.
89
90 // Only consider conditional start/stop pairs which read the same register
91 // holding the original value of pstate.sm, as some conditional start/stops
92 // require the state on entry to the function.
93 if (MI1->getOperand(i: 3).isReg() && MI2->getOperand(i: 3).isReg()) {
94 Register Reg1 = MI1->getOperand(i: 3).getReg();
95 Register Reg2 = MI2->getOperand(i: 3).getReg();
96 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
97 return false;
98 }
99
100 return true;
101}
102
103static bool ChangesStreamingMode(const MachineInstr *MI) {
104 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
105 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
106 "Expected MI to be a smstart/smstop instruction");
107 return MI->getOperand(i: 0).getImm() == AArch64SVCR::SVCRSM ||
108 MI->getOperand(i: 0).getImm() == AArch64SVCR::SVCRSMZA;
109}
110
111static bool isSVERegOp(const TargetRegisterInfo &TRI,
112 const MachineRegisterInfo &MRI,
113 const MachineOperand &MO) {
114 if (!MO.isReg())
115 return false;
116
117 Register R = MO.getReg();
118 if (R.isPhysical())
119 return llvm::any_of(Range: TRI.subregs_inclusive(Reg: R), P: [](const MCPhysReg &SR) {
120 return AArch64::ZPRRegClass.contains(Reg: SR) ||
121 AArch64::PPRRegClass.contains(Reg: SR);
122 });
123
124 const TargetRegisterClass *RC = MRI.getRegClass(Reg: R);
125 return TRI.getCommonSubClass(A: &AArch64::ZPRRegClass, B: RC) ||
126 TRI.getCommonSubClass(A: &AArch64::PPRRegClass, B: RC);
127}
128
129bool SMEPeepholeOpt::optimizeStartStopPairs(
130 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
131 const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
132 const TargetRegisterInfo &TRI =
133 *MBB.getParent()->getSubtarget().getRegisterInfo();
134
135 bool Changed = false;
136 MachineInstr *Prev = nullptr;
137 SmallVector<MachineInstr *, 4> ToBeRemoved;
138
139 // Convenience function to reset the matching of a sequence.
140 auto Reset = [&]() {
141 Prev = nullptr;
142 ToBeRemoved.clear();
143 };
144
145 // Walk through instructions in the block trying to find pairs of smstart
146 // and smstop nodes that cancel each other out. We only permit a limited
147 // set of instructions to appear between them, otherwise we reset our
148 // tracking.
149 unsigned NumSMChanges = 0;
150 unsigned NumSMChangesRemoved = 0;
151 for (MachineInstr &MI : make_early_inc_range(Range&: MBB)) {
152 switch (MI.getOpcode()) {
153 case AArch64::MSRpstatesvcrImm1:
154 case AArch64::MSRpstatePseudo: {
155 if (ChangesStreamingMode(MI: &MI))
156 NumSMChanges++;
157
158 if (!Prev)
159 Prev = &MI;
160 else if (isMatchingStartStopPair(MI1: Prev, MI2: &MI)) {
161 // If they match, we can remove them, and possibly any instructions
162 // that we marked for deletion in between.
163 Prev->eraseFromParent();
164 MI.eraseFromParent();
165 for (MachineInstr *TBR : ToBeRemoved)
166 TBR->eraseFromParent();
167 ToBeRemoved.clear();
168 Prev = nullptr;
169 Changed = true;
170 NumSMChangesRemoved += 2;
171 } else {
172 Reset();
173 Prev = &MI;
174 }
175 continue;
176 }
177 default:
178 if (!Prev)
179 // Avoid doing expensive checks when Prev is nullptr.
180 continue;
181 break;
182 }
183
184 // Test if the instructions in between the start/stop sequence are agnostic
185 // of streaming mode. If not, the algorithm should reset.
186 switch (MI.getOpcode()) {
187 default:
188 Reset();
189 break;
190 case AArch64::COALESCER_BARRIER_FPR16:
191 case AArch64::COALESCER_BARRIER_FPR32:
192 case AArch64::COALESCER_BARRIER_FPR64:
193 case AArch64::COALESCER_BARRIER_FPR128:
194 case AArch64::COPY:
195 // These instructions should be safe when executed on their own, but
196 // the code remains conservative when SVE registers are used. There may
197 // exist subtle cases where executing a COPY in a different mode results
198 // in different behaviour, even if we can't yet come up with any
199 // concrete example/test-case.
200 if (isSVERegOp(TRI, MRI, MO: MI.getOperand(i: 0)) ||
201 isSVERegOp(TRI, MRI, MO: MI.getOperand(i: 1)))
202 Reset();
203 break;
204 case AArch64::ADJCALLSTACKDOWN:
205 case AArch64::ADJCALLSTACKUP:
206 case AArch64::ANDXri:
207 case AArch64::ADDXri:
208 // We permit these as they don't generate SVE/NEON instructions.
209 break;
210 case AArch64::VGRestorePseudo:
211 case AArch64::VGSavePseudo:
212 // When the smstart/smstop are removed, we should also remove
213 // the pseudos that save/restore the VG value for CFI info.
214 ToBeRemoved.push_back(Elt: &MI);
215 break;
216 case AArch64::MSRpstatesvcrImm1:
217 case AArch64::MSRpstatePseudo:
218 llvm_unreachable("Should have been handled");
219 }
220 }
221
222 HasRemovedAllSMChanges =
223 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
224 return Changed;
225}
226
227// Using the FORM_TRANSPOSED_REG_TUPLE pseudo can improve register allocation
228// of multi-vector intrinsics. However, the pseudo should only be emitted if
229// the input registers of the REG_SEQUENCE are copy nodes where the source
230// register is in a StridedOrContiguous class. For example:
231//
232// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
233// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
234// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
235// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
236// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
237// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
238// %9:zpr2mul2 = REG_SEQUENCE %5:zpr, %subreg.zsub0, %8:zpr, %subreg.zsub1
239//
240// -> %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
241//
242bool SMEPeepholeOpt::visitRegSequence(MachineInstr &MI) {
243 assert(MI.getMF()->getRegInfo().isSSA() && "Expected to be run on SSA form!");
244
245 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
246 switch (MRI.getRegClass(Reg: MI.getOperand(i: 0).getReg())->getID()) {
247 case AArch64::ZPR2RegClassID:
248 case AArch64::ZPR4RegClassID:
249 case AArch64::ZPR2Mul2RegClassID:
250 case AArch64::ZPR4Mul4RegClassID:
251 break;
252 default:
253 return false;
254 }
255
256 // The first operand is the register class created by the REG_SEQUENCE.
257 // Each operand pair after this consists of a vreg + subreg index, so
258 // for example a sequence of 2 registers will have a total of 5 operands.
259 if (MI.getNumOperands() != 5 && MI.getNumOperands() != 9)
260 return false;
261
262 MCRegister SubReg = MCRegister::NoRegister;
263 for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
264 MachineOperand &MO = MI.getOperand(i: I);
265
266 MachineOperand *Def = MRI.getOneDef(Reg: MO.getReg());
267 if (!Def || !Def->getParent()->isCopy())
268 return false;
269
270 const MachineOperand &CopySrc = Def->getParent()->getOperand(i: 1);
271 unsigned OpSubReg = CopySrc.getSubReg();
272 if (SubReg == MCRegister::NoRegister)
273 SubReg = OpSubReg;
274
275 MachineOperand *CopySrcOp = MRI.getOneDef(Reg: CopySrc.getReg());
276 if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
277 CopySrcOp->getReg().isPhysical())
278 return false;
279
280 const TargetRegisterClass *CopySrcClass =
281 MRI.getRegClass(Reg: CopySrcOp->getReg());
282 if (CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
283 CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass)
284 return false;
285 }
286
287 unsigned Opc = MI.getNumOperands() == 5
288 ? AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO
289 : AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
290
291 const TargetInstrInfo *TII =
292 MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo();
293 MachineInstrBuilder MIB = BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(),
294 MCID: TII->get(Opcode: Opc), DestReg: MI.getOperand(i: 0).getReg());
295 for (unsigned I = 1; I < MI.getNumOperands(); I += 2)
296 MIB.addReg(RegNo: MI.getOperand(i: I).getReg());
297
298 MI.eraseFromParent();
299 return true;
300}
301
302INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
303 "SME Peephole Optimization", false, false)
304
305bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
306 if (skipFunction(F: MF.getFunction()))
307 return false;
308
309 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
310 return false;
311
312 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
313
314 bool Changed = false;
315 bool FunctionHasAllSMChangesRemoved = false;
316
317 // Even if the block lives in a function with no SME attributes attached we
318 // still have to analyze all the blocks because we may call a streaming
319 // function that requires smstart/smstop pairs.
320 for (MachineBasicBlock &MBB : MF) {
321 bool BlockHasAllSMChangesRemoved;
322 Changed |= optimizeStartStopPairs(MBB, HasRemovedAllSMChanges&: BlockHasAllSMChangesRemoved);
323 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
324
325 if (MF.getSubtarget<AArch64Subtarget>().isStreaming()) {
326 for (MachineInstr &MI : make_early_inc_range(Range&: MBB))
327 if (MI.getOpcode() == AArch64::REG_SEQUENCE)
328 Changed |= visitRegSequence(MI);
329 }
330 }
331
332 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
333 if (FunctionHasAllSMChangesRemoved)
334 AFI->setHasStreamingModeChanges(false);
335
336 return Changed;
337}
338
339FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
340