| 1 | //===- GCNRegPressure.cpp -------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements the GCNRegPressure class. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "GCNRegPressure.h" |
| 15 | #include "AMDGPU.h" |
| 16 | #include "SIMachineFunctionInfo.h" |
| 17 | #include "llvm/CodeGen/MachineBasicBlock.h" |
| 18 | #include "llvm/CodeGen/MachineLoopInfo.h" |
| 19 | #include "llvm/CodeGen/RegisterPressure.h" |
| 20 | |
| 21 | using namespace llvm; |
| 22 | |
| 23 | #define DEBUG_TYPE "machine-scheduler" |
| 24 | |
| 25 | bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, |
| 26 | const GCNRPTracker::LiveRegSet &S2) { |
| 27 | if (S1.size() != S2.size()) |
| 28 | return false; |
| 29 | |
| 30 | for (const auto &P : S1) { |
| 31 | auto I = S2.find(Val: P.first); |
| 32 | if (I == S2.end() || I->second != P.second) |
| 33 | return false; |
| 34 | } |
| 35 | return true; |
| 36 | } |
| 37 | |
| 38 | /////////////////////////////////////////////////////////////////////////////// |
| 39 | // GCNRegPressure |
| 40 | |
| 41 | unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC, |
| 42 | const SIRegisterInfo *STI) { |
| 43 | return STI->isSGPRClass(RC) |
| 44 | ? SGPR |
| 45 | : (STI->isAGPRClass(RC) |
| 46 | ? AGPR |
| 47 | : (STI->isVectorSuperClass(RC) ? AVGPR : VGPR)); |
| 48 | } |
| 49 | |
| 50 | void GCNRegPressure::inc(unsigned Reg, |
| 51 | LaneBitmask PrevMask, |
| 52 | LaneBitmask NewMask, |
| 53 | const MachineRegisterInfo &MRI) { |
| 54 | unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: NewMask); |
| 55 | unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: PrevMask); |
| 56 | if (NewNumCoveredRegs == PrevNumCoveredRegs) |
| 57 | return; |
| 58 | |
| 59 | int Sign = 1; |
| 60 | if (NewMask < PrevMask) { |
| 61 | std::swap(a&: NewMask, b&: PrevMask); |
| 62 | std::swap(a&: NewNumCoveredRegs, b&: PrevNumCoveredRegs); |
| 63 | Sign = -1; |
| 64 | } |
| 65 | assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs && |
| 66 | "prev mask should always be lesser than new" ); |
| 67 | |
| 68 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 69 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 70 | const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI); |
| 71 | unsigned RegKind = getRegKind(RC, STI); |
| 72 | if (TRI->getRegSizeInBits(RC: *RC) != 32) { |
| 73 | // Reg is from a tuple register class. |
| 74 | if (PrevMask.none()) { |
| 75 | unsigned TupleIdx = TOTAL_KINDS + RegKind; |
| 76 | Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight; |
| 77 | } |
| 78 | // Pressure scales with number of new registers covered by the new mask. |
| 79 | // Note when true16 is enabled, we can no longer safely use the following |
| 80 | // approach to calculate the difference in the number of 32-bit registers |
| 81 | // between two masks: |
| 82 | // |
| 83 | // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); |
| 84 | // |
| 85 | // The issue is that the mask calculation `~PrevMask & NewMask` doesn't |
| 86 | // properly account for partial usage of a 32-bit register when dealing with |
| 87 | // 16-bit registers. |
| 88 | // |
| 89 | // Consider this example: |
| 90 | // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register |
| 91 | // usage difference should be 1, because even though PrevMask uses only half |
| 92 | // of a 32-bit register, it should still be counted as a full register use. |
| 93 | // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and |
| 94 | // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect |
| 95 | // calculation can lead to integer overflow when Sign = -1. |
| 96 | Sign *= NewNumCoveredRegs - PrevNumCoveredRegs; |
| 97 | } |
| 98 | Value[RegKind] += Sign; |
| 99 | } |
| 100 | |
| 101 | namespace { |
| 102 | struct RegExcess { |
| 103 | unsigned SGPR = 0; |
| 104 | unsigned VGPR = 0; |
| 105 | unsigned ArchVGPR = 0; |
| 106 | unsigned AGPR = 0; |
| 107 | |
| 108 | bool anyExcess() const { return SGPR || VGPR || ArchVGPR || AGPR; } |
| 109 | bool hasVectorRegisterExcess() const { return VGPR || ArchVGPR || AGPR; } |
| 110 | |
| 111 | RegExcess(const MachineFunction &MF, const GCNRegPressure &RP) |
| 112 | : RegExcess(MF, RP, GCNRPTarget(MF, RP)) {} |
| 113 | RegExcess(const MachineFunction &MF, const GCNRegPressure &RP, |
| 114 | const GCNRPTarget &Target) { |
| 115 | unsigned MaxSGPRs = Target.getMaxSGPRs(); |
| 116 | unsigned MaxVGPRs = Target.getMaxVGPRs(); |
| 117 | |
| 118 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 119 | SGPR = std::max(a: static_cast<int>(RP.getSGPRNum() - MaxSGPRs), b: 0); |
| 120 | |
| 121 | // The number of virtual VGPRs required to handle excess SGPR |
| 122 | unsigned WaveSize = ST.getWavefrontSize(); |
| 123 | unsigned VGPRForSGPRSpills = divideCeil(Numerator: SGPR, Denominator: WaveSize); |
| 124 | |
| 125 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
| 126 | |
| 127 | // Unified excess pressure conditions, accounting for VGPRs used for SGPR |
| 128 | // spills |
| 129 | VGPR = std::max(a: static_cast<int>(RP.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
| 130 | VGPRForSGPRSpills - MaxVGPRs), |
| 131 | b: 0); |
| 132 | |
| 133 | unsigned ArchVGPRLimit = ST.hasGFX90AInsts() ? MaxArchVGPRs : MaxVGPRs; |
| 134 | // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR |
| 135 | // spills |
| 136 | ArchVGPR = std::max(a: static_cast<int>(RP.getArchVGPRNum() + |
| 137 | VGPRForSGPRSpills - ArchVGPRLimit), |
| 138 | b: 0); |
| 139 | |
| 140 | // AGPR excess pressure conditions |
| 141 | AGPR = std::max(a: static_cast<int>(RP.getAGPRNum() - ArchVGPRLimit), b: 0); |
| 142 | } |
| 143 | }; |
| 144 | } // namespace |
| 145 | |
| 146 | bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, |
| 147 | unsigned MaxOccupancy) const { |
| 148 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 149 | unsigned DynamicVGPRBlockSize = |
| 150 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 151 | |
| 152 | const auto SGPROcc = std::min(a: MaxOccupancy, |
| 153 | b: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum())); |
| 154 | const auto VGPROcc = std::min( |
| 155 | a: MaxOccupancy, b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 156 | DynamicVGPRBlockSize)); |
| 157 | const auto OtherSGPROcc = std::min(a: MaxOccupancy, |
| 158 | b: ST.getOccupancyWithNumSGPRs(SGPRs: O.getSGPRNum())); |
| 159 | const auto OtherVGPROcc = |
| 160 | std::min(a: MaxOccupancy, |
| 161 | b: ST.getOccupancyWithNumVGPRs(VGPRs: O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 162 | DynamicVGPRBlockSize)); |
| 163 | |
| 164 | const auto Occ = std::min(a: SGPROcc, b: VGPROcc); |
| 165 | const auto OtherOcc = std::min(a: OtherSGPROcc, b: OtherVGPROcc); |
| 166 | |
| 167 | // Give first precedence to the better occupancy. |
| 168 | if (Occ != OtherOcc) |
| 169 | return Occ > OtherOcc; |
| 170 | |
| 171 | unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); |
| 172 | |
| 173 | RegExcess Excess(MF, *this); |
| 174 | RegExcess OtherExcess(MF, O); |
| 175 | |
| 176 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
| 177 | |
| 178 | bool ExcessRP = Excess.anyExcess(); |
| 179 | bool OtherExcessRP = OtherExcess.anyExcess(); |
| 180 | |
| 181 | // Give second precedence to the reduced number of spills to hold the register |
| 182 | // pressure. |
| 183 | if (ExcessRP || OtherExcessRP) { |
| 184 | // The difference in excess VGPR pressure, after including VGPRs used for |
| 185 | // SGPR spills |
| 186 | int VGPRDiff = |
| 187 | ((OtherExcess.VGPR + OtherExcess.ArchVGPR + OtherExcess.AGPR) - |
| 188 | (Excess.VGPR + Excess.ArchVGPR + Excess.AGPR)); |
| 189 | |
| 190 | int SGPRDiff = OtherExcess.SGPR - Excess.SGPR; |
| 191 | |
| 192 | if (VGPRDiff != 0) |
| 193 | return VGPRDiff > 0; |
| 194 | if (SGPRDiff != 0) { |
| 195 | unsigned PureExcessVGPR = |
| 196 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 197 | b: 0) + |
| 198 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 199 | unsigned OtherPureExcessVGPR = |
| 200 | std::max( |
| 201 | a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 202 | b: 0) + |
| 203 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 204 | |
| 205 | // If we have a special case where there is a tie in excess VGPR, but one |
| 206 | // of the pressures has VGPR usage from SGPR spills, prefer the pressure |
| 207 | // with SGPR spills. |
| 208 | if (PureExcessVGPR != OtherPureExcessVGPR) |
| 209 | return SGPRDiff < 0; |
| 210 | // If both pressures have the same excess pressure before and after |
| 211 | // accounting for SGPR spills, prefer fewer SGPR spills. |
| 212 | return SGPRDiff > 0; |
| 213 | } |
| 214 | } |
| 215 | |
| 216 | bool SGPRImportant = SGPROcc < VGPROcc; |
| 217 | const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; |
| 218 | |
| 219 | // If both pressures disagree on what is more important compare vgprs. |
| 220 | if (SGPRImportant != OtherSGPRImportant) { |
| 221 | SGPRImportant = false; |
| 222 | } |
| 223 | |
| 224 | // Give third precedence to lower register tuple pressure. |
| 225 | bool SGPRFirst = SGPRImportant; |
| 226 | for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { |
| 227 | if (SGPRFirst) { |
| 228 | auto SW = getSGPRTuplesWeight(); |
| 229 | auto OtherSW = O.getSGPRTuplesWeight(); |
| 230 | if (SW != OtherSW) |
| 231 | return SW < OtherSW; |
| 232 | } else { |
| 233 | auto VW = getVGPRTuplesWeight(); |
| 234 | auto OtherVW = O.getVGPRTuplesWeight(); |
| 235 | if (VW != OtherVW) |
| 236 | return VW < OtherVW; |
| 237 | } |
| 238 | } |
| 239 | |
| 240 | // Give final precedence to lower general RP. |
| 241 | return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): |
| 242 | (getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) < |
| 243 | O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts())); |
| 244 | } |
| 245 | |
| 246 | Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST, |
| 247 | unsigned DynamicVGPRBlockSize) { |
| 248 | return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) { |
| 249 | OS << "VGPRs: " << RP.getArchVGPRNum() << ' ' |
| 250 | << "AGPRs: " << RP.getAGPRNum(); |
| 251 | if (ST) |
| 252 | OS << "(O" |
| 253 | << ST->getOccupancyWithNumVGPRs(VGPRs: RP.getVGPRNum(UnifiedVGPRFile: ST->hasGFX90AInsts()), |
| 254 | DynamicVGPRBlockSize) |
| 255 | << ')'; |
| 256 | OS << ", SGPRs: " << RP.getSGPRNum(); |
| 257 | if (ST) |
| 258 | OS << "(O" << ST->getOccupancyWithNumSGPRs(SGPRs: RP.getSGPRNum()) << ')'; |
| 259 | OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() |
| 260 | << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); |
| 261 | if (ST) |
| 262 | OS << " -> Occ: " << RP.getOccupancy(ST: *ST, DynamicVGPRBlockSize); |
| 263 | OS << '\n'; |
| 264 | }); |
| 265 | } |
| 266 | |
| 267 | static LaneBitmask getDefRegMask(const MachineOperand &MO, |
| 268 | const MachineRegisterInfo &MRI) { |
| 269 | assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); |
| 270 | |
| 271 | // We don't rely on read-undef flag because in case of tentative schedule |
| 272 | // tracking it isn't set correctly yet. This works correctly however since |
| 273 | // use mask has been tracked before using LIS. |
| 274 | return MO.getSubReg() == 0 ? |
| 275 | MRI.getMaxLaneMaskForVReg(Reg: MO.getReg()) : |
| 276 | MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubIdx: MO.getSubReg()); |
| 277 | } |
| 278 | |
| 279 | static void |
| 280 | collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits, |
| 281 | const MachineInstr &MI, const LiveIntervals &LIS, |
| 282 | const MachineRegisterInfo &MRI) { |
| 283 | |
| 284 | auto &TRI = *MRI.getTargetRegisterInfo(); |
| 285 | for (const auto &MO : MI.operands()) { |
| 286 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 287 | continue; |
| 288 | if (!MO.isUse() || !MO.readsReg()) |
| 289 | continue; |
| 290 | |
| 291 | Register Reg = MO.getReg(); |
| 292 | auto I = llvm::find_if(Range&: VRegMaskOrUnits, P: [Reg](const VRegMaskOrUnit &RM) { |
| 293 | return RM.VRegOrUnit.asVirtualReg() == Reg; |
| 294 | }); |
| 295 | |
| 296 | auto &P = I == VRegMaskOrUnits.end() |
| 297 | ? VRegMaskOrUnits.emplace_back(Args: VirtRegOrUnit(Reg), |
| 298 | Args: LaneBitmask::getNone()) |
| 299 | : *I; |
| 300 | |
| 301 | P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(SubIdx: MO.getSubReg()) |
| 302 | : MRI.getMaxLaneMaskForVReg(Reg); |
| 303 | } |
| 304 | |
| 305 | SlotIndex InstrSI; |
| 306 | for (auto &P : VRegMaskOrUnits) { |
| 307 | auto &LI = LIS.getInterval(Reg: P.VRegOrUnit.asVirtualReg()); |
| 308 | if (!LI.hasSubRanges()) |
| 309 | continue; |
| 310 | |
| 311 | // For a tentative schedule LIS isn't updated yet but livemask should |
| 312 | // remain the same on any schedule. Subreg defs can be reordered but they |
| 313 | // all must dominate uses anyway. |
| 314 | if (!InstrSI) |
| 315 | InstrSI = LIS.getInstructionIndex(Instr: MI).getBaseIndex(); |
| 316 | |
| 317 | P.LaneMask = getLiveLaneMask(LI, SI: InstrSI, MRI, LaneMaskFilter: P.LaneMask); |
| 318 | } |
| 319 | } |
| 320 | |
| 321 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 322 | static LaneBitmask getLanesWithProperty( |
| 323 | const LiveIntervals &LIS, const MachineRegisterInfo &MRI, |
| 324 | bool TrackLaneMasks, Register Reg, SlotIndex Pos, |
| 325 | function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) { |
| 326 | assert(Reg.isVirtual()); |
| 327 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 328 | LaneBitmask Result; |
| 329 | if (TrackLaneMasks && LI.hasSubRanges()) { |
| 330 | for (const LiveInterval::SubRange &SR : LI.subranges()) { |
| 331 | if (Property(SR, Pos)) |
| 332 | Result |= SR.LaneMask; |
| 333 | } |
| 334 | } else if (Property(LI, Pos)) { |
| 335 | Result = |
| 336 | TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(Reg) : LaneBitmask::getAll(); |
| 337 | } |
| 338 | |
| 339 | return Result; |
| 340 | } |
| 341 | |
| 342 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 343 | /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}. |
| 344 | /// The query starts with a lane bitmask which gets lanes/bits removed for every |
| 345 | /// use we find. |
| 346 | static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, |
| 347 | SlotIndex PriorUseIdx, SlotIndex NextUseIdx, |
| 348 | const MachineRegisterInfo &MRI, |
| 349 | const SIRegisterInfo *TRI, |
| 350 | const LiveIntervals *LIS, |
| 351 | bool Upward = false) { |
| 352 | for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { |
| 353 | if (MO.isUndef()) |
| 354 | continue; |
| 355 | const MachineInstr *MI = MO.getParent(); |
| 356 | SlotIndex InstSlot = LIS->getInstructionIndex(Instr: *MI).getRegSlot(); |
| 357 | bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) |
| 358 | : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); |
| 359 | if (!InRange) |
| 360 | continue; |
| 361 | |
| 362 | unsigned SubRegIdx = MO.getSubReg(); |
| 363 | LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubIdx: SubRegIdx); |
| 364 | LastUseMask &= ~UseMask; |
| 365 | if (LastUseMask.none()) |
| 366 | return LaneBitmask::getNone(); |
| 367 | } |
| 368 | return LastUseMask; |
| 369 | } |
| 370 | |
| 371 | //////////////////////////////////////////////////////////////////////////////// |
| 372 | // GCNRPTarget |
| 373 | |
| 374 | GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP) |
| 375 | : GCNRPTarget(RP, MF) { |
| 376 | const Function &F = MF.getFunction(); |
| 377 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 378 | setTarget(NumSGPRs: ST.getMaxNumSGPRs(F), NumVGPRs: ST.getMaxNumVGPRs(F)); |
| 379 | } |
| 380 | |
| 381 | GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs, |
| 382 | const MachineFunction &MF, const GCNRegPressure &RP) |
| 383 | : GCNRPTarget(RP, MF) { |
| 384 | setTarget(NumSGPRs, NumVGPRs); |
| 385 | } |
| 386 | |
| 387 | GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF, |
| 388 | const GCNRegPressure &RP) |
| 389 | : GCNRPTarget(RP, MF) { |
| 390 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 391 | unsigned DynamicVGPRBlockSize = |
| 392 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 393 | setTarget(NumSGPRs: ST.getMaxNumSGPRs(WavesPerEU: Occupancy, /*Addressable=*/false), |
| 394 | NumVGPRs: ST.getMaxNumVGPRs(WavesPerEU: Occupancy, DynamicVGPRBlockSize)); |
| 395 | } |
| 396 | |
| 397 | void GCNRPTarget::setTarget(unsigned NumSGPRs, unsigned NumVGPRs) { |
| 398 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 399 | MaxSGPRs = std::min(a: ST.getAddressableNumSGPRs(), b: NumSGPRs); |
| 400 | MaxVGPRs = std::min(a: ST.getAddressableNumArchVGPRs(), b: NumVGPRs); |
| 401 | if (UnifiedRF) { |
| 402 | unsigned DynamicVGPRBlockSize = |
| 403 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 404 | MaxUnifiedVGPRs = |
| 405 | std::min(a: ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), b: NumVGPRs); |
| 406 | } else { |
| 407 | MaxUnifiedVGPRs = 0; |
| 408 | } |
| 409 | } |
| 410 | |
| 411 | bool GCNRPTarget::isSaveBeneficial(Register Reg) const { |
| 412 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 413 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 414 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 415 | const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI); |
| 416 | |
| 417 | RegExcess Excess(MF, RP, *this); |
| 418 | |
| 419 | if (SRI->isSGPRClass(RC)) |
| 420 | return Excess.SGPR; |
| 421 | |
| 422 | if (SRI->isAGPRClass(RC)) |
| 423 | return (UnifiedRF && Excess.VGPR) || Excess.AGPR; |
| 424 | |
| 425 | return (UnifiedRF && Excess.VGPR) || Excess.ArchVGPR; |
| 426 | } |
| 427 | |
| 428 | bool GCNRPTarget::isSaveBeneficial(const GCNRegPressure &SaveRP) const { |
| 429 | RegExcess Excess(MF, RP, *this); |
| 430 | if (SaveRP.getSGPRNum() != 0 && Excess.SGPR != 0) |
| 431 | return true; |
| 432 | if (SaveRP.getArchVGPRNum() != 0 && Excess.ArchVGPR != 0) |
| 433 | return true; |
| 434 | if (SaveRP.getAGPRNum() != 0 && Excess.AGPR != 0) |
| 435 | return true; |
| 436 | if (UnifiedRF && Excess.VGPR != 0) |
| 437 | return SaveRP.getArchVGPRNum() != 0 || SaveRP.getAGPRNum() != 0; |
| 438 | return false; |
| 439 | } |
| 440 | |
| 441 | unsigned GCNRPTarget::getNumRegsBenefit(const GCNRegPressure &SaveRP) const { |
| 442 | RegExcess Excess(MF, RP, *this); |
| 443 | const unsigned NumVGPRAboveAddrLimit = |
| 444 | std::min(a: Excess.ArchVGPR, b: SaveRP.getArchVGPRNum()) + |
| 445 | std::min(a: Excess.AGPR, b: SaveRP.getAGPRNum()); |
| 446 | unsigned NumRegsSaved = |
| 447 | std::min(a: Excess.SGPR, b: SaveRP.getSGPRNum()) + NumVGPRAboveAddrLimit; |
| 448 | |
| 449 | if (UnifiedRF && Excess.VGPR) { |
| 450 | // We have already accounted for excess pressure above addressive limits for |
| 451 | // the individual VGPR classes. However for targets with unified RFs there |
| 452 | // is also a unified VGPR pressure (ArchVGPR + AGPR combination) limit to |
| 453 | // honor that may be more restrictive that the per-VGPR-class limits. We |
| 454 | // must also be careful not to double-count VGPR saves that may contribute |
| 455 | // to lowering pressure both above the addressable limit in their respective |
| 456 | // class as well as in the unified VGPR limit. |
| 457 | const unsigned VGPRSave = SaveRP.getArchVGPRNum() + SaveRP.getAGPRNum(); |
| 458 | if (NumVGPRAboveAddrLimit < VGPRSave) |
| 459 | NumRegsSaved += std::min(a: Excess.VGPR, b: VGPRSave - NumVGPRAboveAddrLimit); |
| 460 | } |
| 461 | |
| 462 | return NumRegsSaved; |
| 463 | } |
| 464 | |
| 465 | bool GCNRPTarget::satisfied(const GCNRegPressure &TestRP) const { |
| 466 | if (TestRP.getSGPRNum() > MaxSGPRs || TestRP.getVGPRNum(UnifiedVGPRFile: false) > MaxVGPRs) |
| 467 | return false; |
| 468 | if (UnifiedRF && TestRP.getVGPRNum(UnifiedVGPRFile: true) > MaxUnifiedVGPRs) |
| 469 | return false; |
| 470 | return true; |
| 471 | } |
| 472 | |
| 473 | bool GCNRPTarget::hasVectorRegisterExcess() const { |
| 474 | RegExcess Excess(MF, RP, *this); |
| 475 | return Excess.hasVectorRegisterExcess(); |
| 476 | } |
| 477 | |
| 478 | /////////////////////////////////////////////////////////////////////////////// |
| 479 | // GCNRPTracker |
| 480 | |
| 481 | LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, |
| 482 | const LiveIntervals &LIS, |
| 483 | const MachineRegisterInfo &MRI, |
| 484 | LaneBitmask LaneMaskFilter) { |
| 485 | return getLiveLaneMask(LI: LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); |
| 486 | } |
| 487 | |
| 488 | LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, |
| 489 | const MachineRegisterInfo &MRI, |
| 490 | LaneBitmask LaneMaskFilter) { |
| 491 | LaneBitmask LiveMask; |
| 492 | if (LI.hasSubRanges()) { |
| 493 | for (const auto &S : LI.subranges()) |
| 494 | if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(index: SI)) { |
| 495 | LiveMask |= S.LaneMask; |
| 496 | assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); |
| 497 | } |
| 498 | } else if (LI.liveAt(index: SI)) { |
| 499 | LiveMask = MRI.getMaxLaneMaskForVReg(Reg: LI.reg()); |
| 500 | } |
| 501 | LiveMask &= LaneMaskFilter; |
| 502 | return LiveMask; |
| 503 | } |
| 504 | |
| 505 | GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, |
| 506 | const LiveIntervals &LIS, |
| 507 | const MachineRegisterInfo &MRI, |
| 508 | GCNRegPressure::RegKind RegKind) { |
| 509 | GCNRPTracker::LiveRegSet LiveRegs; |
| 510 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 511 | auto Reg = Register::index2VirtReg(Index: I); |
| 512 | if (RegKind != GCNRegPressure::TOTAL_KINDS && |
| 513 | GCNRegPressure::getRegKind(Reg, MRI) != RegKind) |
| 514 | continue; |
| 515 | if (!LIS.hasInterval(Reg)) |
| 516 | continue; |
| 517 | auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); |
| 518 | if (LiveMask.any()) |
| 519 | LiveRegs[Reg] = LiveMask; |
| 520 | } |
| 521 | return LiveRegs; |
| 522 | } |
| 523 | |
| 524 | void GCNRPTracker::reset(const MachineInstr &MI, bool After) { |
| 525 | const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo(); |
| 526 | if (!MI.isDebugInstr()) { |
| 527 | SlotIndex SI = LIS.getInstructionIndex(Instr: MI); |
| 528 | if (After) |
| 529 | SI = SI.getDeadSlot(); |
| 530 | reset(MRI, SI); |
| 531 | return; |
| 532 | } |
| 533 | |
| 534 | // Look for the first valid index after the provided debug MI. |
| 535 | MachineBasicBlock::const_iterator It = MI.getIterator(), |
| 536 | MBBEnd = MI.getParent()->end(); |
| 537 | MachineBasicBlock::const_iterator NonDbgMI = |
| 538 | skipDebugInstructionsForward(It, End: MBBEnd); |
| 539 | if (NonDbgMI == MBBEnd) { |
| 540 | // There are no non-debug instruction between MI and the end of the |
| 541 | // block, so we reset the tracker at the end of the block. |
| 542 | reset(MBB: *MI.getParent(), /*End=*/true); |
| 543 | return; |
| 544 | } |
| 545 | // MI is a debug instruction so register pressure before or after it is |
| 546 | // identical. Since we moved forward to finding a non-debug instruction |
| 547 | // in the block, we reset the tracker before that instruction i.e., at its |
| 548 | // base index. |
| 549 | reset(MRI, SI: LIS.getInstructionIndex(Instr: *NonDbgMI)); |
| 550 | } |
| 551 | |
| 552 | void GCNRPTracker::reset(const MachineBasicBlock &MBB, bool End) { |
| 553 | SlotIndex SI = End ? LIS.getSlotIndexes()->getMBBLastIdx(MBB: &MBB) |
| 554 | : LIS.getMBBStartIdx(mbb: &MBB); |
| 555 | reset(MRI: MBB.getParent()->getRegInfo(), SI); |
| 556 | } |
| 557 | |
| 558 | void GCNRPTracker::reset(const MachineRegisterInfo &MRI, SlotIndex SI) { |
| 559 | this->MRI = &MRI; |
| 560 | LastTrackedMI = nullptr; |
| 561 | LiveRegs = llvm::getLiveRegs(SI, LIS, MRI); |
| 562 | MaxPressure = CurPressure = getRegPressure(MRI, LiveRegs); |
| 563 | } |
| 564 | |
| 565 | void GCNRPTracker::reset(const MachineRegisterInfo &MRI, |
| 566 | const LiveRegSet &LiveRegs) { |
| 567 | this->MRI = &MRI; |
| 568 | LastTrackedMI = nullptr; |
| 569 | if (&this->LiveRegs != &LiveRegs) |
| 570 | this->LiveRegs = LiveRegs; |
| 571 | MaxPressure = CurPressure = getRegPressure(MRI, LiveRegs); |
| 572 | } |
| 573 | |
| 574 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 575 | LaneBitmask GCNRPTracker::getLastUsedLanes(Register Reg, SlotIndex Pos) const { |
| 576 | return getLanesWithProperty( |
| 577 | LIS, MRI: *MRI, TrackLaneMasks: true, Reg, Pos: Pos.getBaseIndex(), |
| 578 | Property: [](const LiveRange &LR, SlotIndex Pos) { |
| 579 | const LiveRange::Segment *S = LR.getSegmentContaining(Idx: Pos); |
| 580 | return S != nullptr && S->end == Pos.getRegSlot(); |
| 581 | }); |
| 582 | } |
| 583 | |
| 584 | //////////////////////////////////////////////////////////////////////////////// |
| 585 | // GCNUpwardRPTracker |
| 586 | |
| 587 | void GCNUpwardRPTracker::recede(const MachineInstr &MI) { |
| 588 | assert(MRI && "call reset first" ); |
| 589 | |
| 590 | LastTrackedMI = &MI; |
| 591 | |
| 592 | if (MI.isDebugInstr()) |
| 593 | return; |
| 594 | |
| 595 | // Kill all defs. |
| 596 | GCNRegPressure DefPressure, ECDefPressure; |
| 597 | bool HasECDefs = false; |
| 598 | for (const MachineOperand &MO : MI.all_defs()) { |
| 599 | if (!MO.getReg().isVirtual()) |
| 600 | continue; |
| 601 | |
| 602 | Register Reg = MO.getReg(); |
| 603 | LaneBitmask DefMask = getDefRegMask(MO, MRI: *MRI); |
| 604 | |
| 605 | // Treat a def as fully live at the moment of definition: keep a record. |
| 606 | if (MO.isEarlyClobber()) { |
| 607 | ECDefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 608 | HasECDefs = true; |
| 609 | } else |
| 610 | DefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 611 | |
| 612 | auto I = LiveRegs.find(Val: Reg); |
| 613 | if (I == LiveRegs.end()) |
| 614 | continue; |
| 615 | |
| 616 | LaneBitmask &LiveMask = I->second; |
| 617 | LaneBitmask PrevMask = LiveMask; |
| 618 | LiveMask &= ~DefMask; |
| 619 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 620 | if (LiveMask.none()) |
| 621 | LiveRegs.erase(I); |
| 622 | } |
| 623 | |
| 624 | // Update MaxPressure with defs pressure. |
| 625 | DefPressure += CurPressure; |
| 626 | if (HasECDefs) |
| 627 | DefPressure += ECDefPressure; |
| 628 | MaxPressure = max(P1: DefPressure, P2: MaxPressure); |
| 629 | |
| 630 | // Make uses alive. |
| 631 | SmallVector<VRegMaskOrUnit, 8> RegUses; |
| 632 | collectVirtualRegUses(VRegMaskOrUnits&: RegUses, MI, LIS, MRI: *MRI); |
| 633 | for (const VRegMaskOrUnit &U : RegUses) { |
| 634 | LaneBitmask &LiveMask = LiveRegs[U.VRegOrUnit.asVirtualReg()]; |
| 635 | LaneBitmask PrevMask = LiveMask; |
| 636 | LiveMask |= U.LaneMask; |
| 637 | CurPressure.inc(Reg: U.VRegOrUnit.asVirtualReg(), PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 638 | } |
| 639 | |
| 640 | // Update MaxPressure with uses plus early-clobber defs pressure. |
| 641 | MaxPressure = HasECDefs ? max(P1: CurPressure + ECDefPressure, P2: MaxPressure) |
| 642 | : max(P1: CurPressure, P2: MaxPressure); |
| 643 | |
| 644 | assert(CurPressure == getRegPressure(*MRI, LiveRegs)); |
| 645 | } |
| 646 | |
| 647 | //////////////////////////////////////////////////////////////////////////////// |
| 648 | // GCNDownwardRPTracker |
| 649 | |
| 650 | bool GCNDownwardRPTracker::reset(const MachineInstr &MI, |
| 651 | MachineBasicBlock::const_iterator End, |
| 652 | const LiveRegSet *LiveRegsCopy) { |
| 653 | MBBEnd = MI.getParent()->end(); |
| 654 | assert(End == MBBEnd || |
| 655 | End->getParent()->end() == MBBEnd && "end unrelated to MI block" ); |
| 656 | NextMI = &MI; |
| 657 | NextMI = skipDebugInstructionsForward(It: NextMI, End); |
| 658 | |
| 659 | // Do not use the MI to compute live registers when a set is provided. |
| 660 | // Otherwise the first non-debug instruction after the provided one (or the |
| 661 | // end of the block, if no such instruction exists) serves as the basis to |
| 662 | // compute a live register set. |
| 663 | if (LiveRegsCopy) |
| 664 | GCNRPTracker::reset(MRI: MI.getMF()->getRegInfo(), LiveRegs: *LiveRegsCopy); |
| 665 | else if (NextMI != MBBEnd) |
| 666 | GCNRPTracker::reset(MI: *NextMI, /*After=*/false); |
| 667 | else |
| 668 | GCNRPTracker::reset(MBB: *MI.getParent(), /*End=*/true); |
| 669 | return NextMI != End; |
| 670 | } |
| 671 | |
| 672 | bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, |
| 673 | bool UseInternalIterator) { |
| 674 | assert(MRI && "call reset first" ); |
| 675 | SlotIndex SI; |
| 676 | const MachineInstr *CurrMI; |
| 677 | if (UseInternalIterator) { |
| 678 | if (!LastTrackedMI) |
| 679 | return NextMI == MBBEnd; |
| 680 | |
| 681 | assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); |
| 682 | CurrMI = LastTrackedMI; |
| 683 | |
| 684 | SI = NextMI == MBBEnd |
| 685 | ? LIS.getInstructionIndex(Instr: *LastTrackedMI).getDeadSlot() |
| 686 | : LIS.getInstructionIndex(Instr: *NextMI).getBaseIndex(); |
| 687 | } else { //! UseInternalIterator |
| 688 | SI = LIS.getInstructionIndex(Instr: *MI).getBaseIndex(); |
| 689 | CurrMI = MI; |
| 690 | } |
| 691 | |
| 692 | assert(SI.isValid()); |
| 693 | |
| 694 | // Remove dead registers or mask bits. |
| 695 | SmallSet<Register, 8> SeenRegs; |
| 696 | for (auto &MO : CurrMI->operands()) { |
| 697 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 698 | continue; |
| 699 | if (MO.isUse() && !MO.readsReg()) |
| 700 | continue; |
| 701 | if (!UseInternalIterator && MO.isDef()) |
| 702 | continue; |
| 703 | if (!SeenRegs.insert(V: MO.getReg()).second) |
| 704 | continue; |
| 705 | const LiveInterval &LI = LIS.getInterval(Reg: MO.getReg()); |
| 706 | if (LI.hasSubRanges()) { |
| 707 | auto It = LiveRegs.end(); |
| 708 | for (const auto &S : LI.subranges()) { |
| 709 | if (!S.liveAt(index: SI)) { |
| 710 | if (It == LiveRegs.end()) { |
| 711 | It = LiveRegs.find(Val: MO.getReg()); |
| 712 | if (It == LiveRegs.end()) |
| 713 | llvm_unreachable("register isn't live" ); |
| 714 | } |
| 715 | auto PrevMask = It->second; |
| 716 | It->second &= ~S.LaneMask; |
| 717 | CurPressure.inc(Reg: MO.getReg(), PrevMask, NewMask: It->second, MRI: *MRI); |
| 718 | } |
| 719 | } |
| 720 | if (It != LiveRegs.end() && It->second.none()) |
| 721 | LiveRegs.erase(I: It); |
| 722 | } else if (!LI.liveAt(index: SI)) { |
| 723 | auto It = LiveRegs.find(Val: MO.getReg()); |
| 724 | if (It == LiveRegs.end()) |
| 725 | llvm_unreachable("register isn't live" ); |
| 726 | CurPressure.inc(Reg: MO.getReg(), PrevMask: It->second, NewMask: LaneBitmask::getNone(), MRI: *MRI); |
| 727 | LiveRegs.erase(I: It); |
| 728 | } |
| 729 | } |
| 730 | |
| 731 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 732 | |
| 733 | LastTrackedMI = nullptr; |
| 734 | |
| 735 | return UseInternalIterator && (NextMI == MBBEnd); |
| 736 | } |
| 737 | |
| 738 | void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, |
| 739 | bool UseInternalIterator) { |
| 740 | if (UseInternalIterator) { |
| 741 | LastTrackedMI = &*NextMI++; |
| 742 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
| 743 | } else { |
| 744 | LastTrackedMI = MI; |
| 745 | } |
| 746 | |
| 747 | const MachineInstr *CurrMI = LastTrackedMI; |
| 748 | |
| 749 | // Add new registers or mask bits. |
| 750 | for (const auto &MO : CurrMI->all_defs()) { |
| 751 | Register Reg = MO.getReg(); |
| 752 | if (!Reg.isVirtual()) |
| 753 | continue; |
| 754 | auto &LiveMask = LiveRegs[Reg]; |
| 755 | auto PrevMask = LiveMask; |
| 756 | LiveMask |= getDefRegMask(MO, MRI: *MRI); |
| 757 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 758 | } |
| 759 | |
| 760 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 761 | } |
| 762 | |
| 763 | bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) { |
| 764 | if (UseInternalIterator && NextMI == MBBEnd) |
| 765 | return false; |
| 766 | |
| 767 | advanceBeforeNext(MI, UseInternalIterator); |
| 768 | advanceToNext(MI, UseInternalIterator); |
| 769 | if (!UseInternalIterator) { |
| 770 | // We must remove any dead def lanes from the current RP |
| 771 | advanceBeforeNext(MI, UseInternalIterator: true); |
| 772 | } |
| 773 | return true; |
| 774 | } |
| 775 | |
| 776 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { |
| 777 | bool AnyAdvance = false; |
| 778 | while (NextMI != End && advance()) |
| 779 | AnyAdvance = true; |
| 780 | return AnyAdvance; |
| 781 | } |
| 782 | |
| 783 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, |
| 784 | MachineBasicBlock::const_iterator End, |
| 785 | const LiveRegSet *LiveRegsCopy) { |
| 786 | if (!reset(MI: *Begin, End, LiveRegsCopy)) |
| 787 | return false; |
| 788 | return advance(End); |
| 789 | } |
| 790 | |
| 791 | Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, |
| 792 | const GCNRPTracker::LiveRegSet &TrackedLR, |
| 793 | const TargetRegisterInfo *TRI, StringRef Pfx) { |
| 794 | return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { |
| 795 | for (auto const &P : TrackedLR) { |
| 796 | auto I = LISLR.find(Val: P.first); |
| 797 | if (I == LISLR.end()) { |
| 798 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 799 | << " isn't found in LIS reported set\n" ; |
| 800 | } else if (I->second != P.second) { |
| 801 | OS << Pfx << printReg(Reg: P.first, TRI) |
| 802 | << " masks doesn't match: LIS reported " << PrintLaneMask(LaneMask: I->second) |
| 803 | << ", tracked " << PrintLaneMask(LaneMask: P.second) << '\n'; |
| 804 | } |
| 805 | } |
| 806 | for (auto const &P : LISLR) { |
| 807 | auto I = TrackedLR.find(Val: P.first); |
| 808 | if (I == TrackedLR.end()) { |
| 809 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 810 | << " isn't found in tracked set\n" ; |
| 811 | } |
| 812 | } |
| 813 | }); |
| 814 | } |
| 815 | |
| 816 | GCNRegPressure |
| 817 | GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, |
| 818 | const SIRegisterInfo *TRI) const { |
| 819 | assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction." ); |
| 820 | |
| 821 | SlotIndex SlotIdx; |
| 822 | SlotIdx = LIS.getInstructionIndex(Instr: *MI).getRegSlot(); |
| 823 | |
| 824 | // Account for register pressure similar to RegPressureTracker::recede(). |
| 825 | RegisterOperands RegOpers; |
| 826 | RegOpers.collect(MI: *MI, TRI: *TRI, MRI: *MRI, TrackLaneMasks: true, /*IgnoreDead=*/false); |
| 827 | RegOpers.adjustLaneLiveness(LIS, MRI: *MRI, Pos: SlotIdx); |
| 828 | GCNRegPressure TempPressure = CurPressure; |
| 829 | |
| 830 | for (const VRegMaskOrUnit &Use : RegOpers.Uses) { |
| 831 | if (!Use.VRegOrUnit.isVirtualReg()) |
| 832 | continue; |
| 833 | Register Reg = Use.VRegOrUnit.asVirtualReg(); |
| 834 | LaneBitmask LastUseMask = getLastUsedLanes(Reg, Pos: SlotIdx); |
| 835 | if (LastUseMask.none()) |
| 836 | continue; |
| 837 | // The LastUseMask is queried from the liveness information of instruction |
| 838 | // which may be further down the schedule. Some lanes may actually not be |
| 839 | // last uses for the current position. |
| 840 | // FIXME: allow the caller to pass in the list of vreg uses that remain |
| 841 | // to be bottom-scheduled to avoid searching uses at each query. |
| 842 | SlotIndex CurrIdx; |
| 843 | const MachineBasicBlock *MBB = MI->getParent(); |
| 844 | MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( |
| 845 | It: LastTrackedMI ? LastTrackedMI : MBB->begin(), End: MBB->end()); |
| 846 | if (IdxPos == MBB->end()) { |
| 847 | CurrIdx = LIS.getMBBEndIdx(mbb: MBB); |
| 848 | } else { |
| 849 | CurrIdx = LIS.getInstructionIndex(Instr: *IdxPos).getRegSlot(); |
| 850 | } |
| 851 | |
| 852 | LastUseMask = |
| 853 | findUseBetween(Reg, LastUseMask, PriorUseIdx: CurrIdx, NextUseIdx: SlotIdx, MRI: *MRI, TRI, LIS: &LIS); |
| 854 | if (LastUseMask.none()) |
| 855 | continue; |
| 856 | |
| 857 | auto It = LiveRegs.find(Val: Reg); |
| 858 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 859 | LaneBitmask NewMask = LiveMask & ~LastUseMask; |
| 860 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 861 | } |
| 862 | |
| 863 | // Generate liveness for defs. |
| 864 | for (const VRegMaskOrUnit &Def : RegOpers.Defs) { |
| 865 | if (!Def.VRegOrUnit.isVirtualReg()) |
| 866 | continue; |
| 867 | Register Reg = Def.VRegOrUnit.asVirtualReg(); |
| 868 | auto It = LiveRegs.find(Val: Reg); |
| 869 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 870 | LaneBitmask NewMask = LiveMask | Def.LaneMask; |
| 871 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 872 | } |
| 873 | |
| 874 | return TempPressure; |
| 875 | } |
| 876 | |
| 877 | bool GCNUpwardRPTracker::isValid() const { |
| 878 | const auto &SI = LIS.getInstructionIndex(Instr: *LastTrackedMI).getBaseIndex(); |
| 879 | const auto LISLR = llvm::getLiveRegs(SI, LIS, MRI: *MRI); |
| 880 | const auto &TrackedLR = LiveRegs; |
| 881 | |
| 882 | if (!isEqual(S1: LISLR, S2: TrackedLR)) { |
| 883 | dbgs() << "\nGCNUpwardRPTracker error: Tracked and" |
| 884 | " LIS reported livesets mismatch:\n" |
| 885 | << print(LiveRegs: LISLR, MRI: *MRI); |
| 886 | reportMismatch(LISLR, TrackedLR, TRI: MRI->getTargetRegisterInfo()); |
| 887 | return false; |
| 888 | } |
| 889 | |
| 890 | auto LISPressure = getRegPressure(MRI: *MRI, LiveRegs: LISLR); |
| 891 | if (LISPressure != CurPressure) { |
| 892 | dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " |
| 893 | << print(RP: CurPressure) << "LIS rpt: " << print(RP: LISPressure); |
| 894 | return false; |
| 895 | } |
| 896 | return true; |
| 897 | } |
| 898 | |
| 899 | Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, |
| 900 | const MachineRegisterInfo &MRI) { |
| 901 | return Printable([&LiveRegs, &MRI](raw_ostream &OS) { |
| 902 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 903 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 904 | Register Reg = Register::index2VirtReg(Index: I); |
| 905 | auto It = LiveRegs.find(Val: Reg); |
| 906 | if (It != LiveRegs.end() && It->second.any()) |
| 907 | OS << ' ' << printReg(Reg, TRI) << ':' << PrintLaneMask(LaneMask: It->second); |
| 908 | } |
| 909 | OS << '\n'; |
| 910 | }); |
| 911 | } |
| 912 | |
| 913 | void GCNRegPressure::dump() const { dbgs() << print(RP: *this); } |
| 914 | |
| 915 | static cl::opt<bool> UseDownwardTracker( |
| 916 | "amdgpu-print-rp-downward" , |
| 917 | cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass" ), |
| 918 | cl::init(Val: false), cl::Hidden); |
| 919 | |
| 920 | char llvm::GCNRegPressurePrinter::ID = 0; |
| 921 | char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; |
| 922 | |
| 923 | INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp" , "" , true, true) |
| 924 | |
| 925 | // Return lanemask of Reg's subregs that are live-through at [Begin, End] and |
| 926 | // are fully covered by Mask. |
| 927 | static LaneBitmask |
| 928 | getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, |
| 929 | Register Reg, SlotIndex Begin, SlotIndex End, |
| 930 | LaneBitmask Mask = LaneBitmask::getAll()) { |
| 931 | |
| 932 | auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { |
| 933 | auto *Segment = LR.getSegmentContaining(Idx: Begin); |
| 934 | return Segment && Segment->contains(I: End); |
| 935 | }; |
| 936 | |
| 937 | LaneBitmask LiveThroughMask; |
| 938 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 939 | if (LI.hasSubRanges()) { |
| 940 | for (auto &SR : LI.subranges()) { |
| 941 | if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) |
| 942 | LiveThroughMask |= SR.LaneMask; |
| 943 | } |
| 944 | } else { |
| 945 | LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); |
| 946 | if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) |
| 947 | LiveThroughMask = RegMask; |
| 948 | } |
| 949 | |
| 950 | return LiveThroughMask; |
| 951 | } |
| 952 | |
| 953 | bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { |
| 954 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 955 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 956 | const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); |
| 957 | |
| 958 | auto &OS = dbgs(); |
| 959 | |
| 960 | // Leading spaces are important for YAML syntax. |
| 961 | #define PFX " " |
| 962 | |
| 963 | OS << "---\nname: " << MF.getName() << "\nbody: |\n" ; |
| 964 | |
| 965 | auto printRP = [](const GCNRegPressure &RP) { |
| 966 | return Printable([&RP](raw_ostream &OS) { |
| 967 | OS << format(PFX " %-5d" , Vals: RP.getSGPRNum()) |
| 968 | << format(Fmt: " %-5d" , Vals: RP.getVGPRNum(UnifiedVGPRFile: false)); |
| 969 | }); |
| 970 | }; |
| 971 | |
| 972 | auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, |
| 973 | const GCNRPTracker::LiveRegSet &LISLR) { |
| 974 | if (LISLR != TrackedLR) { |
| 975 | OS << PFX " mis LIS: " << llvm::print(LiveRegs: LISLR, MRI) |
| 976 | << reportMismatch(LISLR, TrackedLR, TRI, PFX " " ); |
| 977 | } |
| 978 | }; |
| 979 | |
| 980 | // Register pressure before and at an instruction (in program order). |
| 981 | SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; |
| 982 | |
| 983 | for (auto &MBB : MF) { |
| 984 | RP.clear(); |
| 985 | RP.reserve(N: MBB.size()); |
| 986 | |
| 987 | OS << PFX; |
| 988 | MBB.printName(os&: OS); |
| 989 | OS << ":\n" ; |
| 990 | |
| 991 | SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(mbb: &MBB); |
| 992 | SlotIndex MBBLastSlot = LIS.getSlotIndexes()->getMBBLastIdx(MBB: &MBB); |
| 993 | |
| 994 | GCNRPTracker::LiveRegSet LiveIn, LiveOut; |
| 995 | GCNRegPressure RPAtMBBEnd; |
| 996 | |
| 997 | if (UseDownwardTracker) { |
| 998 | if (MBB.empty()) { |
| 999 | LiveIn = LiveOut = getLiveRegs(SI: MBBStartSlot, LIS, MRI); |
| 1000 | RPAtMBBEnd = getRegPressure(MRI, LiveRegs&: LiveIn); |
| 1001 | } else { |
| 1002 | GCNDownwardRPTracker RPT(LIS); |
| 1003 | RPT.reset(MI: MBB.front(), End: MBB.end()); |
| 1004 | |
| 1005 | LiveIn = RPT.getLiveRegs(); |
| 1006 | |
| 1007 | while (!RPT.advanceBeforeNext()) { |
| 1008 | GCNRegPressure RPBeforeMI = RPT.getPressure(); |
| 1009 | RPT.advanceToNext(); |
| 1010 | RP.emplace_back(Args&: RPBeforeMI, Args: RPT.getPressure()); |
| 1011 | } |
| 1012 | |
| 1013 | LiveOut = RPT.getLiveRegs(); |
| 1014 | RPAtMBBEnd = RPT.getPressure(); |
| 1015 | } |
| 1016 | } else { |
| 1017 | GCNUpwardRPTracker RPT(LIS); |
| 1018 | RPT.reset(MRI, SI: MBBLastSlot); |
| 1019 | |
| 1020 | LiveOut = RPT.getLiveRegs(); |
| 1021 | RPAtMBBEnd = RPT.getPressure(); |
| 1022 | |
| 1023 | for (auto &MI : reverse(C&: MBB)) { |
| 1024 | RPT.resetMaxPressure(); |
| 1025 | RPT.recede(MI); |
| 1026 | if (!MI.isDebugInstr()) |
| 1027 | RP.emplace_back(Args: RPT.getPressure(), Args: RPT.getMaxPressure()); |
| 1028 | } |
| 1029 | |
| 1030 | LiveIn = RPT.getLiveRegs(); |
| 1031 | } |
| 1032 | |
| 1033 | OS << PFX " Live-in: " << llvm::print(LiveRegs: LiveIn, MRI); |
| 1034 | if (!UseDownwardTracker) |
| 1035 | ReportLISMismatchIfAny(LiveIn, getLiveRegs(SI: MBBStartSlot, LIS, MRI)); |
| 1036 | |
| 1037 | OS << PFX " SGPR VGPR\n" ; |
| 1038 | int I = 0; |
| 1039 | for (auto &MI : MBB) { |
| 1040 | if (!MI.isDebugInstr()) { |
| 1041 | auto &[RPBeforeInstr, RPAtInstr] = |
| 1042 | RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; |
| 1043 | ++I; |
| 1044 | OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " " ; |
| 1045 | } else |
| 1046 | OS << PFX " " ; |
| 1047 | MI.print(OS); |
| 1048 | } |
| 1049 | OS << printRP(RPAtMBBEnd) << '\n'; |
| 1050 | |
| 1051 | OS << PFX " Live-out:" << llvm::print(LiveRegs: LiveOut, MRI); |
| 1052 | if (UseDownwardTracker) |
| 1053 | ReportLISMismatchIfAny(LiveOut, getLiveRegs(SI: MBBLastSlot, LIS, MRI)); |
| 1054 | |
| 1055 | GCNRPTracker::LiveRegSet LiveThrough; |
| 1056 | for (auto [Reg, Mask] : LiveIn) { |
| 1057 | LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Val: Reg); |
| 1058 | if (MaskIntersection.any()) { |
| 1059 | LaneBitmask LTMask = getRegLiveThroughMask( |
| 1060 | MRI, LIS, Reg, Begin: MBBStartSlot, End: MBBLastSlot, Mask: MaskIntersection); |
| 1061 | if (LTMask.any()) |
| 1062 | LiveThrough[Reg] = LTMask; |
| 1063 | } |
| 1064 | } |
| 1065 | OS << PFX " Live-thr:" << llvm::print(LiveRegs: LiveThrough, MRI); |
| 1066 | OS << printRP(getRegPressure(MRI, LiveRegs&: LiveThrough)) << '\n'; |
| 1067 | } |
| 1068 | OS << "...\n" ; |
| 1069 | return false; |
| 1070 | |
| 1071 | #undef PFX |
| 1072 | } |
| 1073 | |
| 1074 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 1075 | LLVM_DUMP_METHOD void llvm::dumpMaxRegPressure(MachineFunction &MF, |
| 1076 | GCNRegPressure::RegKind Kind, |
| 1077 | LiveIntervals &LIS, |
| 1078 | const MachineLoopInfo *MLI) { |
| 1079 | |
| 1080 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 1081 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 1082 | auto &OS = dbgs(); |
| 1083 | const char *RegName = GCNRegPressure::getName(Kind); |
| 1084 | |
| 1085 | unsigned MaxNumRegs = 0; |
| 1086 | const MachineInstr *MaxPressureMI = nullptr; |
| 1087 | GCNUpwardRPTracker RPT(LIS); |
| 1088 | for (const MachineBasicBlock &MBB : MF) { |
| 1089 | RPT.reset(MRI, LIS.getSlotIndexes()->getMBBEndIdx(&MBB).getPrevSlot()); |
| 1090 | for (const MachineInstr &MI : reverse(MBB)) { |
| 1091 | RPT.recede(MI); |
| 1092 | unsigned NumRegs = RPT.getMaxPressure().getNumRegs(Kind); |
| 1093 | if (NumRegs > MaxNumRegs) { |
| 1094 | MaxNumRegs = NumRegs; |
| 1095 | MaxPressureMI = &MI; |
| 1096 | } |
| 1097 | } |
| 1098 | } |
| 1099 | |
| 1100 | SlotIndex MISlot = LIS.getInstructionIndex(*MaxPressureMI); |
| 1101 | |
| 1102 | // Max pressure can occur at either the early-clobber or register slot. |
| 1103 | // Choose the maximum liveset between both slots. This is ugly but this is |
| 1104 | // diagnostic code. |
| 1105 | SlotIndex ECSlot = MISlot.getRegSlot(true); |
| 1106 | SlotIndex RSlot = MISlot.getRegSlot(false); |
| 1107 | GCNRPTracker::LiveRegSet ECLiveSet = getLiveRegs(ECSlot, LIS, MRI, Kind); |
| 1108 | GCNRPTracker::LiveRegSet RLiveSet = getLiveRegs(RSlot, LIS, MRI, Kind); |
| 1109 | unsigned ECNumRegs = getRegPressure(MRI, ECLiveSet).getNumRegs(Kind); |
| 1110 | unsigned RNumRegs = getRegPressure(MRI, RLiveSet).getNumRegs(Kind); |
| 1111 | GCNRPTracker::LiveRegSet *LiveSet = |
| 1112 | ECNumRegs > RNumRegs ? &ECLiveSet : &RLiveSet; |
| 1113 | SlotIndex MaxPressureSlot = ECNumRegs > RNumRegs ? ECSlot : RSlot; |
| 1114 | assert(getRegPressure(MRI, *LiveSet).getNumRegs(Kind) == MaxNumRegs); |
| 1115 | |
| 1116 | // Split live registers into single-def and multi-def sets. |
| 1117 | GCNRegPressure SDefPressure, MDefPressure; |
| 1118 | SmallVector<Register, 16> SDefRegs, MDefRegs; |
| 1119 | for (auto [Reg, LaneMask] : *LiveSet) { |
| 1120 | assert(GCNRegPressure::getRegKind(Reg, MRI) == Kind); |
| 1121 | LiveInterval &LI = LIS.getInterval(Reg); |
| 1122 | if (LI.getNumValNums() == 1 || |
| 1123 | (LI.hasSubRanges() && |
| 1124 | llvm::all_of(LI.subranges(), [](const LiveInterval::SubRange &SR) { |
| 1125 | return SR.getNumValNums() == 1; |
| 1126 | }))) { |
| 1127 | SDefPressure.inc(Reg, LaneBitmask::getNone(), LaneMask, MRI); |
| 1128 | SDefRegs.push_back(Reg); |
| 1129 | } else { |
| 1130 | MDefPressure.inc(Reg, LaneBitmask::getNone(), LaneMask, MRI); |
| 1131 | MDefRegs.push_back(Reg); |
| 1132 | } |
| 1133 | } |
| 1134 | unsigned SDefNumRegs = SDefPressure.getNumRegs(Kind); |
| 1135 | unsigned MDefNumRegs = MDefPressure.getNumRegs(Kind); |
| 1136 | assert(SDefNumRegs + MDefNumRegs == MaxNumRegs); |
| 1137 | |
| 1138 | auto printLoc = [&](const MachineBasicBlock *MBB, SlotIndex SI) { |
| 1139 | return Printable([&, MBB, SI](raw_ostream &OS) { |
| 1140 | OS << SI << ':' << printMBBReference(*MBB); |
| 1141 | if (MLI) |
| 1142 | if (const MachineLoop *ML = MLI->getLoopFor(MBB)) |
| 1143 | OS << " (LoopHdr " << printMBBReference(*ML->getHeader()) |
| 1144 | << ", Depth " << ML->getLoopDepth() << ")" ; |
| 1145 | }); |
| 1146 | }; |
| 1147 | |
| 1148 | auto PrintRegInfo = [&](Register Reg, LaneBitmask LiveMask) { |
| 1149 | GCNRegPressure RegPressure; |
| 1150 | RegPressure.inc(Reg, LaneBitmask::getNone(), LiveMask, MRI); |
| 1151 | OS << " " << printReg(Reg, TRI) << ':' |
| 1152 | << TRI->getRegClassName(MRI.getRegClass(Reg)) << ", LiveMask " |
| 1153 | << PrintLaneMask(LiveMask) << " (" << RegPressure.getNumRegs(Kind) << ' ' |
| 1154 | << RegName << "s)\n" ; |
| 1155 | |
| 1156 | // Use std::map to sort def/uses by SlotIndex. |
| 1157 | std::map<SlotIndex, const MachineInstr *> Instrs; |
| 1158 | for (const MachineInstr &MI : MRI.reg_nodbg_instructions(Reg)) { |
| 1159 | Instrs[LIS.getInstructionIndex(MI).getRegSlot()] = &MI; |
| 1160 | } |
| 1161 | |
| 1162 | for (const auto &[SI, MI] : Instrs) { |
| 1163 | OS << " " ; |
| 1164 | if (MI->definesRegister(Reg, TRI)) |
| 1165 | OS << "def " ; |
| 1166 | if (MI->readsRegister(Reg, TRI)) |
| 1167 | OS << "use " ; |
| 1168 | OS << printLoc(MI->getParent(), SI) << ": " << *MI; |
| 1169 | } |
| 1170 | }; |
| 1171 | |
| 1172 | OS << "\n*** Register pressure info (" << RegName << "s) for " << MF.getName() |
| 1173 | << " ***\n" ; |
| 1174 | OS << "Max pressure is " << MaxNumRegs << ' ' << RegName << "s at " |
| 1175 | << printLoc(MaxPressureMI->getParent(), MaxPressureSlot) << ": " |
| 1176 | << *MaxPressureMI; |
| 1177 | |
| 1178 | OS << "\nLive registers with single definition (" << SDefNumRegs << ' ' |
| 1179 | << RegName << "s):\n" ; |
| 1180 | |
| 1181 | // Sort SDefRegs by number of uses (smallest first) |
| 1182 | llvm::sort(SDefRegs, [&](Register A, Register B) { |
| 1183 | return std::distance(MRI.use_nodbg_begin(A), MRI.use_nodbg_end()) < |
| 1184 | std::distance(MRI.use_nodbg_begin(B), MRI.use_nodbg_end()); |
| 1185 | }); |
| 1186 | |
| 1187 | for (const Register Reg : SDefRegs) { |
| 1188 | PrintRegInfo(Reg, LiveSet->lookup(Reg)); |
| 1189 | } |
| 1190 | |
| 1191 | OS << "\nLive registers with multiple definitions (" << MDefNumRegs << ' ' |
| 1192 | << RegName << "s):\n" ; |
| 1193 | for (const Register Reg : MDefRegs) { |
| 1194 | PrintRegInfo(Reg, LiveSet->lookup(Reg)); |
| 1195 | } |
| 1196 | } |
| 1197 | #endif |
| 1198 | |