1//===- SMEPeepholeOpt.cpp - SME peephole optimization pass-----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8// This pass tries to remove back-to-back (smstart, smstop) and
9// (smstop, smstart) sequences. The pass is conservative when it cannot
10// determine that it is safe to remove these sequences.
11//===----------------------------------------------------------------------===//
12
13#include "AArch64InstrInfo.h"
14#include "AArch64MachineFunctionInfo.h"
15#include "AArch64Subtarget.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/CodeGen/MachineBasicBlock.h"
18#include "llvm/CodeGen/MachineFunctionPass.h"
19#include "llvm/CodeGen/MachineRegisterInfo.h"
20#include "llvm/CodeGen/TargetRegisterInfo.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "aarch64-sme-peephole-opt"
25
26namespace {
27
28struct SMEPeepholeOpt : public MachineFunctionPass {
29 static char ID;
30
31 SMEPeepholeOpt() : MachineFunctionPass(ID) {}
32
33 bool runOnMachineFunction(MachineFunction &MF) override;
34
35 StringRef getPassName() const override {
36 return "SME Peephole Optimization pass";
37 }
38
39 void getAnalysisUsage(AnalysisUsage &AU) const override {
40 AU.setPreservesCFG();
41 MachineFunctionPass::getAnalysisUsage(AU);
42 }
43
44 bool optimizeStartStopPairs(MachineBasicBlock &MBB,
45 bool &HasRemovedAllSMChanges) const;
46 bool visitRegSequence(MachineInstr &MI);
47};
48
49char SMEPeepholeOpt::ID = 0;
50
51} // end anonymous namespace
52
53static bool isConditionalStartStop(const MachineInstr *MI) {
54 return MI->getOpcode() == AArch64::MSRpstatePseudo;
55}
56
57static bool isMatchingStartStopPair(const MachineInstr *MI1,
58 const MachineInstr *MI2) {
59 // We only consider the same type of streaming mode change here, i.e.
60 // start/stop SM, or start/stop ZA pairs.
61 if (MI1->getOperand(i: 0).getImm() != MI2->getOperand(i: 0).getImm())
62 return false;
63
64 // One must be 'start', the other must be 'stop'
65 if (MI1->getOperand(i: 1).getImm() == MI2->getOperand(i: 1).getImm())
66 return false;
67
68 bool IsConditional = isConditionalStartStop(MI: MI2);
69 if (isConditionalStartStop(MI: MI1) != IsConditional)
70 return false;
71
72 if (!IsConditional)
73 return true;
74
75 // Check to make sure the conditional start/stop pairs are identical.
76 if (MI1->getOperand(i: 2).getImm() != MI2->getOperand(i: 2).getImm())
77 return false;
78
79 // Ensure reg masks are identical.
80 if (MI1->getOperand(i: 4).getRegMask() != MI2->getOperand(i: 4).getRegMask())
81 return false;
82
83 // Only consider conditional start/stop pairs which read the same register
84 // holding the original value of pstate.sm. This is somewhat over conservative
85 // as all conditional streaming mode changes only look at the state on entry
86 // to the function.
87 if (MI1->getOperand(i: 3).isReg() && MI2->getOperand(i: 3).isReg()) {
88 Register Reg1 = MI1->getOperand(i: 3).getReg();
89 Register Reg2 = MI2->getOperand(i: 3).getReg();
90 if (Reg1.isPhysical() || Reg2.isPhysical() || Reg1 != Reg2)
91 return false;
92 }
93
94 return true;
95}
96
97static bool ChangesStreamingMode(const MachineInstr *MI) {
98 assert((MI->getOpcode() == AArch64::MSRpstatesvcrImm1 ||
99 MI->getOpcode() == AArch64::MSRpstatePseudo) &&
100 "Expected MI to be a smstart/smstop instruction");
101 return MI->getOperand(i: 0).getImm() == AArch64SVCR::SVCRSM ||
102 MI->getOperand(i: 0).getImm() == AArch64SVCR::SVCRSMZA;
103}
104
105static bool isSVERegOp(const TargetRegisterInfo &TRI,
106 const MachineRegisterInfo &MRI,
107 const MachineOperand &MO) {
108 if (!MO.isReg())
109 return false;
110
111 Register R = MO.getReg();
112 if (R.isPhysical())
113 return llvm::any_of(Range: TRI.subregs_inclusive(Reg: R), P: [](const MCPhysReg &SR) {
114 return AArch64::ZPRRegClass.contains(Reg: SR) ||
115 AArch64::PPRRegClass.contains(Reg: SR);
116 });
117
118 const TargetRegisterClass *RC = MRI.getRegClass(Reg: R);
119 return TRI.getCommonSubClass(A: &AArch64::ZPRRegClass, B: RC) ||
120 TRI.getCommonSubClass(A: &AArch64::PPRRegClass, B: RC);
121}
122
123bool SMEPeepholeOpt::optimizeStartStopPairs(
124 MachineBasicBlock &MBB, bool &HasRemovedAllSMChanges) const {
125 const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
126 const TargetRegisterInfo &TRI =
127 *MBB.getParent()->getSubtarget().getRegisterInfo();
128
129 bool Changed = false;
130 MachineInstr *Prev = nullptr;
131
132 // Walk through instructions in the block trying to find pairs of smstart
133 // and smstop nodes that cancel each other out. We only permit a limited
134 // set of instructions to appear between them, otherwise we reset our
135 // tracking.
136 unsigned NumSMChanges = 0;
137 unsigned NumSMChangesRemoved = 0;
138 for (MachineInstr &MI : make_early_inc_range(Range&: MBB)) {
139 switch (MI.getOpcode()) {
140 case AArch64::MSRpstatesvcrImm1:
141 case AArch64::MSRpstatePseudo: {
142 if (ChangesStreamingMode(MI: &MI))
143 NumSMChanges++;
144
145 if (!Prev)
146 Prev = &MI;
147 else if (isMatchingStartStopPair(MI1: Prev, MI2: &MI)) {
148 // If they match, we can remove them, and possibly any instructions
149 // that we marked for deletion in between.
150 Prev->eraseFromParent();
151 MI.eraseFromParent();
152 Prev = nullptr;
153 Changed = true;
154 NumSMChangesRemoved += 2;
155 } else {
156 Prev = &MI;
157 }
158 continue;
159 }
160 default:
161 if (!Prev)
162 // Avoid doing expensive checks when Prev is nullptr.
163 continue;
164 break;
165 }
166
167 // Test if the instructions in between the start/stop sequence are agnostic
168 // of streaming mode. If not, the algorithm should reset.
169 switch (MI.getOpcode()) {
170 default:
171 Prev = nullptr;
172 break;
173 case AArch64::COALESCER_BARRIER_FPR16:
174 case AArch64::COALESCER_BARRIER_FPR32:
175 case AArch64::COALESCER_BARRIER_FPR64:
176 case AArch64::COALESCER_BARRIER_FPR128:
177 case AArch64::COPY:
178 // These instructions should be safe when executed on their own, but
179 // the code remains conservative when SVE registers are used. There may
180 // exist subtle cases where executing a COPY in a different mode results
181 // in different behaviour, even if we can't yet come up with any
182 // concrete example/test-case.
183 if (isSVERegOp(TRI, MRI, MO: MI.getOperand(i: 0)) ||
184 isSVERegOp(TRI, MRI, MO: MI.getOperand(i: 1)))
185 Prev = nullptr;
186 break;
187 case AArch64::RestoreZAPseudo:
188 case AArch64::InOutZAUsePseudo:
189 case AArch64::CommitZASavePseudo:
190 case AArch64::SMEStateAllocPseudo:
191 case AArch64::RequiresZASavePseudo:
192 // These instructions only depend on the ZA state, not the streaming mode,
193 // so if the pair of smstart/stop is only changing the streaming mode, we
194 // can permit these instructions.
195 if (Prev->getOperand(i: 0).getImm() != AArch64SVCR::SVCRSM)
196 Prev = nullptr;
197 break;
198 case AArch64::ADJCALLSTACKDOWN:
199 case AArch64::ADJCALLSTACKUP:
200 case AArch64::ANDXri:
201 case AArch64::ADDXri:
202 // We permit these as they don't generate SVE/NEON instructions.
203 break;
204 case AArch64::MSRpstatesvcrImm1:
205 case AArch64::MSRpstatePseudo:
206 llvm_unreachable("Should have been handled");
207 }
208 }
209
210 HasRemovedAllSMChanges =
211 NumSMChanges && (NumSMChanges == NumSMChangesRemoved);
212 return Changed;
213}
214
215// Using the FORM_TRANSPOSED_REG_TUPLE pseudo can improve register allocation
216// of multi-vector intrinsics. However, the pseudo should only be emitted if
217// the input registers of the REG_SEQUENCE are copy nodes where the source
218// register is in a StridedOrContiguous class. For example:
219//
220// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
221// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
222// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
223// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
224// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
225// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
226// %9:zpr2mul2 = REG_SEQUENCE %5:zpr, %subreg.zsub0, %8:zpr, %subreg.zsub1
227//
228// -> %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
229//
230bool SMEPeepholeOpt::visitRegSequence(MachineInstr &MI) {
231 assert(MI.getMF()->getRegInfo().isSSA() && "Expected to be run on SSA form!");
232
233 MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
234 switch (MRI.getRegClass(Reg: MI.getOperand(i: 0).getReg())->getID()) {
235 case AArch64::ZPR2RegClassID:
236 case AArch64::ZPR4RegClassID:
237 case AArch64::ZPR2Mul2RegClassID:
238 case AArch64::ZPR4Mul4RegClassID:
239 break;
240 default:
241 return false;
242 }
243
244 // The first operand is the register class created by the REG_SEQUENCE.
245 // Each operand pair after this consists of a vreg + subreg index, so
246 // for example a sequence of 2 registers will have a total of 5 operands.
247 if (MI.getNumOperands() != 5 && MI.getNumOperands() != 9)
248 return false;
249
250 MCRegister SubReg = MCRegister::NoRegister;
251 for (unsigned I = 1; I < MI.getNumOperands(); I += 2) {
252 MachineOperand &MO = MI.getOperand(i: I);
253
254 MachineOperand *Def = MRI.getOneDef(Reg: MO.getReg());
255 if (!Def || !Def->getParent()->isCopy())
256 return false;
257
258 const MachineOperand &CopySrc = Def->getParent()->getOperand(i: 1);
259 unsigned OpSubReg = CopySrc.getSubReg();
260 if (SubReg == MCRegister::NoRegister)
261 SubReg = OpSubReg;
262
263 MachineOperand *CopySrcOp = MRI.getOneDef(Reg: CopySrc.getReg());
264 if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
265 CopySrcOp->getReg().isPhysical())
266 return false;
267
268 const TargetRegisterClass *CopySrcClass =
269 MRI.getRegClass(Reg: CopySrcOp->getReg());
270 if (CopySrcClass != &AArch64::ZPR2StridedOrContiguousRegClass &&
271 CopySrcClass != &AArch64::ZPR4StridedOrContiguousRegClass)
272 return false;
273 }
274
275 unsigned Opc = MI.getNumOperands() == 5
276 ? AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO
277 : AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO;
278
279 const TargetInstrInfo *TII =
280 MI.getMF()->getSubtarget<AArch64Subtarget>().getInstrInfo();
281 MachineInstrBuilder MIB = BuildMI(BB&: *MI.getParent(), I&: MI, MIMD: MI.getDebugLoc(),
282 MCID: TII->get(Opcode: Opc), DestReg: MI.getOperand(i: 0).getReg());
283 for (unsigned I = 1; I < MI.getNumOperands(); I += 2)
284 MIB.addReg(RegNo: MI.getOperand(i: I).getReg());
285
286 MI.eraseFromParent();
287 return true;
288}
289
290INITIALIZE_PASS(SMEPeepholeOpt, "aarch64-sme-peephole-opt",
291 "SME Peephole Optimization", false, false)
292
293bool SMEPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
294 if (skipFunction(F: MF.getFunction()))
295 return false;
296
297 if (!MF.getSubtarget<AArch64Subtarget>().hasSME())
298 return false;
299
300 assert(MF.getRegInfo().isSSA() && "Expected to be run on SSA form!");
301
302 bool Changed = false;
303 bool FunctionHasAllSMChangesRemoved = false;
304
305 // Even if the block lives in a function with no SME attributes attached we
306 // still have to analyze all the blocks because we may call a streaming
307 // function that requires smstart/smstop pairs.
308 for (MachineBasicBlock &MBB : MF) {
309 bool BlockHasAllSMChangesRemoved;
310 Changed |= optimizeStartStopPairs(MBB, HasRemovedAllSMChanges&: BlockHasAllSMChangesRemoved);
311 FunctionHasAllSMChangesRemoved |= BlockHasAllSMChangesRemoved;
312
313 if (MF.getSubtarget<AArch64Subtarget>().isStreaming()) {
314 for (MachineInstr &MI : make_early_inc_range(Range&: MBB))
315 if (MI.getOpcode() == AArch64::REG_SEQUENCE)
316 Changed |= visitRegSequence(MI);
317 }
318 }
319
320 AArch64FunctionInfo *AFI = MF.getInfo<AArch64FunctionInfo>();
321 if (FunctionHasAllSMChangesRemoved)
322 AFI->setHasStreamingModeChanges(false);
323
324 return Changed;
325}
326
327FunctionPass *llvm::createSMEPeepholeOptPass() { return new SMEPeepholeOpt(); }
328