1//===-- AMDGPURewriteAGPRCopyMFMA.cpp -------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// \file \brief Try to replace MFMA instructions using VGPRs with MFMA
10/// instructions using AGPRs. We expect MFMAs to be selected using VGPRs, and
11/// only use AGPRs if it helps avoid spilling. In this case, the MFMA will have
12/// copies between AGPRs and VGPRs and the AGPR variant of an MFMA pseudo. This
13/// pass will attempt to delete the cross register bank copy and replace the
14/// MFMA opcode.
15///
16/// TODO:
17/// - Handle non-tied dst+src2 cases. We need to try to find a copy from an
18/// AGPR from src2, or reassign src2 to an available AGPR (which should work
19/// in the common case of a load).
20///
21/// - Handle multiple MFMA uses of the same register. e.g. chained MFMAs that
22/// can be rewritten as a set
23///
24/// - Update LiveIntervals incrementally instead of recomputing from scratch
25///
26//===----------------------------------------------------------------------===//
27
28#include "AMDGPU.h"
29#include "GCNSubtarget.h"
30#include "SIMachineFunctionInfo.h"
31#include "SIRegisterInfo.h"
32#include "llvm/CodeGen/LiveIntervals.h"
33#include "llvm/CodeGen/LiveRegMatrix.h"
34#include "llvm/CodeGen/MachineFunctionPass.h"
35#include "llvm/CodeGen/VirtRegMap.h"
36#include "llvm/InitializePasses.h"
37
38using namespace llvm;
39
40#define DEBUG_TYPE "amdgpu-rewrite-agpr-copy-mfma"
41
42namespace {
43
44class AMDGPURewriteAGPRCopyMFMAImpl {
45 const GCNSubtarget &ST;
46 const SIInstrInfo &TII;
47 const SIRegisterInfo &TRI;
48 MachineRegisterInfo &MRI;
49 VirtRegMap &VRM;
50 LiveRegMatrix ‎
51 LiveIntervals &LIS;
52
53public:
54 AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM,
55 LiveRegMatrix &LRM, LiveIntervals &LIS)
56 : ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()),
57 TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM),
58 LIS(LIS) {}
59
60 /// Compute the register class constraints based on the uses of \p Reg,
61 /// excluding uses from \p ExceptMI. This should be nearly identical to
62 /// MachineRegisterInfo::recomputeRegClass.
63 const TargetRegisterClass *
64 recomputeRegClassExcept(Register Reg, const TargetRegisterClass *OldRC,
65 const TargetRegisterClass *NewRC,
66 const MachineInstr *ExceptMI) const;
67
68 bool run(MachineFunction &MF) const;
69};
70
71const TargetRegisterClass *
72AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExcept(
73 Register Reg, const TargetRegisterClass *OldRC,
74 const TargetRegisterClass *NewRC, const MachineInstr *ExceptMI) const {
75
76 // Accumulate constraints from all uses.
77 for (MachineOperand &MO : MRI.reg_nodbg_operands(Reg)) {
78 // Apply the effect of the given operand to NewRC.
79 MachineInstr *MI = MO.getParent();
80 if (MI == ExceptMI)
81 continue;
82
83 unsigned OpNo = &MO - &MI->getOperand(i: 0);
84 NewRC = MI->getRegClassConstraintEffect(OpIdx: OpNo, CurRC: NewRC, TII: &TII, TRI: &TRI);
85 if (!NewRC || NewRC == OldRC)
86 return nullptr;
87 }
88
89 return NewRC;
90}
91
92bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const {
93 // This only applies on subtargets that have a configurable AGPR vs. VGPR
94 // allocation.
95 if (!ST.hasGFX90AInsts())
96 return false;
97
98 // Early exit if no AGPRs were assigned.
99 if (!LRM.isPhysRegUsed(PhysReg: AMDGPU::AGPR0))
100 return false;
101
102 bool MadeChange = false;
103
104 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
105 Register VReg = Register::index2VirtReg(Index: I);
106 Register PhysReg = VRM.getPhys(virtReg: VReg);
107 if (!PhysReg)
108 continue;
109
110 // Find AV_* registers assigned to AGPRs.
111 const TargetRegisterClass *VirtRegRC = MRI.getRegClass(Reg: VReg);
112 if (!TRI.isVectorSuperClass(RC: VirtRegRC))
113 continue;
114
115 const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(Reg: PhysReg);
116 if (!TRI.isAGPRClass(RC: AssignedRC))
117 continue;
118
119 LiveInterval &LI = LIS.getInterval(Reg: VReg);
120
121 // TODO: Test multiple uses
122 for (VNInfo *VNI : LI.vnis()) {
123 MachineInstr *DefMI = LIS.getInstructionFromIndex(index: VNI->def);
124
125 // TODO: Handle SplitKit produced copy bundles for partially defined
126 // registers.
127 if (!DefMI || !DefMI->isFullCopy())
128 continue;
129
130 Register CopySrcReg = DefMI->getOperand(i: 1).getReg();
131 if (!CopySrcReg.isVirtual())
132 continue;
133
134 LiveInterval &CopySrcLI = LIS.getInterval(Reg: CopySrcReg);
135 LiveQueryResult LRQ = CopySrcLI.Query(Idx: VNI->def.getRegSlot());
136 MachineInstr *CopySrcMI = LIS.getInstructionFromIndex(index: LRQ.valueIn()->def);
137 if (!CopySrcMI)
138 continue;
139
140 int AGPROp = AMDGPU::getMFMASrcCVDstAGPROp(Opcode: CopySrcMI->getOpcode());
141 if (AGPROp == -1)
142 continue;
143
144 MachineOperand *Src2 =
145 TII.getNamedOperand(MI&: *CopySrcMI, OperandName: AMDGPU::OpName::src2);
146
147 // FIXME: getMinimalPhysRegClass returns a nonsense AV_* subclass instead
148 // of an AGPR or VGPR subclass, so we can't simply use the result on the
149 // assignment.
150
151 LLVM_DEBUG({
152 Register Src2PhysReg = VRM.getPhys(Src2->getReg());
153 dbgs() << "Attempting to replace VGPR MFMA with AGPR version:"
154 << " Dst=[" << printReg(VReg) << " => "
155 << printReg(PhysReg, &TRI) << "], Src2=["
156 << printReg(Src2->getReg(), &TRI) << " => "
157 << printReg(Src2PhysReg, &TRI) << "]: " << *CopySrcMI;
158 });
159
160 // If the inputs are tied and the same register, we can shortcut and
161 // directly replace the register.
162 if (Src2->getReg() != CopySrcReg) {
163 LLVM_DEBUG(
164 dbgs()
165 << "Replacing untied VGPR MFMAs with AGPR form not yet handled\n");
166 // TODO: Only handles the tied case for now. If the input operand is a
167 // different register, we need to also reassign it (either by looking
168 // for a compatible copy-from-AGPR, or by seeing if an available AGPR is
169 // compatible with all other uses.
170
171 // If we can't reassign it, we'd need to introduce a different copy
172 // which is likely worse than the copy we'd be saving.
173 continue;
174 }
175
176 const TargetRegisterClass *Src2VirtRegRC =
177 MRI.getRegClass(Reg: Src2->getReg());
178
179 // We've found av = COPY (MFMA), and need to verify that we can trivially
180 // rewrite src2 to use the new AGPR. If we can't trivially replace it,
181 // we're going to induce as many copies as we would have emitted in the
182 // first place, as well as need to assign another register, and need to
183 // figure out where to put them. The live range splitting is smarter than
184 // anything we're doing here, so trust it did something reasonable.
185 const TargetRegisterClass *Src2ExceptRC = recomputeRegClassExcept(
186 Reg: Src2->getReg(), OldRC: Src2VirtRegRC, NewRC: VirtRegRC, ExceptMI: CopySrcMI);
187 if (!Src2ExceptRC)
188 continue;
189
190 const TargetRegisterClass *NewSrc2ConstraintRC =
191 TII.getRegClass(TID: TII.get(Opcode: AGPROp), OpNum: Src2->getOperandNo(), TRI: &TRI, MF);
192
193 // Try to constrain src2 to the replacement instruction candidate's
194 // register class.
195 const TargetRegisterClass *NewSrc2RC =
196 TRI.getCommonSubClass(A: Src2ExceptRC, B: NewSrc2ConstraintRC);
197 if (!NewSrc2RC) {
198 // TODO: This is ignoring ther rewritable uses. e.g. a rewritable MFMA
199 // using a rewritable MFMA can be rewritten as a pair.
200 LLVM_DEBUG(dbgs() << "Other uses of " << printReg(Src2->getReg(), &TRI)
201 << " are incompatible with replacement class\n");
202 continue;
203 }
204
205 MRI.setRegClass(Reg: VReg, RC: AssignedRC);
206 MRI.setRegClass(Reg: Src2->getReg(), RC: NewSrc2RC);
207
208 CopySrcMI->setDesc(TII.get(Opcode: AGPROp));
209
210 // TODO: Is replacing too aggressive, fixup these instructions only?
211 MRI.replaceRegWith(FromReg: CopySrcReg, ToReg: VReg);
212
213 LLVM_DEBUG(dbgs() << "Replaced VGPR MFMA with AGPR: " << *CopySrcMI);
214
215 // We left behind an identity copy, so delete it.
216 LIS.RemoveMachineInstrFromMaps(MI&: *DefMI);
217 DefMI->eraseFromParent();
218
219 LRM.unassign(VirtReg: CopySrcLI);
220
221 // We don't need the liveness information anymore, so don't bother
222 // updating the intervals. Just delete the stale information.
223 // TODO: Is it worth preserving these?
224 LIS.removeInterval(Reg: CopySrcReg);
225 LIS.removeInterval(Reg: VReg);
226 LIS.createAndComputeVirtRegInterval(Reg: VReg);
227
228 MadeChange = true;
229 }
230 }
231
232 return MadeChange;
233}
234
235class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass {
236public:
237 static char ID;
238
239 AMDGPURewriteAGPRCopyMFMALegacy() : MachineFunctionPass(ID) {
240 initializeAMDGPURewriteAGPRCopyMFMALegacyPass(
241 *PassRegistry::getPassRegistry());
242 }
243
244 bool runOnMachineFunction(MachineFunction &MF) override;
245
246 StringRef getPassName() const override {
247 return "AMDGPU Rewrite AGPR-Copy-MFMA";
248 }
249
250 void getAnalysisUsage(AnalysisUsage &AU) const override {
251 AU.addRequired<LiveIntervalsWrapperPass>();
252 AU.addRequired<VirtRegMapWrapperLegacy>();
253 AU.addRequired<LiveRegMatrixWrapperLegacy>();
254
255 AU.addPreserved<LiveIntervalsWrapperPass>();
256 AU.addPreserved<VirtRegMapWrapperLegacy>();
257 AU.addPreserved<LiveRegMatrixWrapperLegacy>();
258 AU.setPreservesAll();
259 MachineFunctionPass::getAnalysisUsage(AU);
260 }
261};
262
263} // End anonymous namespace.
264
265INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
266 "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
267INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass)
268INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy)
269INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy)
270INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE,
271 "AMDGPU Rewrite AGPR-Copy-MFMA", false, false)
272
273char AMDGPURewriteAGPRCopyMFMALegacy::ID = 0;
274
275char &llvm::AMDGPURewriteAGPRCopyMFMALegacyID =
276 AMDGPURewriteAGPRCopyMFMALegacy::ID;
277
278bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction(
279 MachineFunction &MF) {
280 if (skipFunction(F: MF.getFunction()))
281 return false;
282
283 auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM();
284 auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM();
285 auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
286
287 AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS);
288 return Impl.run(MF);
289}
290
291PreservedAnalyses
292AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF,
293 MachineFunctionAnalysisManager &MFAM) {
294 VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(IR&: MF);
295 LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(IR&: MF);
296 LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(IR&: MF);
297
298 AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS);
299 if (!Impl.run(MF))
300 return PreservedAnalyses::all();
301 auto PA = getMachineFunctionPassPreservedAnalyses();
302 PA.preserveSet<CFGAnalyses>();
303 return PA;
304}
305