| 1 | //===- GCNRegPressure.cpp -------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements the GCNRegPressure class. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "GCNRegPressure.h" |
| 15 | #include "AMDGPU.h" |
| 16 | #include "SIMachineFunctionInfo.h" |
| 17 | #include "llvm/CodeGen/MachineLoopInfo.h" |
| 18 | #include "llvm/CodeGen/RegisterPressure.h" |
| 19 | |
| 20 | using namespace llvm; |
| 21 | |
| 22 | #define DEBUG_TYPE "machine-scheduler" |
| 23 | |
| 24 | bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, |
| 25 | const GCNRPTracker::LiveRegSet &S2) { |
| 26 | if (S1.size() != S2.size()) |
| 27 | return false; |
| 28 | |
| 29 | for (const auto &P : S1) { |
| 30 | auto I = S2.find(Val: P.first); |
| 31 | if (I == S2.end() || I->second != P.second) |
| 32 | return false; |
| 33 | } |
| 34 | return true; |
| 35 | } |
| 36 | |
| 37 | /////////////////////////////////////////////////////////////////////////////// |
| 38 | // GCNRegPressure |
| 39 | |
| 40 | unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC, |
| 41 | const SIRegisterInfo *STI) { |
| 42 | return STI->isSGPRClass(RC) |
| 43 | ? SGPR |
| 44 | : (STI->isAGPRClass(RC) |
| 45 | ? AGPR |
| 46 | : (STI->isVectorSuperClass(RC) ? AVGPR : VGPR)); |
| 47 | } |
| 48 | |
| 49 | void GCNRegPressure::inc(unsigned Reg, |
| 50 | LaneBitmask PrevMask, |
| 51 | LaneBitmask NewMask, |
| 52 | const MachineRegisterInfo &MRI) { |
| 53 | unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: NewMask); |
| 54 | unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: PrevMask); |
| 55 | if (NewNumCoveredRegs == PrevNumCoveredRegs) |
| 56 | return; |
| 57 | |
| 58 | int Sign = 1; |
| 59 | if (NewMask < PrevMask) { |
| 60 | std::swap(a&: NewMask, b&: PrevMask); |
| 61 | std::swap(a&: NewNumCoveredRegs, b&: PrevNumCoveredRegs); |
| 62 | Sign = -1; |
| 63 | } |
| 64 | assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs && |
| 65 | "prev mask should always be lesser than new" ); |
| 66 | |
| 67 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 68 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 69 | const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI); |
| 70 | unsigned RegKind = getRegKind(RC, STI); |
| 71 | if (TRI->getRegSizeInBits(RC: *RC) != 32) { |
| 72 | // Reg is from a tuple register class. |
| 73 | if (PrevMask.none()) { |
| 74 | unsigned TupleIdx = TOTAL_KINDS + RegKind; |
| 75 | Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight; |
| 76 | } |
| 77 | // Pressure scales with number of new registers covered by the new mask. |
| 78 | // Note when true16 is enabled, we can no longer safely use the following |
| 79 | // approach to calculate the difference in the number of 32-bit registers |
| 80 | // between two masks: |
| 81 | // |
| 82 | // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); |
| 83 | // |
| 84 | // The issue is that the mask calculation `~PrevMask & NewMask` doesn't |
| 85 | // properly account for partial usage of a 32-bit register when dealing with |
| 86 | // 16-bit registers. |
| 87 | // |
| 88 | // Consider this example: |
| 89 | // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register |
| 90 | // usage difference should be 1, because even though PrevMask uses only half |
| 91 | // of a 32-bit register, it should still be counted as a full register use. |
| 92 | // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and |
| 93 | // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect |
| 94 | // calculation can lead to integer overflow when Sign = -1. |
| 95 | Sign *= NewNumCoveredRegs - PrevNumCoveredRegs; |
| 96 | } |
| 97 | Value[RegKind] += Sign; |
| 98 | } |
| 99 | |
| 100 | namespace { |
| 101 | struct RegExcess { |
| 102 | unsigned SGPR = 0; |
| 103 | unsigned VGPR = 0; |
| 104 | unsigned ArchVGPR = 0; |
| 105 | unsigned AGPR = 0; |
| 106 | |
| 107 | bool anyExcess() const { return SGPR || VGPR || ArchVGPR || AGPR; } |
| 108 | bool hasVectorRegisterExcess() const { return VGPR || ArchVGPR || AGPR; } |
| 109 | |
| 110 | RegExcess(const MachineFunction &MF, const GCNRegPressure &RP) |
| 111 | : RegExcess(MF, RP, GCNRPTarget(MF, RP)) {} |
| 112 | RegExcess(const MachineFunction &MF, const GCNRegPressure &RP, |
| 113 | const GCNRPTarget &Target) { |
| 114 | unsigned MaxSGPRs = Target.getMaxSGPRs(); |
| 115 | unsigned MaxVGPRs = Target.getMaxVGPRs(); |
| 116 | |
| 117 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 118 | SGPR = std::max(a: static_cast<int>(RP.getSGPRNum() - MaxSGPRs), b: 0); |
| 119 | |
| 120 | // The number of virtual VGPRs required to handle excess SGPR |
| 121 | unsigned WaveSize = ST.getWavefrontSize(); |
| 122 | unsigned VGPRForSGPRSpills = divideCeil(Numerator: SGPR, Denominator: WaveSize); |
| 123 | |
| 124 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
| 125 | |
| 126 | // Unified excess pressure conditions, accounting for VGPRs used for SGPR |
| 127 | // spills |
| 128 | VGPR = std::max(a: static_cast<int>(RP.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
| 129 | VGPRForSGPRSpills - MaxVGPRs), |
| 130 | b: 0); |
| 131 | |
| 132 | unsigned ArchVGPRLimit = ST.hasGFX90AInsts() ? MaxArchVGPRs : MaxVGPRs; |
| 133 | // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR |
| 134 | // spills |
| 135 | ArchVGPR = std::max(a: static_cast<int>(RP.getArchVGPRNum() + |
| 136 | VGPRForSGPRSpills - ArchVGPRLimit), |
| 137 | b: 0); |
| 138 | |
| 139 | // AGPR excess pressure conditions |
| 140 | AGPR = std::max(a: static_cast<int>(RP.getAGPRNum() - ArchVGPRLimit), b: 0); |
| 141 | } |
| 142 | }; |
| 143 | } // namespace |
| 144 | |
| 145 | bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, |
| 146 | unsigned MaxOccupancy) const { |
| 147 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 148 | unsigned DynamicVGPRBlockSize = |
| 149 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 150 | |
| 151 | const auto SGPROcc = std::min(a: MaxOccupancy, |
| 152 | b: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum())); |
| 153 | const auto VGPROcc = std::min( |
| 154 | a: MaxOccupancy, b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 155 | DynamicVGPRBlockSize)); |
| 156 | const auto OtherSGPROcc = std::min(a: MaxOccupancy, |
| 157 | b: ST.getOccupancyWithNumSGPRs(SGPRs: O.getSGPRNum())); |
| 158 | const auto OtherVGPROcc = |
| 159 | std::min(a: MaxOccupancy, |
| 160 | b: ST.getOccupancyWithNumVGPRs(VGPRs: O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 161 | DynamicVGPRBlockSize)); |
| 162 | |
| 163 | const auto Occ = std::min(a: SGPROcc, b: VGPROcc); |
| 164 | const auto OtherOcc = std::min(a: OtherSGPROcc, b: OtherVGPROcc); |
| 165 | |
| 166 | // Give first precedence to the better occupancy. |
| 167 | if (Occ != OtherOcc) |
| 168 | return Occ > OtherOcc; |
| 169 | |
| 170 | unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); |
| 171 | |
| 172 | RegExcess Excess(MF, *this); |
| 173 | RegExcess OtherExcess(MF, O); |
| 174 | |
| 175 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
| 176 | |
| 177 | bool ExcessRP = Excess.anyExcess(); |
| 178 | bool OtherExcessRP = OtherExcess.anyExcess(); |
| 179 | |
| 180 | // Give second precedence to the reduced number of spills to hold the register |
| 181 | // pressure. |
| 182 | if (ExcessRP || OtherExcessRP) { |
| 183 | // The difference in excess VGPR pressure, after including VGPRs used for |
| 184 | // SGPR spills |
| 185 | int VGPRDiff = |
| 186 | ((OtherExcess.VGPR + OtherExcess.ArchVGPR + OtherExcess.AGPR) - |
| 187 | (Excess.VGPR + Excess.ArchVGPR + Excess.AGPR)); |
| 188 | |
| 189 | int SGPRDiff = OtherExcess.SGPR - Excess.SGPR; |
| 190 | |
| 191 | if (VGPRDiff != 0) |
| 192 | return VGPRDiff > 0; |
| 193 | if (SGPRDiff != 0) { |
| 194 | unsigned PureExcessVGPR = |
| 195 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 196 | b: 0) + |
| 197 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 198 | unsigned OtherPureExcessVGPR = |
| 199 | std::max( |
| 200 | a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 201 | b: 0) + |
| 202 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 203 | |
| 204 | // If we have a special case where there is a tie in excess VGPR, but one |
| 205 | // of the pressures has VGPR usage from SGPR spills, prefer the pressure |
| 206 | // with SGPR spills. |
| 207 | if (PureExcessVGPR != OtherPureExcessVGPR) |
| 208 | return SGPRDiff < 0; |
| 209 | // If both pressures have the same excess pressure before and after |
| 210 | // accounting for SGPR spills, prefer fewer SGPR spills. |
| 211 | return SGPRDiff > 0; |
| 212 | } |
| 213 | } |
| 214 | |
| 215 | bool SGPRImportant = SGPROcc < VGPROcc; |
| 216 | const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; |
| 217 | |
| 218 | // If both pressures disagree on what is more important compare vgprs. |
| 219 | if (SGPRImportant != OtherSGPRImportant) { |
| 220 | SGPRImportant = false; |
| 221 | } |
| 222 | |
| 223 | // Give third precedence to lower register tuple pressure. |
| 224 | bool SGPRFirst = SGPRImportant; |
| 225 | for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { |
| 226 | if (SGPRFirst) { |
| 227 | auto SW = getSGPRTuplesWeight(); |
| 228 | auto OtherSW = O.getSGPRTuplesWeight(); |
| 229 | if (SW != OtherSW) |
| 230 | return SW < OtherSW; |
| 231 | } else { |
| 232 | auto VW = getVGPRTuplesWeight(); |
| 233 | auto OtherVW = O.getVGPRTuplesWeight(); |
| 234 | if (VW != OtherVW) |
| 235 | return VW < OtherVW; |
| 236 | } |
| 237 | } |
| 238 | |
| 239 | // Give final precedence to lower general RP. |
| 240 | return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): |
| 241 | (getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) < |
| 242 | O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts())); |
| 243 | } |
| 244 | |
| 245 | Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST, |
| 246 | unsigned DynamicVGPRBlockSize) { |
| 247 | return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) { |
| 248 | OS << "VGPRs: " << RP.getArchVGPRNum() << ' ' |
| 249 | << "AGPRs: " << RP.getAGPRNum(); |
| 250 | if (ST) |
| 251 | OS << "(O" |
| 252 | << ST->getOccupancyWithNumVGPRs(VGPRs: RP.getVGPRNum(UnifiedVGPRFile: ST->hasGFX90AInsts()), |
| 253 | DynamicVGPRBlockSize) |
| 254 | << ')'; |
| 255 | OS << ", SGPRs: " << RP.getSGPRNum(); |
| 256 | if (ST) |
| 257 | OS << "(O" << ST->getOccupancyWithNumSGPRs(SGPRs: RP.getSGPRNum()) << ')'; |
| 258 | OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() |
| 259 | << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); |
| 260 | if (ST) |
| 261 | OS << " -> Occ: " << RP.getOccupancy(ST: *ST, DynamicVGPRBlockSize); |
| 262 | OS << '\n'; |
| 263 | }); |
| 264 | } |
| 265 | |
| 266 | static LaneBitmask getDefRegMask(const MachineOperand &MO, |
| 267 | const MachineRegisterInfo &MRI) { |
| 268 | assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); |
| 269 | |
| 270 | // We don't rely on read-undef flag because in case of tentative schedule |
| 271 | // tracking it isn't set correctly yet. This works correctly however since |
| 272 | // use mask has been tracked before using LIS. |
| 273 | return MO.getSubReg() == 0 ? |
| 274 | MRI.getMaxLaneMaskForVReg(Reg: MO.getReg()) : |
| 275 | MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubIdx: MO.getSubReg()); |
| 276 | } |
| 277 | |
| 278 | static void |
| 279 | collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits, |
| 280 | const MachineInstr &MI, const LiveIntervals &LIS, |
| 281 | const MachineRegisterInfo &MRI) { |
| 282 | |
| 283 | auto &TRI = *MRI.getTargetRegisterInfo(); |
| 284 | for (const auto &MO : MI.operands()) { |
| 285 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 286 | continue; |
| 287 | if (!MO.isUse() || !MO.readsReg()) |
| 288 | continue; |
| 289 | |
| 290 | Register Reg = MO.getReg(); |
| 291 | auto I = llvm::find_if(Range&: VRegMaskOrUnits, P: [Reg](const VRegMaskOrUnit &RM) { |
| 292 | return RM.VRegOrUnit.asVirtualReg() == Reg; |
| 293 | }); |
| 294 | |
| 295 | auto &P = I == VRegMaskOrUnits.end() |
| 296 | ? VRegMaskOrUnits.emplace_back(Args: VirtRegOrUnit(Reg), |
| 297 | Args: LaneBitmask::getNone()) |
| 298 | : *I; |
| 299 | |
| 300 | P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(SubIdx: MO.getSubReg()) |
| 301 | : MRI.getMaxLaneMaskForVReg(Reg); |
| 302 | } |
| 303 | |
| 304 | SlotIndex InstrSI; |
| 305 | for (auto &P : VRegMaskOrUnits) { |
| 306 | auto &LI = LIS.getInterval(Reg: P.VRegOrUnit.asVirtualReg()); |
| 307 | if (!LI.hasSubRanges()) |
| 308 | continue; |
| 309 | |
| 310 | // For a tentative schedule LIS isn't updated yet but livemask should |
| 311 | // remain the same on any schedule. Subreg defs can be reordered but they |
| 312 | // all must dominate uses anyway. |
| 313 | if (!InstrSI) |
| 314 | InstrSI = LIS.getInstructionIndex(Instr: MI).getBaseIndex(); |
| 315 | |
| 316 | P.LaneMask = getLiveLaneMask(LI, SI: InstrSI, MRI, LaneMaskFilter: P.LaneMask); |
| 317 | } |
| 318 | } |
| 319 | |
| 320 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 321 | static LaneBitmask getLanesWithProperty( |
| 322 | const LiveIntervals &LIS, const MachineRegisterInfo &MRI, |
| 323 | bool TrackLaneMasks, Register Reg, SlotIndex Pos, |
| 324 | function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) { |
| 325 | assert(Reg.isVirtual()); |
| 326 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 327 | LaneBitmask Result; |
| 328 | if (TrackLaneMasks && LI.hasSubRanges()) { |
| 329 | for (const LiveInterval::SubRange &SR : LI.subranges()) { |
| 330 | if (Property(SR, Pos)) |
| 331 | Result |= SR.LaneMask; |
| 332 | } |
| 333 | } else if (Property(LI, Pos)) { |
| 334 | Result = |
| 335 | TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(Reg) : LaneBitmask::getAll(); |
| 336 | } |
| 337 | |
| 338 | return Result; |
| 339 | } |
| 340 | |
| 341 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 342 | /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}. |
| 343 | /// The query starts with a lane bitmask which gets lanes/bits removed for every |
| 344 | /// use we find. |
| 345 | static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, |
| 346 | SlotIndex PriorUseIdx, SlotIndex NextUseIdx, |
| 347 | const MachineRegisterInfo &MRI, |
| 348 | const SIRegisterInfo *TRI, |
| 349 | const LiveIntervals *LIS, |
| 350 | bool Upward = false) { |
| 351 | for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { |
| 352 | if (MO.isUndef()) |
| 353 | continue; |
| 354 | const MachineInstr *MI = MO.getParent(); |
| 355 | SlotIndex InstSlot = LIS->getInstructionIndex(Instr: *MI).getRegSlot(); |
| 356 | bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) |
| 357 | : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); |
| 358 | if (!InRange) |
| 359 | continue; |
| 360 | |
| 361 | unsigned SubRegIdx = MO.getSubReg(); |
| 362 | LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubIdx: SubRegIdx); |
| 363 | LastUseMask &= ~UseMask; |
| 364 | if (LastUseMask.none()) |
| 365 | return LaneBitmask::getNone(); |
| 366 | } |
| 367 | return LastUseMask; |
| 368 | } |
| 369 | |
| 370 | //////////////////////////////////////////////////////////////////////////////// |
| 371 | // GCNRPTarget |
| 372 | |
| 373 | GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP) |
| 374 | : GCNRPTarget(RP, MF) { |
| 375 | const Function &F = MF.getFunction(); |
| 376 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 377 | setTarget(NumSGPRs: ST.getMaxNumSGPRs(F), NumVGPRs: ST.getMaxNumVGPRs(F)); |
| 378 | } |
| 379 | |
| 380 | GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs, |
| 381 | const MachineFunction &MF, const GCNRegPressure &RP) |
| 382 | : GCNRPTarget(RP, MF) { |
| 383 | setTarget(NumSGPRs, NumVGPRs); |
| 384 | } |
| 385 | |
| 386 | GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF, |
| 387 | const GCNRegPressure &RP) |
| 388 | : GCNRPTarget(RP, MF) { |
| 389 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 390 | unsigned DynamicVGPRBlockSize = |
| 391 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 392 | setTarget(NumSGPRs: ST.getMaxNumSGPRs(WavesPerEU: Occupancy, /*Addressable=*/false), |
| 393 | NumVGPRs: ST.getMaxNumVGPRs(WavesPerEU: Occupancy, DynamicVGPRBlockSize)); |
| 394 | } |
| 395 | |
| 396 | void GCNRPTarget::setTarget(unsigned NumSGPRs, unsigned NumVGPRs) { |
| 397 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 398 | MaxSGPRs = std::min(a: ST.getAddressableNumSGPRs(), b: NumSGPRs); |
| 399 | MaxVGPRs = std::min(a: ST.getAddressableNumArchVGPRs(), b: NumVGPRs); |
| 400 | if (UnifiedRF) { |
| 401 | unsigned DynamicVGPRBlockSize = |
| 402 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 403 | MaxUnifiedVGPRs = |
| 404 | std::min(a: ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), b: NumVGPRs); |
| 405 | } else { |
| 406 | MaxUnifiedVGPRs = 0; |
| 407 | } |
| 408 | } |
| 409 | |
| 410 | bool GCNRPTarget::isSaveBeneficial(Register Reg) const { |
| 411 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 412 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 413 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 414 | const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI); |
| 415 | |
| 416 | RegExcess Excess(MF, RP, *this); |
| 417 | |
| 418 | if (SRI->isSGPRClass(RC)) |
| 419 | return Excess.SGPR; |
| 420 | |
| 421 | if (SRI->isAGPRClass(RC)) |
| 422 | return (UnifiedRF && Excess.VGPR) || Excess.AGPR; |
| 423 | |
| 424 | return (UnifiedRF && Excess.VGPR) || Excess.ArchVGPR; |
| 425 | } |
| 426 | |
| 427 | bool GCNRPTarget::satisfied() const { |
| 428 | if (RP.getSGPRNum() > MaxSGPRs || RP.getVGPRNum(UnifiedVGPRFile: false) > MaxVGPRs) |
| 429 | return false; |
| 430 | if (UnifiedRF && RP.getVGPRNum(UnifiedVGPRFile: true) > MaxUnifiedVGPRs) |
| 431 | return false; |
| 432 | return true; |
| 433 | } |
| 434 | |
| 435 | bool GCNRPTarget::hasVectorRegisterExcess() const { |
| 436 | RegExcess Excess(MF, RP, *this); |
| 437 | return Excess.hasVectorRegisterExcess(); |
| 438 | } |
| 439 | |
| 440 | /////////////////////////////////////////////////////////////////////////////// |
| 441 | // GCNRPTracker |
| 442 | |
| 443 | LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, |
| 444 | const LiveIntervals &LIS, |
| 445 | const MachineRegisterInfo &MRI, |
| 446 | LaneBitmask LaneMaskFilter) { |
| 447 | return getLiveLaneMask(LI: LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); |
| 448 | } |
| 449 | |
| 450 | LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, |
| 451 | const MachineRegisterInfo &MRI, |
| 452 | LaneBitmask LaneMaskFilter) { |
| 453 | LaneBitmask LiveMask; |
| 454 | if (LI.hasSubRanges()) { |
| 455 | for (const auto &S : LI.subranges()) |
| 456 | if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(index: SI)) { |
| 457 | LiveMask |= S.LaneMask; |
| 458 | assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); |
| 459 | } |
| 460 | } else if (LI.liveAt(index: SI)) { |
| 461 | LiveMask = MRI.getMaxLaneMaskForVReg(Reg: LI.reg()); |
| 462 | } |
| 463 | LiveMask &= LaneMaskFilter; |
| 464 | return LiveMask; |
| 465 | } |
| 466 | |
| 467 | GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, |
| 468 | const LiveIntervals &LIS, |
| 469 | const MachineRegisterInfo &MRI, |
| 470 | GCNRegPressure::RegKind RegKind) { |
| 471 | GCNRPTracker::LiveRegSet LiveRegs; |
| 472 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 473 | auto Reg = Register::index2VirtReg(Index: I); |
| 474 | if (RegKind != GCNRegPressure::TOTAL_KINDS && |
| 475 | GCNRegPressure::getRegKind(Reg, MRI) != RegKind) |
| 476 | continue; |
| 477 | if (!LIS.hasInterval(Reg)) |
| 478 | continue; |
| 479 | auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); |
| 480 | if (LiveMask.any()) |
| 481 | LiveRegs[Reg] = LiveMask; |
| 482 | } |
| 483 | return LiveRegs; |
| 484 | } |
| 485 | |
| 486 | void GCNRPTracker::reset(const MachineInstr &MI, |
| 487 | const LiveRegSet *LiveRegsCopy, |
| 488 | bool After) { |
| 489 | const MachineFunction &MF = *MI.getMF(); |
| 490 | MRI = &MF.getRegInfo(); |
| 491 | if (LiveRegsCopy) { |
| 492 | if (&LiveRegs != LiveRegsCopy) |
| 493 | LiveRegs = *LiveRegsCopy; |
| 494 | } else { |
| 495 | LiveRegs = After ? getLiveRegsAfter(MI, LIS) |
| 496 | : getLiveRegsBefore(MI, LIS); |
| 497 | } |
| 498 | |
| 499 | MaxPressure = CurPressure = getRegPressure(MRI: *MRI, LiveRegs); |
| 500 | } |
| 501 | |
| 502 | void GCNRPTracker::reset(const MachineRegisterInfo &MRI_, |
| 503 | const LiveRegSet &LiveRegs_) { |
| 504 | MRI = &MRI_; |
| 505 | LiveRegs = LiveRegs_; |
| 506 | LastTrackedMI = nullptr; |
| 507 | MaxPressure = CurPressure = getRegPressure(MRI: MRI_, LiveRegs: LiveRegs_); |
| 508 | } |
| 509 | |
| 510 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 511 | LaneBitmask GCNRPTracker::getLastUsedLanes(Register Reg, SlotIndex Pos) const { |
| 512 | return getLanesWithProperty( |
| 513 | LIS, MRI: *MRI, TrackLaneMasks: true, Reg, Pos: Pos.getBaseIndex(), |
| 514 | Property: [](const LiveRange &LR, SlotIndex Pos) { |
| 515 | const LiveRange::Segment *S = LR.getSegmentContaining(Idx: Pos); |
| 516 | return S != nullptr && S->end == Pos.getRegSlot(); |
| 517 | }); |
| 518 | } |
| 519 | |
| 520 | //////////////////////////////////////////////////////////////////////////////// |
| 521 | // GCNUpwardRPTracker |
| 522 | |
| 523 | void GCNUpwardRPTracker::recede(const MachineInstr &MI) { |
| 524 | assert(MRI && "call reset first" ); |
| 525 | |
| 526 | LastTrackedMI = &MI; |
| 527 | |
| 528 | if (MI.isDebugInstr()) |
| 529 | return; |
| 530 | |
| 531 | // Kill all defs. |
| 532 | GCNRegPressure DefPressure, ECDefPressure; |
| 533 | bool HasECDefs = false; |
| 534 | for (const MachineOperand &MO : MI.all_defs()) { |
| 535 | if (!MO.getReg().isVirtual()) |
| 536 | continue; |
| 537 | |
| 538 | Register Reg = MO.getReg(); |
| 539 | LaneBitmask DefMask = getDefRegMask(MO, MRI: *MRI); |
| 540 | |
| 541 | // Treat a def as fully live at the moment of definition: keep a record. |
| 542 | if (MO.isEarlyClobber()) { |
| 543 | ECDefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 544 | HasECDefs = true; |
| 545 | } else |
| 546 | DefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 547 | |
| 548 | auto I = LiveRegs.find(Val: Reg); |
| 549 | if (I == LiveRegs.end()) |
| 550 | continue; |
| 551 | |
| 552 | LaneBitmask &LiveMask = I->second; |
| 553 | LaneBitmask PrevMask = LiveMask; |
| 554 | LiveMask &= ~DefMask; |
| 555 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 556 | if (LiveMask.none()) |
| 557 | LiveRegs.erase(I); |
| 558 | } |
| 559 | |
| 560 | // Update MaxPressure with defs pressure. |
| 561 | DefPressure += CurPressure; |
| 562 | if (HasECDefs) |
| 563 | DefPressure += ECDefPressure; |
| 564 | MaxPressure = max(P1: DefPressure, P2: MaxPressure); |
| 565 | |
| 566 | // Make uses alive. |
| 567 | SmallVector<VRegMaskOrUnit, 8> RegUses; |
| 568 | collectVirtualRegUses(VRegMaskOrUnits&: RegUses, MI, LIS, MRI: *MRI); |
| 569 | for (const VRegMaskOrUnit &U : RegUses) { |
| 570 | LaneBitmask &LiveMask = LiveRegs[U.VRegOrUnit.asVirtualReg()]; |
| 571 | LaneBitmask PrevMask = LiveMask; |
| 572 | LiveMask |= U.LaneMask; |
| 573 | CurPressure.inc(Reg: U.VRegOrUnit.asVirtualReg(), PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 574 | } |
| 575 | |
| 576 | // Update MaxPressure with uses plus early-clobber defs pressure. |
| 577 | MaxPressure = HasECDefs ? max(P1: CurPressure + ECDefPressure, P2: MaxPressure) |
| 578 | : max(P1: CurPressure, P2: MaxPressure); |
| 579 | |
| 580 | assert(CurPressure == getRegPressure(*MRI, LiveRegs)); |
| 581 | } |
| 582 | |
| 583 | //////////////////////////////////////////////////////////////////////////////// |
| 584 | // GCNDownwardRPTracker |
| 585 | |
| 586 | bool GCNDownwardRPTracker::reset(const MachineInstr &MI, |
| 587 | const LiveRegSet *LiveRegsCopy) { |
| 588 | MRI = &MI.getMF()->getRegInfo(); |
| 589 | LastTrackedMI = nullptr; |
| 590 | MBBEnd = MI.getParent()->end(); |
| 591 | NextMI = &MI; |
| 592 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
| 593 | if (NextMI == MBBEnd) |
| 594 | return false; |
| 595 | GCNRPTracker::reset(MI: *NextMI, LiveRegsCopy, After: false); |
| 596 | return true; |
| 597 | } |
| 598 | |
| 599 | bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, |
| 600 | bool UseInternalIterator) { |
| 601 | assert(MRI && "call reset first" ); |
| 602 | SlotIndex SI; |
| 603 | const MachineInstr *CurrMI; |
| 604 | if (UseInternalIterator) { |
| 605 | if (!LastTrackedMI) |
| 606 | return NextMI == MBBEnd; |
| 607 | |
| 608 | assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); |
| 609 | CurrMI = LastTrackedMI; |
| 610 | |
| 611 | SI = NextMI == MBBEnd |
| 612 | ? LIS.getInstructionIndex(Instr: *LastTrackedMI).getDeadSlot() |
| 613 | : LIS.getInstructionIndex(Instr: *NextMI).getBaseIndex(); |
| 614 | } else { //! UseInternalIterator |
| 615 | SI = LIS.getInstructionIndex(Instr: *MI).getBaseIndex(); |
| 616 | CurrMI = MI; |
| 617 | } |
| 618 | |
| 619 | assert(SI.isValid()); |
| 620 | |
| 621 | // Remove dead registers or mask bits. |
| 622 | SmallSet<Register, 8> SeenRegs; |
| 623 | for (auto &MO : CurrMI->operands()) { |
| 624 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 625 | continue; |
| 626 | if (MO.isUse() && !MO.readsReg()) |
| 627 | continue; |
| 628 | if (!UseInternalIterator && MO.isDef()) |
| 629 | continue; |
| 630 | if (!SeenRegs.insert(V: MO.getReg()).second) |
| 631 | continue; |
| 632 | const LiveInterval &LI = LIS.getInterval(Reg: MO.getReg()); |
| 633 | if (LI.hasSubRanges()) { |
| 634 | auto It = LiveRegs.end(); |
| 635 | for (const auto &S : LI.subranges()) { |
| 636 | if (!S.liveAt(index: SI)) { |
| 637 | if (It == LiveRegs.end()) { |
| 638 | It = LiveRegs.find(Val: MO.getReg()); |
| 639 | if (It == LiveRegs.end()) |
| 640 | llvm_unreachable("register isn't live" ); |
| 641 | } |
| 642 | auto PrevMask = It->second; |
| 643 | It->second &= ~S.LaneMask; |
| 644 | CurPressure.inc(Reg: MO.getReg(), PrevMask, NewMask: It->second, MRI: *MRI); |
| 645 | } |
| 646 | } |
| 647 | if (It != LiveRegs.end() && It->second.none()) |
| 648 | LiveRegs.erase(I: It); |
| 649 | } else if (!LI.liveAt(index: SI)) { |
| 650 | auto It = LiveRegs.find(Val: MO.getReg()); |
| 651 | if (It == LiveRegs.end()) |
| 652 | llvm_unreachable("register isn't live" ); |
| 653 | CurPressure.inc(Reg: MO.getReg(), PrevMask: It->second, NewMask: LaneBitmask::getNone(), MRI: *MRI); |
| 654 | LiveRegs.erase(I: It); |
| 655 | } |
| 656 | } |
| 657 | |
| 658 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 659 | |
| 660 | LastTrackedMI = nullptr; |
| 661 | |
| 662 | return UseInternalIterator && (NextMI == MBBEnd); |
| 663 | } |
| 664 | |
| 665 | void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, |
| 666 | bool UseInternalIterator) { |
| 667 | if (UseInternalIterator) { |
| 668 | LastTrackedMI = &*NextMI++; |
| 669 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
| 670 | } else { |
| 671 | LastTrackedMI = MI; |
| 672 | } |
| 673 | |
| 674 | const MachineInstr *CurrMI = LastTrackedMI; |
| 675 | |
| 676 | // Add new registers or mask bits. |
| 677 | for (const auto &MO : CurrMI->all_defs()) { |
| 678 | Register Reg = MO.getReg(); |
| 679 | if (!Reg.isVirtual()) |
| 680 | continue; |
| 681 | auto &LiveMask = LiveRegs[Reg]; |
| 682 | auto PrevMask = LiveMask; |
| 683 | LiveMask |= getDefRegMask(MO, MRI: *MRI); |
| 684 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 685 | } |
| 686 | |
| 687 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 688 | } |
| 689 | |
| 690 | bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) { |
| 691 | if (UseInternalIterator && NextMI == MBBEnd) |
| 692 | return false; |
| 693 | |
| 694 | advanceBeforeNext(MI, UseInternalIterator); |
| 695 | advanceToNext(MI, UseInternalIterator); |
| 696 | if (!UseInternalIterator) { |
| 697 | // We must remove any dead def lanes from the current RP |
| 698 | advanceBeforeNext(MI, UseInternalIterator: true); |
| 699 | } |
| 700 | return true; |
| 701 | } |
| 702 | |
| 703 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { |
| 704 | while (NextMI != End) |
| 705 | if (!advance()) return false; |
| 706 | return true; |
| 707 | } |
| 708 | |
| 709 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, |
| 710 | MachineBasicBlock::const_iterator End, |
| 711 | const LiveRegSet *LiveRegsCopy) { |
| 712 | reset(MI: *Begin, LiveRegsCopy); |
| 713 | return advance(End); |
| 714 | } |
| 715 | |
| 716 | Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, |
| 717 | const GCNRPTracker::LiveRegSet &TrackedLR, |
| 718 | const TargetRegisterInfo *TRI, StringRef Pfx) { |
| 719 | return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { |
| 720 | for (auto const &P : TrackedLR) { |
| 721 | auto I = LISLR.find(Val: P.first); |
| 722 | if (I == LISLR.end()) { |
| 723 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 724 | << " isn't found in LIS reported set\n" ; |
| 725 | } else if (I->second != P.second) { |
| 726 | OS << Pfx << printReg(Reg: P.first, TRI) |
| 727 | << " masks doesn't match: LIS reported " << PrintLaneMask(LaneMask: I->second) |
| 728 | << ", tracked " << PrintLaneMask(LaneMask: P.second) << '\n'; |
| 729 | } |
| 730 | } |
| 731 | for (auto const &P : LISLR) { |
| 732 | auto I = TrackedLR.find(Val: P.first); |
| 733 | if (I == TrackedLR.end()) { |
| 734 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 735 | << " isn't found in tracked set\n" ; |
| 736 | } |
| 737 | } |
| 738 | }); |
| 739 | } |
| 740 | |
| 741 | GCNRegPressure |
| 742 | GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, |
| 743 | const SIRegisterInfo *TRI) const { |
| 744 | assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction." ); |
| 745 | |
| 746 | SlotIndex SlotIdx; |
| 747 | SlotIdx = LIS.getInstructionIndex(Instr: *MI).getRegSlot(); |
| 748 | |
| 749 | // Account for register pressure similar to RegPressureTracker::recede(). |
| 750 | RegisterOperands RegOpers; |
| 751 | RegOpers.collect(MI: *MI, TRI: *TRI, MRI: *MRI, TrackLaneMasks: true, /*IgnoreDead=*/false); |
| 752 | RegOpers.adjustLaneLiveness(LIS, MRI: *MRI, Pos: SlotIdx); |
| 753 | GCNRegPressure TempPressure = CurPressure; |
| 754 | |
| 755 | for (const VRegMaskOrUnit &Use : RegOpers.Uses) { |
| 756 | if (!Use.VRegOrUnit.isVirtualReg()) |
| 757 | continue; |
| 758 | Register Reg = Use.VRegOrUnit.asVirtualReg(); |
| 759 | LaneBitmask LastUseMask = getLastUsedLanes(Reg, Pos: SlotIdx); |
| 760 | if (LastUseMask.none()) |
| 761 | continue; |
| 762 | // The LastUseMask is queried from the liveness information of instruction |
| 763 | // which may be further down the schedule. Some lanes may actually not be |
| 764 | // last uses for the current position. |
| 765 | // FIXME: allow the caller to pass in the list of vreg uses that remain |
| 766 | // to be bottom-scheduled to avoid searching uses at each query. |
| 767 | SlotIndex CurrIdx; |
| 768 | const MachineBasicBlock *MBB = MI->getParent(); |
| 769 | MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( |
| 770 | It: LastTrackedMI ? LastTrackedMI : MBB->begin(), End: MBB->end()); |
| 771 | if (IdxPos == MBB->end()) { |
| 772 | CurrIdx = LIS.getMBBEndIdx(mbb: MBB); |
| 773 | } else { |
| 774 | CurrIdx = LIS.getInstructionIndex(Instr: *IdxPos).getRegSlot(); |
| 775 | } |
| 776 | |
| 777 | LastUseMask = |
| 778 | findUseBetween(Reg, LastUseMask, PriorUseIdx: CurrIdx, NextUseIdx: SlotIdx, MRI: *MRI, TRI, LIS: &LIS); |
| 779 | if (LastUseMask.none()) |
| 780 | continue; |
| 781 | |
| 782 | auto It = LiveRegs.find(Val: Reg); |
| 783 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 784 | LaneBitmask NewMask = LiveMask & ~LastUseMask; |
| 785 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 786 | } |
| 787 | |
| 788 | // Generate liveness for defs. |
| 789 | for (const VRegMaskOrUnit &Def : RegOpers.Defs) { |
| 790 | if (!Def.VRegOrUnit.isVirtualReg()) |
| 791 | continue; |
| 792 | Register Reg = Def.VRegOrUnit.asVirtualReg(); |
| 793 | auto It = LiveRegs.find(Val: Reg); |
| 794 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 795 | LaneBitmask NewMask = LiveMask | Def.LaneMask; |
| 796 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 797 | } |
| 798 | |
| 799 | return TempPressure; |
| 800 | } |
| 801 | |
| 802 | bool GCNUpwardRPTracker::isValid() const { |
| 803 | const auto &SI = LIS.getInstructionIndex(Instr: *LastTrackedMI).getBaseIndex(); |
| 804 | const auto LISLR = llvm::getLiveRegs(SI, LIS, MRI: *MRI); |
| 805 | const auto &TrackedLR = LiveRegs; |
| 806 | |
| 807 | if (!isEqual(S1: LISLR, S2: TrackedLR)) { |
| 808 | dbgs() << "\nGCNUpwardRPTracker error: Tracked and" |
| 809 | " LIS reported livesets mismatch:\n" |
| 810 | << print(LiveRegs: LISLR, MRI: *MRI); |
| 811 | reportMismatch(LISLR, TrackedLR, TRI: MRI->getTargetRegisterInfo()); |
| 812 | return false; |
| 813 | } |
| 814 | |
| 815 | auto LISPressure = getRegPressure(MRI: *MRI, LiveRegs: LISLR); |
| 816 | if (LISPressure != CurPressure) { |
| 817 | dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " |
| 818 | << print(RP: CurPressure) << "LIS rpt: " << print(RP: LISPressure); |
| 819 | return false; |
| 820 | } |
| 821 | return true; |
| 822 | } |
| 823 | |
| 824 | Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, |
| 825 | const MachineRegisterInfo &MRI) { |
| 826 | return Printable([&LiveRegs, &MRI](raw_ostream &OS) { |
| 827 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 828 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 829 | Register Reg = Register::index2VirtReg(Index: I); |
| 830 | auto It = LiveRegs.find(Val: Reg); |
| 831 | if (It != LiveRegs.end() && It->second.any()) |
| 832 | OS << ' ' << printReg(Reg, TRI) << ':' << PrintLaneMask(LaneMask: It->second); |
| 833 | } |
| 834 | OS << '\n'; |
| 835 | }); |
| 836 | } |
| 837 | |
| 838 | void GCNRegPressure::dump() const { dbgs() << print(RP: *this); } |
| 839 | |
| 840 | static cl::opt<bool> UseDownwardTracker( |
| 841 | "amdgpu-print-rp-downward" , |
| 842 | cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass" ), |
| 843 | cl::init(Val: false), cl::Hidden); |
| 844 | |
| 845 | char llvm::GCNRegPressurePrinter::ID = 0; |
| 846 | char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; |
| 847 | |
| 848 | INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp" , "" , true, true) |
| 849 | |
| 850 | // Return lanemask of Reg's subregs that are live-through at [Begin, End] and |
| 851 | // are fully covered by Mask. |
| 852 | static LaneBitmask |
| 853 | getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, |
| 854 | Register Reg, SlotIndex Begin, SlotIndex End, |
| 855 | LaneBitmask Mask = LaneBitmask::getAll()) { |
| 856 | |
| 857 | auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { |
| 858 | auto *Segment = LR.getSegmentContaining(Idx: Begin); |
| 859 | return Segment && Segment->contains(I: End); |
| 860 | }; |
| 861 | |
| 862 | LaneBitmask LiveThroughMask; |
| 863 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 864 | if (LI.hasSubRanges()) { |
| 865 | for (auto &SR : LI.subranges()) { |
| 866 | if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) |
| 867 | LiveThroughMask |= SR.LaneMask; |
| 868 | } |
| 869 | } else { |
| 870 | LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); |
| 871 | if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) |
| 872 | LiveThroughMask = RegMask; |
| 873 | } |
| 874 | |
| 875 | return LiveThroughMask; |
| 876 | } |
| 877 | |
| 878 | bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { |
| 879 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 880 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 881 | const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); |
| 882 | |
| 883 | auto &OS = dbgs(); |
| 884 | |
| 885 | // Leading spaces are important for YAML syntax. |
| 886 | #define PFX " " |
| 887 | |
| 888 | OS << "---\nname: " << MF.getName() << "\nbody: |\n" ; |
| 889 | |
| 890 | auto printRP = [](const GCNRegPressure &RP) { |
| 891 | return Printable([&RP](raw_ostream &OS) { |
| 892 | OS << format(PFX " %-5d" , Vals: RP.getSGPRNum()) |
| 893 | << format(Fmt: " %-5d" , Vals: RP.getVGPRNum(UnifiedVGPRFile: false)); |
| 894 | }); |
| 895 | }; |
| 896 | |
| 897 | auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, |
| 898 | const GCNRPTracker::LiveRegSet &LISLR) { |
| 899 | if (LISLR != TrackedLR) { |
| 900 | OS << PFX " mis LIS: " << llvm::print(LiveRegs: LISLR, MRI) |
| 901 | << reportMismatch(LISLR, TrackedLR, TRI, PFX " " ); |
| 902 | } |
| 903 | }; |
| 904 | |
| 905 | // Register pressure before and at an instruction (in program order). |
| 906 | SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; |
| 907 | |
| 908 | for (auto &MBB : MF) { |
| 909 | RP.clear(); |
| 910 | RP.reserve(N: MBB.size()); |
| 911 | |
| 912 | OS << PFX; |
| 913 | MBB.printName(os&: OS); |
| 914 | OS << ":\n" ; |
| 915 | |
| 916 | SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(mbb: &MBB); |
| 917 | SlotIndex MBBLastSlot = LIS.getSlotIndexes()->getMBBLastIdx(MBB: &MBB); |
| 918 | |
| 919 | GCNRPTracker::LiveRegSet LiveIn, LiveOut; |
| 920 | GCNRegPressure RPAtMBBEnd; |
| 921 | |
| 922 | if (UseDownwardTracker) { |
| 923 | if (MBB.empty()) { |
| 924 | LiveIn = LiveOut = getLiveRegs(SI: MBBStartSlot, LIS, MRI); |
| 925 | RPAtMBBEnd = getRegPressure(MRI, LiveRegs&: LiveIn); |
| 926 | } else { |
| 927 | GCNDownwardRPTracker RPT(LIS); |
| 928 | RPT.reset(MI: MBB.front()); |
| 929 | |
| 930 | LiveIn = RPT.getLiveRegs(); |
| 931 | |
| 932 | while (!RPT.advanceBeforeNext()) { |
| 933 | GCNRegPressure RPBeforeMI = RPT.getPressure(); |
| 934 | RPT.advanceToNext(); |
| 935 | RP.emplace_back(Args&: RPBeforeMI, Args: RPT.getPressure()); |
| 936 | } |
| 937 | |
| 938 | LiveOut = RPT.getLiveRegs(); |
| 939 | RPAtMBBEnd = RPT.getPressure(); |
| 940 | } |
| 941 | } else { |
| 942 | GCNUpwardRPTracker RPT(LIS); |
| 943 | RPT.reset(MRI, SI: MBBLastSlot); |
| 944 | |
| 945 | LiveOut = RPT.getLiveRegs(); |
| 946 | RPAtMBBEnd = RPT.getPressure(); |
| 947 | |
| 948 | for (auto &MI : reverse(C&: MBB)) { |
| 949 | RPT.resetMaxPressure(); |
| 950 | RPT.recede(MI); |
| 951 | if (!MI.isDebugInstr()) |
| 952 | RP.emplace_back(Args: RPT.getPressure(), Args: RPT.getMaxPressure()); |
| 953 | } |
| 954 | |
| 955 | LiveIn = RPT.getLiveRegs(); |
| 956 | } |
| 957 | |
| 958 | OS << PFX " Live-in: " << llvm::print(LiveRegs: LiveIn, MRI); |
| 959 | if (!UseDownwardTracker) |
| 960 | ReportLISMismatchIfAny(LiveIn, getLiveRegs(SI: MBBStartSlot, LIS, MRI)); |
| 961 | |
| 962 | OS << PFX " SGPR VGPR\n" ; |
| 963 | int I = 0; |
| 964 | for (auto &MI : MBB) { |
| 965 | if (!MI.isDebugInstr()) { |
| 966 | auto &[RPBeforeInstr, RPAtInstr] = |
| 967 | RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; |
| 968 | ++I; |
| 969 | OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " " ; |
| 970 | } else |
| 971 | OS << PFX " " ; |
| 972 | MI.print(OS); |
| 973 | } |
| 974 | OS << printRP(RPAtMBBEnd) << '\n'; |
| 975 | |
| 976 | OS << PFX " Live-out:" << llvm::print(LiveRegs: LiveOut, MRI); |
| 977 | if (UseDownwardTracker) |
| 978 | ReportLISMismatchIfAny(LiveOut, getLiveRegs(SI: MBBLastSlot, LIS, MRI)); |
| 979 | |
| 980 | GCNRPTracker::LiveRegSet LiveThrough; |
| 981 | for (auto [Reg, Mask] : LiveIn) { |
| 982 | LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Val: Reg); |
| 983 | if (MaskIntersection.any()) { |
| 984 | LaneBitmask LTMask = getRegLiveThroughMask( |
| 985 | MRI, LIS, Reg, Begin: MBBStartSlot, End: MBBLastSlot, Mask: MaskIntersection); |
| 986 | if (LTMask.any()) |
| 987 | LiveThrough[Reg] = LTMask; |
| 988 | } |
| 989 | } |
| 990 | OS << PFX " Live-thr:" << llvm::print(LiveRegs: LiveThrough, MRI); |
| 991 | OS << printRP(getRegPressure(MRI, LiveRegs&: LiveThrough)) << '\n'; |
| 992 | } |
| 993 | OS << "...\n" ; |
| 994 | return false; |
| 995 | |
| 996 | #undef PFX |
| 997 | } |
| 998 | |
| 999 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 1000 | LLVM_DUMP_METHOD void llvm::dumpMaxRegPressure(MachineFunction &MF, |
| 1001 | GCNRegPressure::RegKind Kind, |
| 1002 | LiveIntervals &LIS, |
| 1003 | const MachineLoopInfo *MLI) { |
| 1004 | |
| 1005 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 1006 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 1007 | auto &OS = dbgs(); |
| 1008 | const char *RegName = GCNRegPressure::getName(Kind); |
| 1009 | |
| 1010 | unsigned MaxNumRegs = 0; |
| 1011 | const MachineInstr *MaxPressureMI = nullptr; |
| 1012 | GCNUpwardRPTracker RPT(LIS); |
| 1013 | for (const MachineBasicBlock &MBB : MF) { |
| 1014 | RPT.reset(MRI, LIS.getSlotIndexes()->getMBBEndIdx(&MBB).getPrevSlot()); |
| 1015 | for (const MachineInstr &MI : reverse(MBB)) { |
| 1016 | RPT.recede(MI); |
| 1017 | unsigned NumRegs = RPT.getMaxPressure().getNumRegs(Kind); |
| 1018 | if (NumRegs > MaxNumRegs) { |
| 1019 | MaxNumRegs = NumRegs; |
| 1020 | MaxPressureMI = &MI; |
| 1021 | } |
| 1022 | } |
| 1023 | } |
| 1024 | |
| 1025 | SlotIndex MISlot = LIS.getInstructionIndex(*MaxPressureMI); |
| 1026 | |
| 1027 | // Max pressure can occur at either the early-clobber or register slot. |
| 1028 | // Choose the maximum liveset between both slots. This is ugly but this is |
| 1029 | // diagnostic code. |
| 1030 | SlotIndex ECSlot = MISlot.getRegSlot(true); |
| 1031 | SlotIndex RSlot = MISlot.getRegSlot(false); |
| 1032 | GCNRPTracker::LiveRegSet ECLiveSet = getLiveRegs(ECSlot, LIS, MRI, Kind); |
| 1033 | GCNRPTracker::LiveRegSet RLiveSet = getLiveRegs(RSlot, LIS, MRI, Kind); |
| 1034 | unsigned ECNumRegs = getRegPressure(MRI, ECLiveSet).getNumRegs(Kind); |
| 1035 | unsigned RNumRegs = getRegPressure(MRI, RLiveSet).getNumRegs(Kind); |
| 1036 | GCNRPTracker::LiveRegSet *LiveSet = |
| 1037 | ECNumRegs > RNumRegs ? &ECLiveSet : &RLiveSet; |
| 1038 | SlotIndex MaxPressureSlot = ECNumRegs > RNumRegs ? ECSlot : RSlot; |
| 1039 | assert(getRegPressure(MRI, *LiveSet).getNumRegs(Kind) == MaxNumRegs); |
| 1040 | |
| 1041 | // Split live registers into single-def and multi-def sets. |
| 1042 | GCNRegPressure SDefPressure, MDefPressure; |
| 1043 | SmallVector<Register, 16> SDefRegs, MDefRegs; |
| 1044 | for (auto [Reg, LaneMask] : *LiveSet) { |
| 1045 | assert(GCNRegPressure::getRegKind(Reg, MRI) == Kind); |
| 1046 | LiveInterval &LI = LIS.getInterval(Reg); |
| 1047 | if (LI.getNumValNums() == 1 || |
| 1048 | (LI.hasSubRanges() && |
| 1049 | llvm::all_of(LI.subranges(), [](const LiveInterval::SubRange &SR) { |
| 1050 | return SR.getNumValNums() == 1; |
| 1051 | }))) { |
| 1052 | SDefPressure.inc(Reg, LaneBitmask::getNone(), LaneMask, MRI); |
| 1053 | SDefRegs.push_back(Reg); |
| 1054 | } else { |
| 1055 | MDefPressure.inc(Reg, LaneBitmask::getNone(), LaneMask, MRI); |
| 1056 | MDefRegs.push_back(Reg); |
| 1057 | } |
| 1058 | } |
| 1059 | unsigned SDefNumRegs = SDefPressure.getNumRegs(Kind); |
| 1060 | unsigned MDefNumRegs = MDefPressure.getNumRegs(Kind); |
| 1061 | assert(SDefNumRegs + MDefNumRegs == MaxNumRegs); |
| 1062 | |
| 1063 | auto printLoc = [&](const MachineBasicBlock *MBB, SlotIndex SI) { |
| 1064 | return Printable([&, MBB, SI](raw_ostream &OS) { |
| 1065 | OS << SI << ':' << printMBBReference(*MBB); |
| 1066 | if (MLI) |
| 1067 | if (const MachineLoop *ML = MLI->getLoopFor(MBB)) |
| 1068 | OS << " (LoopHdr " << printMBBReference(*ML->getHeader()) |
| 1069 | << ", Depth " << ML->getLoopDepth() << ")" ; |
| 1070 | }); |
| 1071 | }; |
| 1072 | |
| 1073 | auto PrintRegInfo = [&](Register Reg, LaneBitmask LiveMask) { |
| 1074 | GCNRegPressure RegPressure; |
| 1075 | RegPressure.inc(Reg, LaneBitmask::getNone(), LiveMask, MRI); |
| 1076 | OS << " " << printReg(Reg, TRI) << ':' |
| 1077 | << TRI->getRegClassName(MRI.getRegClass(Reg)) << ", LiveMask " |
| 1078 | << PrintLaneMask(LiveMask) << " (" << RegPressure.getNumRegs(Kind) << ' ' |
| 1079 | << RegName << "s)\n" ; |
| 1080 | |
| 1081 | // Use std::map to sort def/uses by SlotIndex. |
| 1082 | std::map<SlotIndex, const MachineInstr *> Instrs; |
| 1083 | for (const MachineInstr &MI : MRI.reg_nodbg_instructions(Reg)) { |
| 1084 | Instrs[LIS.getInstructionIndex(MI).getRegSlot()] = &MI; |
| 1085 | } |
| 1086 | |
| 1087 | for (const auto &[SI, MI] : Instrs) { |
| 1088 | OS << " " ; |
| 1089 | if (MI->definesRegister(Reg, TRI)) |
| 1090 | OS << "def " ; |
| 1091 | if (MI->readsRegister(Reg, TRI)) |
| 1092 | OS << "use " ; |
| 1093 | OS << printLoc(MI->getParent(), SI) << ": " << *MI; |
| 1094 | } |
| 1095 | }; |
| 1096 | |
| 1097 | OS << "\n*** Register pressure info (" << RegName << "s) for " << MF.getName() |
| 1098 | << " ***\n" ; |
| 1099 | OS << "Max pressure is " << MaxNumRegs << ' ' << RegName << "s at " |
| 1100 | << printLoc(MaxPressureMI->getParent(), MaxPressureSlot) << ": " |
| 1101 | << *MaxPressureMI; |
| 1102 | |
| 1103 | OS << "\nLive registers with single definition (" << SDefNumRegs << ' ' |
| 1104 | << RegName << "s):\n" ; |
| 1105 | |
| 1106 | // Sort SDefRegs by number of uses (smallest first) |
| 1107 | llvm::sort(SDefRegs, [&](Register A, Register B) { |
| 1108 | return std::distance(MRI.use_nodbg_begin(A), MRI.use_nodbg_end()) < |
| 1109 | std::distance(MRI.use_nodbg_begin(B), MRI.use_nodbg_end()); |
| 1110 | }); |
| 1111 | |
| 1112 | for (const Register Reg : SDefRegs) { |
| 1113 | PrintRegInfo(Reg, LiveSet->lookup(Reg)); |
| 1114 | } |
| 1115 | |
| 1116 | OS << "\nLive registers with multiple definitions (" << MDefNumRegs << ' ' |
| 1117 | << RegName << "s):\n" ; |
| 1118 | for (const Register Reg : MDefRegs) { |
| 1119 | PrintRegInfo(Reg, LiveSet->lookup(Reg)); |
| 1120 | } |
| 1121 | } |
| 1122 | #endif |
| 1123 | |