| 1 | //===-- AMDGPURewriteAGPRCopyMFMA.cpp -------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// \file \brief Try to replace MFMA instructions using VGPRs with MFMA |
| 10 | /// instructions using AGPRs. We expect MFMAs to be selected using VGPRs, and |
| 11 | /// only use AGPRs if it helps avoid spilling. In this case, the MFMA will have |
| 12 | /// copies between AGPRs and VGPRs and the AGPR variant of an MFMA pseudo. This |
| 13 | /// pass will attempt to delete the cross register bank copy and replace the |
| 14 | /// MFMA opcode. |
| 15 | /// |
| 16 | /// TODO: |
| 17 | /// - Handle non-tied dst+src2 cases. We need to try to find a copy from an |
| 18 | /// AGPR from src2, or reassign src2 to an available AGPR (which should work |
| 19 | /// in the common case of a load). |
| 20 | /// |
| 21 | /// - Handle multiple MFMA uses of the same register. e.g. chained MFMAs that |
| 22 | /// can be rewritten as a set |
| 23 | /// |
| 24 | /// - Update LiveIntervals incrementally instead of recomputing from scratch |
| 25 | /// |
| 26 | //===----------------------------------------------------------------------===// |
| 27 | |
| 28 | #include "AMDGPU.h" |
| 29 | #include "GCNSubtarget.h" |
| 30 | #include "SIMachineFunctionInfo.h" |
| 31 | #include "SIRegisterInfo.h" |
| 32 | #include "llvm/CodeGen/LiveIntervals.h" |
| 33 | #include "llvm/CodeGen/LiveRegMatrix.h" |
| 34 | #include "llvm/CodeGen/MachineFunctionPass.h" |
| 35 | #include "llvm/CodeGen/VirtRegMap.h" |
| 36 | #include "llvm/InitializePasses.h" |
| 37 | |
| 38 | using namespace llvm; |
| 39 | |
| 40 | #define DEBUG_TYPE "amdgpu-rewrite-agpr-copy-mfma" |
| 41 | |
| 42 | namespace { |
| 43 | |
| 44 | class AMDGPURewriteAGPRCopyMFMAImpl { |
| 45 | const GCNSubtarget &ST; |
| 46 | const SIInstrInfo &TII; |
| 47 | const SIRegisterInfo &TRI; |
| 48 | MachineRegisterInfo &MRI; |
| 49 | VirtRegMap &VRM; |
| 50 | LiveRegMatrix ‎ |
| 51 | LiveIntervals &LIS; |
| 52 | |
| 53 | public: |
| 54 | AMDGPURewriteAGPRCopyMFMAImpl(MachineFunction &MF, VirtRegMap &VRM, |
| 55 | LiveRegMatrix &LRM, LiveIntervals &LIS) |
| 56 | : ST(MF.getSubtarget<GCNSubtarget>()), TII(*ST.getInstrInfo()), |
| 57 | TRI(*ST.getRegisterInfo()), MRI(MF.getRegInfo()), VRM(VRM), LRM(LRM), |
| 58 | LIS(LIS) {} |
| 59 | |
| 60 | /// Compute the register class constraints based on the uses of \p Reg, |
| 61 | /// excluding uses from \p ExceptMI. This should be nearly identical to |
| 62 | /// MachineRegisterInfo::recomputeRegClass. |
| 63 | const TargetRegisterClass * |
| 64 | recomputeRegClassExcept(Register Reg, const TargetRegisterClass *OldRC, |
| 65 | const TargetRegisterClass *NewRC, |
| 66 | const MachineInstr *ExceptMI) const; |
| 67 | |
| 68 | bool run(MachineFunction &MF) const; |
| 69 | }; |
| 70 | |
| 71 | const TargetRegisterClass * |
| 72 | AMDGPURewriteAGPRCopyMFMAImpl::recomputeRegClassExcept( |
| 73 | Register Reg, const TargetRegisterClass *OldRC, |
| 74 | const TargetRegisterClass *NewRC, const MachineInstr *ExceptMI) const { |
| 75 | |
| 76 | // Accumulate constraints from all uses. |
| 77 | for (MachineOperand &MO : MRI.reg_nodbg_operands(Reg)) { |
| 78 | // Apply the effect of the given operand to NewRC. |
| 79 | MachineInstr *MI = MO.getParent(); |
| 80 | if (MI == ExceptMI) |
| 81 | continue; |
| 82 | |
| 83 | unsigned OpNo = &MO - &MI->getOperand(i: 0); |
| 84 | NewRC = MI->getRegClassConstraintEffect(OpIdx: OpNo, CurRC: NewRC, TII: &TII, TRI: &TRI); |
| 85 | if (!NewRC || NewRC == OldRC) |
| 86 | return nullptr; |
| 87 | } |
| 88 | |
| 89 | return NewRC; |
| 90 | } |
| 91 | |
| 92 | bool AMDGPURewriteAGPRCopyMFMAImpl::run(MachineFunction &MF) const { |
| 93 | // This only applies on subtargets that have a configurable AGPR vs. VGPR |
| 94 | // allocation. |
| 95 | if (!ST.hasGFX90AInsts()) |
| 96 | return false; |
| 97 | |
| 98 | // Early exit if no AGPRs were assigned. |
| 99 | if (!LRM.isPhysRegUsed(PhysReg: AMDGPU::AGPR0)) |
| 100 | return false; |
| 101 | |
| 102 | bool MadeChange = false; |
| 103 | |
| 104 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 105 | Register VReg = Register::index2VirtReg(Index: I); |
| 106 | Register PhysReg = VRM.getPhys(virtReg: VReg); |
| 107 | if (!PhysReg) |
| 108 | continue; |
| 109 | |
| 110 | // Find AV_* registers assigned to AGPRs. |
| 111 | const TargetRegisterClass *VirtRegRC = MRI.getRegClass(Reg: VReg); |
| 112 | if (!TRI.isVectorSuperClass(RC: VirtRegRC)) |
| 113 | continue; |
| 114 | |
| 115 | const TargetRegisterClass *AssignedRC = TRI.getPhysRegBaseClass(Reg: PhysReg); |
| 116 | if (!TRI.isAGPRClass(RC: AssignedRC)) |
| 117 | continue; |
| 118 | |
| 119 | LiveInterval &LI = LIS.getInterval(Reg: VReg); |
| 120 | |
| 121 | // TODO: Test multiple uses |
| 122 | for (VNInfo *VNI : LI.vnis()) { |
| 123 | MachineInstr *DefMI = LIS.getInstructionFromIndex(index: VNI->def); |
| 124 | |
| 125 | // TODO: Handle SplitKit produced copy bundles for partially defined |
| 126 | // registers. |
| 127 | if (!DefMI || !DefMI->isFullCopy()) |
| 128 | continue; |
| 129 | |
| 130 | Register CopySrcReg = DefMI->getOperand(i: 1).getReg(); |
| 131 | if (!CopySrcReg.isVirtual()) |
| 132 | continue; |
| 133 | |
| 134 | LiveInterval &CopySrcLI = LIS.getInterval(Reg: CopySrcReg); |
| 135 | LiveQueryResult LRQ = CopySrcLI.Query(Idx: VNI->def.getRegSlot()); |
| 136 | MachineInstr *CopySrcMI = LIS.getInstructionFromIndex(index: LRQ.valueIn()->def); |
| 137 | if (!CopySrcMI) |
| 138 | continue; |
| 139 | |
| 140 | int AGPROp = AMDGPU::getMFMASrcCVDstAGPROp(Opcode: CopySrcMI->getOpcode()); |
| 141 | if (AGPROp == -1) |
| 142 | continue; |
| 143 | |
| 144 | MachineOperand *Src2 = |
| 145 | TII.getNamedOperand(MI&: *CopySrcMI, OperandName: AMDGPU::OpName::src2); |
| 146 | |
| 147 | // FIXME: getMinimalPhysRegClass returns a nonsense AV_* subclass instead |
| 148 | // of an AGPR or VGPR subclass, so we can't simply use the result on the |
| 149 | // assignment. |
| 150 | |
| 151 | LLVM_DEBUG({ |
| 152 | Register Src2PhysReg = VRM.getPhys(Src2->getReg()); |
| 153 | dbgs() << "Attempting to replace VGPR MFMA with AGPR version:" |
| 154 | << " Dst=[" << printReg(VReg) << " => " |
| 155 | << printReg(PhysReg, &TRI) << "], Src2=[" |
| 156 | << printReg(Src2->getReg(), &TRI) << " => " |
| 157 | << printReg(Src2PhysReg, &TRI) << "]: " << *CopySrcMI; |
| 158 | }); |
| 159 | |
| 160 | // If the inputs are tied and the same register, we can shortcut and |
| 161 | // directly replace the register. |
| 162 | if (Src2->getReg() != CopySrcReg) { |
| 163 | LLVM_DEBUG( |
| 164 | dbgs() |
| 165 | << "Replacing untied VGPR MFMAs with AGPR form not yet handled\n" ); |
| 166 | // TODO: Only handles the tied case for now. If the input operand is a |
| 167 | // different register, we need to also reassign it (either by looking |
| 168 | // for a compatible copy-from-AGPR, or by seeing if an available AGPR is |
| 169 | // compatible with all other uses. |
| 170 | |
| 171 | // If we can't reassign it, we'd need to introduce a different copy |
| 172 | // which is likely worse than the copy we'd be saving. |
| 173 | continue; |
| 174 | } |
| 175 | |
| 176 | const TargetRegisterClass *Src2VirtRegRC = |
| 177 | MRI.getRegClass(Reg: Src2->getReg()); |
| 178 | |
| 179 | // We've found av = COPY (MFMA), and need to verify that we can trivially |
| 180 | // rewrite src2 to use the new AGPR. If we can't trivially replace it, |
| 181 | // we're going to induce as many copies as we would have emitted in the |
| 182 | // first place, as well as need to assign another register, and need to |
| 183 | // figure out where to put them. The live range splitting is smarter than |
| 184 | // anything we're doing here, so trust it did something reasonable. |
| 185 | const TargetRegisterClass *Src2ExceptRC = recomputeRegClassExcept( |
| 186 | Reg: Src2->getReg(), OldRC: Src2VirtRegRC, NewRC: VirtRegRC, ExceptMI: CopySrcMI); |
| 187 | if (!Src2ExceptRC) |
| 188 | continue; |
| 189 | |
| 190 | const TargetRegisterClass *NewSrc2ConstraintRC = |
| 191 | TII.getRegClass(TID: TII.get(Opcode: AGPROp), OpNum: Src2->getOperandNo(), TRI: &TRI, MF); |
| 192 | |
| 193 | // Try to constrain src2 to the replacement instruction candidate's |
| 194 | // register class. |
| 195 | const TargetRegisterClass *NewSrc2RC = |
| 196 | TRI.getCommonSubClass(A: Src2ExceptRC, B: NewSrc2ConstraintRC); |
| 197 | if (!NewSrc2RC) { |
| 198 | // TODO: This is ignoring ther rewritable uses. e.g. a rewritable MFMA |
| 199 | // using a rewritable MFMA can be rewritten as a pair. |
| 200 | LLVM_DEBUG(dbgs() << "Other uses of " << printReg(Src2->getReg(), &TRI) |
| 201 | << " are incompatible with replacement class\n" ); |
| 202 | continue; |
| 203 | } |
| 204 | |
| 205 | MRI.setRegClass(Reg: VReg, RC: AssignedRC); |
| 206 | MRI.setRegClass(Reg: Src2->getReg(), RC: NewSrc2RC); |
| 207 | |
| 208 | CopySrcMI->setDesc(TII.get(Opcode: AGPROp)); |
| 209 | |
| 210 | // TODO: Is replacing too aggressive, fixup these instructions only? |
| 211 | MRI.replaceRegWith(FromReg: CopySrcReg, ToReg: VReg); |
| 212 | |
| 213 | LLVM_DEBUG(dbgs() << "Replaced VGPR MFMA with AGPR: " << *CopySrcMI); |
| 214 | |
| 215 | // We left behind an identity copy, so delete it. |
| 216 | LIS.RemoveMachineInstrFromMaps(MI&: *DefMI); |
| 217 | DefMI->eraseFromParent(); |
| 218 | |
| 219 | LRM.unassign(VirtReg: CopySrcLI); |
| 220 | |
| 221 | // We don't need the liveness information anymore, so don't bother |
| 222 | // updating the intervals. Just delete the stale information. |
| 223 | // TODO: Is it worth preserving these? |
| 224 | LIS.removeInterval(Reg: CopySrcReg); |
| 225 | LIS.removeInterval(Reg: VReg); |
| 226 | LIS.createAndComputeVirtRegInterval(Reg: VReg); |
| 227 | |
| 228 | MadeChange = true; |
| 229 | } |
| 230 | } |
| 231 | |
| 232 | return MadeChange; |
| 233 | } |
| 234 | |
| 235 | class AMDGPURewriteAGPRCopyMFMALegacy : public MachineFunctionPass { |
| 236 | public: |
| 237 | static char ID; |
| 238 | |
| 239 | AMDGPURewriteAGPRCopyMFMALegacy() : MachineFunctionPass(ID) { |
| 240 | initializeAMDGPURewriteAGPRCopyMFMALegacyPass( |
| 241 | *PassRegistry::getPassRegistry()); |
| 242 | } |
| 243 | |
| 244 | bool runOnMachineFunction(MachineFunction &MF) override; |
| 245 | |
| 246 | StringRef getPassName() const override { |
| 247 | return "AMDGPU Rewrite AGPR-Copy-MFMA" ; |
| 248 | } |
| 249 | |
| 250 | void getAnalysisUsage(AnalysisUsage &AU) const override { |
| 251 | AU.addRequired<LiveIntervalsWrapperPass>(); |
| 252 | AU.addRequired<VirtRegMapWrapperLegacy>(); |
| 253 | AU.addRequired<LiveRegMatrixWrapperLegacy>(); |
| 254 | |
| 255 | AU.addPreserved<LiveIntervalsWrapperPass>(); |
| 256 | AU.addPreserved<VirtRegMapWrapperLegacy>(); |
| 257 | AU.addPreserved<LiveRegMatrixWrapperLegacy>(); |
| 258 | AU.setPreservesAll(); |
| 259 | MachineFunctionPass::getAnalysisUsage(AU); |
| 260 | } |
| 261 | }; |
| 262 | |
| 263 | } // End anonymous namespace. |
| 264 | |
| 265 | INITIALIZE_PASS_BEGIN(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE, |
| 266 | "AMDGPU Rewrite AGPR-Copy-MFMA" , false, false) |
| 267 | INITIALIZE_PASS_DEPENDENCY(LiveIntervalsWrapperPass) |
| 268 | INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy) |
| 269 | INITIALIZE_PASS_DEPENDENCY(LiveRegMatrixWrapperLegacy) |
| 270 | INITIALIZE_PASS_END(AMDGPURewriteAGPRCopyMFMALegacy, DEBUG_TYPE, |
| 271 | "AMDGPU Rewrite AGPR-Copy-MFMA" , false, false) |
| 272 | |
| 273 | char AMDGPURewriteAGPRCopyMFMALegacy::ID = 0; |
| 274 | |
| 275 | char &llvm::AMDGPURewriteAGPRCopyMFMALegacyID = |
| 276 | AMDGPURewriteAGPRCopyMFMALegacy::ID; |
| 277 | |
| 278 | bool AMDGPURewriteAGPRCopyMFMALegacy::runOnMachineFunction( |
| 279 | MachineFunction &MF) { |
| 280 | if (skipFunction(F: MF.getFunction())) |
| 281 | return false; |
| 282 | |
| 283 | auto &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM(); |
| 284 | auto &LRM = getAnalysis<LiveRegMatrixWrapperLegacy>().getLRM(); |
| 285 | auto &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); |
| 286 | |
| 287 | AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS); |
| 288 | return Impl.run(MF); |
| 289 | } |
| 290 | |
| 291 | PreservedAnalyses |
| 292 | AMDGPURewriteAGPRCopyMFMAPass::run(MachineFunction &MF, |
| 293 | MachineFunctionAnalysisManager &MFAM) { |
| 294 | VirtRegMap &VRM = MFAM.getResult<VirtRegMapAnalysis>(IR&: MF); |
| 295 | LiveRegMatrix &LRM = MFAM.getResult<LiveRegMatrixAnalysis>(IR&: MF); |
| 296 | LiveIntervals &LIS = MFAM.getResult<LiveIntervalsAnalysis>(IR&: MF); |
| 297 | |
| 298 | AMDGPURewriteAGPRCopyMFMAImpl Impl(MF, VRM, LRM, LIS); |
| 299 | if (!Impl.run(MF)) |
| 300 | return PreservedAnalyses::all(); |
| 301 | auto PA = getMachineFunctionPassPreservedAnalyses(); |
| 302 | PA.preserveSet<CFGAnalyses>(); |
| 303 | return PA; |
| 304 | } |
| 305 | |