| 1 | //===- GCNRegPressure.cpp -------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | /// |
| 9 | /// \file |
| 10 | /// This file implements the GCNRegPressure class. |
| 11 | /// |
| 12 | //===----------------------------------------------------------------------===// |
| 13 | |
| 14 | #include "GCNRegPressure.h" |
| 15 | #include "AMDGPU.h" |
| 16 | #include "SIMachineFunctionInfo.h" |
| 17 | #include "llvm/CodeGen/RegisterPressure.h" |
| 18 | |
| 19 | using namespace llvm; |
| 20 | |
| 21 | #define DEBUG_TYPE "machine-scheduler" |
| 22 | |
| 23 | bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, |
| 24 | const GCNRPTracker::LiveRegSet &S2) { |
| 25 | if (S1.size() != S2.size()) |
| 26 | return false; |
| 27 | |
| 28 | for (const auto &P : S1) { |
| 29 | auto I = S2.find(Val: P.first); |
| 30 | if (I == S2.end() || I->second != P.second) |
| 31 | return false; |
| 32 | } |
| 33 | return true; |
| 34 | } |
| 35 | |
| 36 | /////////////////////////////////////////////////////////////////////////////// |
| 37 | // GCNRegPressure |
| 38 | |
| 39 | unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC, |
| 40 | const SIRegisterInfo *STI) { |
| 41 | return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR); |
| 42 | } |
| 43 | |
| 44 | void GCNRegPressure::inc(unsigned Reg, |
| 45 | LaneBitmask PrevMask, |
| 46 | LaneBitmask NewMask, |
| 47 | const MachineRegisterInfo &MRI) { |
| 48 | unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: NewMask); |
| 49 | unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: PrevMask); |
| 50 | if (NewNumCoveredRegs == PrevNumCoveredRegs) |
| 51 | return; |
| 52 | |
| 53 | int Sign = 1; |
| 54 | if (NewMask < PrevMask) { |
| 55 | std::swap(a&: NewMask, b&: PrevMask); |
| 56 | std::swap(a&: NewNumCoveredRegs, b&: PrevNumCoveredRegs); |
| 57 | Sign = -1; |
| 58 | } |
| 59 | assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs && |
| 60 | "prev mask should always be lesser than new" ); |
| 61 | |
| 62 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 63 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 64 | const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI); |
| 65 | unsigned RegKind = getRegKind(RC, STI); |
| 66 | if (TRI->getRegSizeInBits(RC: *RC) != 32) { |
| 67 | // Reg is from a tuple register class. |
| 68 | if (PrevMask.none()) { |
| 69 | unsigned TupleIdx = TOTAL_KINDS + RegKind; |
| 70 | Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight; |
| 71 | } |
| 72 | // Pressure scales with number of new registers covered by the new mask. |
| 73 | // Note when true16 is enabled, we can no longer safely use the following |
| 74 | // approach to calculate the difference in the number of 32-bit registers |
| 75 | // between two masks: |
| 76 | // |
| 77 | // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask); |
| 78 | // |
| 79 | // The issue is that the mask calculation `~PrevMask & NewMask` doesn't |
| 80 | // properly account for partial usage of a 32-bit register when dealing with |
| 81 | // 16-bit registers. |
| 82 | // |
| 83 | // Consider this example: |
| 84 | // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register |
| 85 | // usage difference should be 1, because even though PrevMask uses only half |
| 86 | // of a 32-bit register, it should still be counted as a full register use. |
| 87 | // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and |
| 88 | // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect |
| 89 | // calculation can lead to integer overflow when Sign = -1. |
| 90 | Sign *= NewNumCoveredRegs - PrevNumCoveredRegs; |
| 91 | } |
| 92 | Value[RegKind] += Sign; |
| 93 | } |
| 94 | |
| 95 | bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, |
| 96 | unsigned MaxOccupancy) const { |
| 97 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 98 | unsigned DynamicVGPRBlockSize = |
| 99 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 100 | |
| 101 | const auto SGPROcc = std::min(a: MaxOccupancy, |
| 102 | b: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum())); |
| 103 | const auto VGPROcc = std::min( |
| 104 | a: MaxOccupancy, b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 105 | DynamicVGPRBlockSize)); |
| 106 | const auto OtherSGPROcc = std::min(a: MaxOccupancy, |
| 107 | b: ST.getOccupancyWithNumSGPRs(SGPRs: O.getSGPRNum())); |
| 108 | const auto OtherVGPROcc = |
| 109 | std::min(a: MaxOccupancy, |
| 110 | b: ST.getOccupancyWithNumVGPRs(VGPRs: O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()), |
| 111 | DynamicVGPRBlockSize)); |
| 112 | |
| 113 | const auto Occ = std::min(a: SGPROcc, b: VGPROcc); |
| 114 | const auto OtherOcc = std::min(a: OtherSGPROcc, b: OtherVGPROcc); |
| 115 | |
| 116 | // Give first precedence to the better occupancy. |
| 117 | if (Occ != OtherOcc) |
| 118 | return Occ > OtherOcc; |
| 119 | |
| 120 | unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); |
| 121 | unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); |
| 122 | |
| 123 | // SGPR excess pressure conditions |
| 124 | unsigned ExcessSGPR = std::max(a: static_cast<int>(getSGPRNum() - MaxSGPRs), b: 0); |
| 125 | unsigned OtherExcessSGPR = |
| 126 | std::max(a: static_cast<int>(O.getSGPRNum() - MaxSGPRs), b: 0); |
| 127 | |
| 128 | auto WaveSize = ST.getWavefrontSize(); |
| 129 | // The number of virtual VGPRs required to handle excess SGPR |
| 130 | unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; |
| 131 | unsigned OtherVGPRForSGPRSpills = |
| 132 | (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; |
| 133 | |
| 134 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
| 135 | |
| 136 | // Unified excess pressure conditions, accounting for VGPRs used for SGPR |
| 137 | // spills |
| 138 | unsigned ExcessVGPR = |
| 139 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
| 140 | VGPRForSGPRSpills - MaxVGPRs), |
| 141 | b: 0); |
| 142 | unsigned OtherExcessVGPR = |
| 143 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
| 144 | OtherVGPRForSGPRSpills - MaxVGPRs), |
| 145 | b: 0); |
| 146 | // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR |
| 147 | // spills |
| 148 | unsigned ExcessArchVGPR = std::max( |
| 149 | a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) + VGPRForSGPRSpills - MaxArchVGPRs), |
| 150 | b: 0); |
| 151 | unsigned OtherExcessArchVGPR = |
| 152 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) + OtherVGPRForSGPRSpills - |
| 153 | MaxArchVGPRs), |
| 154 | b: 0); |
| 155 | // AGPR excess pressure conditions |
| 156 | unsigned ExcessAGPR = std::max( |
| 157 | a: static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) |
| 158 | : (getAGPRNum() - MaxVGPRs)), |
| 159 | b: 0); |
| 160 | unsigned OtherExcessAGPR = std::max( |
| 161 | a: static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) |
| 162 | : (O.getAGPRNum() - MaxVGPRs)), |
| 163 | b: 0); |
| 164 | |
| 165 | bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; |
| 166 | bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || |
| 167 | OtherExcessArchVGPR || OtherExcessAGPR; |
| 168 | |
| 169 | // Give second precedence to the reduced number of spills to hold the register |
| 170 | // pressure. |
| 171 | if (ExcessRP || OtherExcessRP) { |
| 172 | // The difference in excess VGPR pressure, after including VGPRs used for |
| 173 | // SGPR spills |
| 174 | int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - |
| 175 | (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); |
| 176 | |
| 177 | int SGPRDiff = OtherExcessSGPR - ExcessSGPR; |
| 178 | |
| 179 | if (VGPRDiff != 0) |
| 180 | return VGPRDiff > 0; |
| 181 | if (SGPRDiff != 0) { |
| 182 | unsigned PureExcessVGPR = |
| 183 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 184 | b: 0) + |
| 185 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 186 | unsigned OtherPureExcessVGPR = |
| 187 | std::max( |
| 188 | a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
| 189 | b: 0) + |
| 190 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
| 191 | |
| 192 | // If we have a special case where there is a tie in excess VGPR, but one |
| 193 | // of the pressures has VGPR usage from SGPR spills, prefer the pressure |
| 194 | // with SGPR spills. |
| 195 | if (PureExcessVGPR != OtherPureExcessVGPR) |
| 196 | return SGPRDiff < 0; |
| 197 | // If both pressures have the same excess pressure before and after |
| 198 | // accounting for SGPR spills, prefer fewer SGPR spills. |
| 199 | return SGPRDiff > 0; |
| 200 | } |
| 201 | } |
| 202 | |
| 203 | bool SGPRImportant = SGPROcc < VGPROcc; |
| 204 | const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; |
| 205 | |
| 206 | // If both pressures disagree on what is more important compare vgprs. |
| 207 | if (SGPRImportant != OtherSGPRImportant) { |
| 208 | SGPRImportant = false; |
| 209 | } |
| 210 | |
| 211 | // Give third precedence to lower register tuple pressure. |
| 212 | bool SGPRFirst = SGPRImportant; |
| 213 | for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { |
| 214 | if (SGPRFirst) { |
| 215 | auto SW = getSGPRTuplesWeight(); |
| 216 | auto OtherSW = O.getSGPRTuplesWeight(); |
| 217 | if (SW != OtherSW) |
| 218 | return SW < OtherSW; |
| 219 | } else { |
| 220 | auto VW = getVGPRTuplesWeight(); |
| 221 | auto OtherVW = O.getVGPRTuplesWeight(); |
| 222 | if (VW != OtherVW) |
| 223 | return VW < OtherVW; |
| 224 | } |
| 225 | } |
| 226 | |
| 227 | // Give final precedence to lower general RP. |
| 228 | return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): |
| 229 | (getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) < |
| 230 | O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts())); |
| 231 | } |
| 232 | |
| 233 | Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST, |
| 234 | unsigned DynamicVGPRBlockSize) { |
| 235 | return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) { |
| 236 | OS << "VGPRs: " << RP.getArchVGPRNum() << ' ' |
| 237 | << "AGPRs: " << RP.getAGPRNum(); |
| 238 | if (ST) |
| 239 | OS << "(O" |
| 240 | << ST->getOccupancyWithNumVGPRs(VGPRs: RP.getVGPRNum(UnifiedVGPRFile: ST->hasGFX90AInsts()), |
| 241 | DynamicVGPRBlockSize) |
| 242 | << ')'; |
| 243 | OS << ", SGPRs: " << RP.getSGPRNum(); |
| 244 | if (ST) |
| 245 | OS << "(O" << ST->getOccupancyWithNumSGPRs(SGPRs: RP.getSGPRNum()) << ')'; |
| 246 | OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() |
| 247 | << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); |
| 248 | if (ST) |
| 249 | OS << " -> Occ: " << RP.getOccupancy(ST: *ST, DynamicVGPRBlockSize); |
| 250 | OS << '\n'; |
| 251 | }); |
| 252 | } |
| 253 | |
| 254 | static LaneBitmask getDefRegMask(const MachineOperand &MO, |
| 255 | const MachineRegisterInfo &MRI) { |
| 256 | assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); |
| 257 | |
| 258 | // We don't rely on read-undef flag because in case of tentative schedule |
| 259 | // tracking it isn't set correctly yet. This works correctly however since |
| 260 | // use mask has been tracked before using LIS. |
| 261 | return MO.getSubReg() == 0 ? |
| 262 | MRI.getMaxLaneMaskForVReg(Reg: MO.getReg()) : |
| 263 | MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubIdx: MO.getSubReg()); |
| 264 | } |
| 265 | |
| 266 | static void |
| 267 | collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits, |
| 268 | const MachineInstr &MI, const LiveIntervals &LIS, |
| 269 | const MachineRegisterInfo &MRI) { |
| 270 | |
| 271 | auto &TRI = *MRI.getTargetRegisterInfo(); |
| 272 | for (const auto &MO : MI.operands()) { |
| 273 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 274 | continue; |
| 275 | if (!MO.isUse() || !MO.readsReg()) |
| 276 | continue; |
| 277 | |
| 278 | Register Reg = MO.getReg(); |
| 279 | auto I = llvm::find_if(Range&: VRegMaskOrUnits, P: [Reg](const VRegMaskOrUnit &RM) { |
| 280 | return RM.RegUnit == Reg; |
| 281 | }); |
| 282 | |
| 283 | auto &P = I == VRegMaskOrUnits.end() |
| 284 | ? VRegMaskOrUnits.emplace_back(Args&: Reg, Args: LaneBitmask::getNone()) |
| 285 | : *I; |
| 286 | |
| 287 | P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(SubIdx: MO.getSubReg()) |
| 288 | : MRI.getMaxLaneMaskForVReg(Reg); |
| 289 | } |
| 290 | |
| 291 | SlotIndex InstrSI; |
| 292 | for (auto &P : VRegMaskOrUnits) { |
| 293 | auto &LI = LIS.getInterval(Reg: P.RegUnit); |
| 294 | if (!LI.hasSubRanges()) |
| 295 | continue; |
| 296 | |
| 297 | // For a tentative schedule LIS isn't updated yet but livemask should |
| 298 | // remain the same on any schedule. Subreg defs can be reordered but they |
| 299 | // all must dominate uses anyway. |
| 300 | if (!InstrSI) |
| 301 | InstrSI = LIS.getInstructionIndex(Instr: MI).getBaseIndex(); |
| 302 | |
| 303 | P.LaneMask = getLiveLaneMask(LI, SI: InstrSI, MRI, LaneMaskFilter: P.LaneMask); |
| 304 | } |
| 305 | } |
| 306 | |
| 307 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 308 | static LaneBitmask getLanesWithProperty( |
| 309 | const LiveIntervals &LIS, const MachineRegisterInfo &MRI, |
| 310 | bool TrackLaneMasks, Register RegUnit, SlotIndex Pos, |
| 311 | LaneBitmask SafeDefault, |
| 312 | function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) { |
| 313 | if (RegUnit.isVirtual()) { |
| 314 | const LiveInterval &LI = LIS.getInterval(Reg: RegUnit); |
| 315 | LaneBitmask Result; |
| 316 | if (TrackLaneMasks && LI.hasSubRanges()) { |
| 317 | for (const LiveInterval::SubRange &SR : LI.subranges()) { |
| 318 | if (Property(SR, Pos)) |
| 319 | Result |= SR.LaneMask; |
| 320 | } |
| 321 | } else if (Property(LI, Pos)) { |
| 322 | Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(Reg: RegUnit) |
| 323 | : LaneBitmask::getAll(); |
| 324 | } |
| 325 | |
| 326 | return Result; |
| 327 | } |
| 328 | |
| 329 | const LiveRange *LR = LIS.getCachedRegUnit(Unit: RegUnit); |
| 330 | if (LR == nullptr) |
| 331 | return SafeDefault; |
| 332 | return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone(); |
| 333 | } |
| 334 | |
| 335 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 336 | /// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}. |
| 337 | /// The query starts with a lane bitmask which gets lanes/bits removed for every |
| 338 | /// use we find. |
| 339 | static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask, |
| 340 | SlotIndex PriorUseIdx, SlotIndex NextUseIdx, |
| 341 | const MachineRegisterInfo &MRI, |
| 342 | const SIRegisterInfo *TRI, |
| 343 | const LiveIntervals *LIS, |
| 344 | bool Upward = false) { |
| 345 | for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) { |
| 346 | if (MO.isUndef()) |
| 347 | continue; |
| 348 | const MachineInstr *MI = MO.getParent(); |
| 349 | SlotIndex InstSlot = LIS->getInstructionIndex(Instr: *MI).getRegSlot(); |
| 350 | bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx) |
| 351 | : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx); |
| 352 | if (!InRange) |
| 353 | continue; |
| 354 | |
| 355 | unsigned SubRegIdx = MO.getSubReg(); |
| 356 | LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubIdx: SubRegIdx); |
| 357 | LastUseMask &= ~UseMask; |
| 358 | if (LastUseMask.none()) |
| 359 | return LaneBitmask::getNone(); |
| 360 | } |
| 361 | return LastUseMask; |
| 362 | } |
| 363 | |
| 364 | //////////////////////////////////////////////////////////////////////////////// |
| 365 | // GCNRPTarget |
| 366 | |
| 367 | GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP, |
| 368 | bool CombineVGPRSavings) |
| 369 | : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { |
| 370 | const Function &F = MF.getFunction(); |
| 371 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 372 | setRegLimits(MaxSGPRs: ST.getMaxNumSGPRs(F), MaxVGPRs: ST.getMaxNumVGPRs(F), MF); |
| 373 | } |
| 374 | |
| 375 | GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs, |
| 376 | const MachineFunction &MF, const GCNRegPressure &RP, |
| 377 | bool CombineVGPRSavings) |
| 378 | : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { |
| 379 | setRegLimits(MaxSGPRs: NumSGPRs, MaxVGPRs: NumVGPRs, MF); |
| 380 | } |
| 381 | |
| 382 | GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF, |
| 383 | const GCNRegPressure &RP, bool CombineVGPRSavings) |
| 384 | : RP(RP), CombineVGPRSavings(CombineVGPRSavings) { |
| 385 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 386 | unsigned DynamicVGPRBlockSize = |
| 387 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 388 | setRegLimits(MaxSGPRs: ST.getMaxNumSGPRs(WavesPerEU: Occupancy, /*Addressable=*/false), |
| 389 | MaxVGPRs: ST.getMaxNumVGPRs(WavesPerEU: Occupancy, DynamicVGPRBlockSize), MF); |
| 390 | } |
| 391 | |
| 392 | void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs, |
| 393 | const MachineFunction &MF) { |
| 394 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
| 395 | unsigned DynamicVGPRBlockSize = |
| 396 | MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize(); |
| 397 | MaxSGPRs = std::min(a: ST.getAddressableNumSGPRs(), b: NumSGPRs); |
| 398 | MaxVGPRs = std::min(a: ST.getAddressableNumArchVGPRs(), b: NumVGPRs); |
| 399 | MaxUnifiedVGPRs = |
| 400 | ST.hasGFX90AInsts() |
| 401 | ? std::min(a: ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), b: NumVGPRs) |
| 402 | : 0; |
| 403 | } |
| 404 | |
| 405 | bool GCNRPTarget::isSaveBeneficial(Register Reg, |
| 406 | const MachineRegisterInfo &MRI) const { |
| 407 | const TargetRegisterClass *RC = MRI.getRegClass(Reg); |
| 408 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 409 | const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI); |
| 410 | |
| 411 | if (SRI->isSGPRClass(RC)) |
| 412 | return RP.getSGPRNum() > MaxSGPRs; |
| 413 | unsigned NumVGPRs = |
| 414 | SRI->isAGPRClass(RC) ? RP.getAGPRNum() : RP.getArchVGPRNum(); |
| 415 | return isVGPRBankSaveBeneficial(NumVGPRs); |
| 416 | } |
| 417 | |
| 418 | bool GCNRPTarget::satisfied() const { |
| 419 | if (RP.getSGPRNum() > MaxSGPRs) |
| 420 | return false; |
| 421 | if (RP.getVGPRNum(UnifiedVGPRFile: false) > MaxVGPRs && |
| 422 | (!CombineVGPRSavings || !satisifiesVGPRBanksTarget())) |
| 423 | return false; |
| 424 | return satisfiesUnifiedTarget(); |
| 425 | } |
| 426 | |
| 427 | /////////////////////////////////////////////////////////////////////////////// |
| 428 | // GCNRPTracker |
| 429 | |
| 430 | LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, |
| 431 | const LiveIntervals &LIS, |
| 432 | const MachineRegisterInfo &MRI, |
| 433 | LaneBitmask LaneMaskFilter) { |
| 434 | return getLiveLaneMask(LI: LIS.getInterval(Reg), SI, MRI, LaneMaskFilter); |
| 435 | } |
| 436 | |
| 437 | LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, |
| 438 | const MachineRegisterInfo &MRI, |
| 439 | LaneBitmask LaneMaskFilter) { |
| 440 | LaneBitmask LiveMask; |
| 441 | if (LI.hasSubRanges()) { |
| 442 | for (const auto &S : LI.subranges()) |
| 443 | if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(index: SI)) { |
| 444 | LiveMask |= S.LaneMask; |
| 445 | assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); |
| 446 | } |
| 447 | } else if (LI.liveAt(index: SI)) { |
| 448 | LiveMask = MRI.getMaxLaneMaskForVReg(Reg: LI.reg()); |
| 449 | } |
| 450 | LiveMask &= LaneMaskFilter; |
| 451 | return LiveMask; |
| 452 | } |
| 453 | |
| 454 | GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, |
| 455 | const LiveIntervals &LIS, |
| 456 | const MachineRegisterInfo &MRI) { |
| 457 | GCNRPTracker::LiveRegSet LiveRegs; |
| 458 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 459 | auto Reg = Register::index2VirtReg(Index: I); |
| 460 | if (!LIS.hasInterval(Reg)) |
| 461 | continue; |
| 462 | auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); |
| 463 | if (LiveMask.any()) |
| 464 | LiveRegs[Reg] = LiveMask; |
| 465 | } |
| 466 | return LiveRegs; |
| 467 | } |
| 468 | |
| 469 | void GCNRPTracker::reset(const MachineInstr &MI, |
| 470 | const LiveRegSet *LiveRegsCopy, |
| 471 | bool After) { |
| 472 | const MachineFunction &MF = *MI.getMF(); |
| 473 | MRI = &MF.getRegInfo(); |
| 474 | if (LiveRegsCopy) { |
| 475 | if (&LiveRegs != LiveRegsCopy) |
| 476 | LiveRegs = *LiveRegsCopy; |
| 477 | } else { |
| 478 | LiveRegs = After ? getLiveRegsAfter(MI, LIS) |
| 479 | : getLiveRegsBefore(MI, LIS); |
| 480 | } |
| 481 | |
| 482 | MaxPressure = CurPressure = getRegPressure(MRI: *MRI, LiveRegs); |
| 483 | } |
| 484 | |
| 485 | void GCNRPTracker::reset(const MachineRegisterInfo &MRI_, |
| 486 | const LiveRegSet &LiveRegs_) { |
| 487 | MRI = &MRI_; |
| 488 | LiveRegs = LiveRegs_; |
| 489 | LastTrackedMI = nullptr; |
| 490 | MaxPressure = CurPressure = getRegPressure(MRI: MRI_, LiveRegs: LiveRegs_); |
| 491 | } |
| 492 | |
| 493 | /// Mostly copy/paste from CodeGen/RegisterPressure.cpp |
| 494 | LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit, |
| 495 | SlotIndex Pos) const { |
| 496 | return getLanesWithProperty( |
| 497 | LIS, MRI: *MRI, TrackLaneMasks: true, RegUnit, Pos: Pos.getBaseIndex(), SafeDefault: LaneBitmask::getNone(), |
| 498 | Property: [](const LiveRange &LR, SlotIndex Pos) { |
| 499 | const LiveRange::Segment *S = LR.getSegmentContaining(Idx: Pos); |
| 500 | return S != nullptr && S->end == Pos.getRegSlot(); |
| 501 | }); |
| 502 | } |
| 503 | |
| 504 | //////////////////////////////////////////////////////////////////////////////// |
| 505 | // GCNUpwardRPTracker |
| 506 | |
| 507 | void GCNUpwardRPTracker::recede(const MachineInstr &MI) { |
| 508 | assert(MRI && "call reset first" ); |
| 509 | |
| 510 | LastTrackedMI = &MI; |
| 511 | |
| 512 | if (MI.isDebugInstr()) |
| 513 | return; |
| 514 | |
| 515 | // Kill all defs. |
| 516 | GCNRegPressure DefPressure, ECDefPressure; |
| 517 | bool HasECDefs = false; |
| 518 | for (const MachineOperand &MO : MI.all_defs()) { |
| 519 | if (!MO.getReg().isVirtual()) |
| 520 | continue; |
| 521 | |
| 522 | Register Reg = MO.getReg(); |
| 523 | LaneBitmask DefMask = getDefRegMask(MO, MRI: *MRI); |
| 524 | |
| 525 | // Treat a def as fully live at the moment of definition: keep a record. |
| 526 | if (MO.isEarlyClobber()) { |
| 527 | ECDefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 528 | HasECDefs = true; |
| 529 | } else |
| 530 | DefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
| 531 | |
| 532 | auto I = LiveRegs.find(Val: Reg); |
| 533 | if (I == LiveRegs.end()) |
| 534 | continue; |
| 535 | |
| 536 | LaneBitmask &LiveMask = I->second; |
| 537 | LaneBitmask PrevMask = LiveMask; |
| 538 | LiveMask &= ~DefMask; |
| 539 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 540 | if (LiveMask.none()) |
| 541 | LiveRegs.erase(I); |
| 542 | } |
| 543 | |
| 544 | // Update MaxPressure with defs pressure. |
| 545 | DefPressure += CurPressure; |
| 546 | if (HasECDefs) |
| 547 | DefPressure += ECDefPressure; |
| 548 | MaxPressure = max(P1: DefPressure, P2: MaxPressure); |
| 549 | |
| 550 | // Make uses alive. |
| 551 | SmallVector<VRegMaskOrUnit, 8> RegUses; |
| 552 | collectVirtualRegUses(VRegMaskOrUnits&: RegUses, MI, LIS, MRI: *MRI); |
| 553 | for (const VRegMaskOrUnit &U : RegUses) { |
| 554 | LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; |
| 555 | LaneBitmask PrevMask = LiveMask; |
| 556 | LiveMask |= U.LaneMask; |
| 557 | CurPressure.inc(Reg: U.RegUnit, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 558 | } |
| 559 | |
| 560 | // Update MaxPressure with uses plus early-clobber defs pressure. |
| 561 | MaxPressure = HasECDefs ? max(P1: CurPressure + ECDefPressure, P2: MaxPressure) |
| 562 | : max(P1: CurPressure, P2: MaxPressure); |
| 563 | |
| 564 | assert(CurPressure == getRegPressure(*MRI, LiveRegs)); |
| 565 | } |
| 566 | |
| 567 | //////////////////////////////////////////////////////////////////////////////// |
| 568 | // GCNDownwardRPTracker |
| 569 | |
| 570 | bool GCNDownwardRPTracker::reset(const MachineInstr &MI, |
| 571 | const LiveRegSet *LiveRegsCopy) { |
| 572 | MRI = &MI.getParent()->getParent()->getRegInfo(); |
| 573 | LastTrackedMI = nullptr; |
| 574 | MBBEnd = MI.getParent()->end(); |
| 575 | NextMI = &MI; |
| 576 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
| 577 | if (NextMI == MBBEnd) |
| 578 | return false; |
| 579 | GCNRPTracker::reset(MI: *NextMI, LiveRegsCopy, After: false); |
| 580 | return true; |
| 581 | } |
| 582 | |
| 583 | bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI, |
| 584 | bool UseInternalIterator) { |
| 585 | assert(MRI && "call reset first" ); |
| 586 | SlotIndex SI; |
| 587 | const MachineInstr *CurrMI; |
| 588 | if (UseInternalIterator) { |
| 589 | if (!LastTrackedMI) |
| 590 | return NextMI == MBBEnd; |
| 591 | |
| 592 | assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); |
| 593 | CurrMI = LastTrackedMI; |
| 594 | |
| 595 | SI = NextMI == MBBEnd |
| 596 | ? LIS.getInstructionIndex(Instr: *LastTrackedMI).getDeadSlot() |
| 597 | : LIS.getInstructionIndex(Instr: *NextMI).getBaseIndex(); |
| 598 | } else { //! UseInternalIterator |
| 599 | SI = LIS.getInstructionIndex(Instr: *MI).getBaseIndex(); |
| 600 | CurrMI = MI; |
| 601 | } |
| 602 | |
| 603 | assert(SI.isValid()); |
| 604 | |
| 605 | // Remove dead registers or mask bits. |
| 606 | SmallSet<Register, 8> SeenRegs; |
| 607 | for (auto &MO : CurrMI->operands()) { |
| 608 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
| 609 | continue; |
| 610 | if (MO.isUse() && !MO.readsReg()) |
| 611 | continue; |
| 612 | if (!UseInternalIterator && MO.isDef()) |
| 613 | continue; |
| 614 | if (!SeenRegs.insert(V: MO.getReg()).second) |
| 615 | continue; |
| 616 | const LiveInterval &LI = LIS.getInterval(Reg: MO.getReg()); |
| 617 | if (LI.hasSubRanges()) { |
| 618 | auto It = LiveRegs.end(); |
| 619 | for (const auto &S : LI.subranges()) { |
| 620 | if (!S.liveAt(index: SI)) { |
| 621 | if (It == LiveRegs.end()) { |
| 622 | It = LiveRegs.find(Val: MO.getReg()); |
| 623 | if (It == LiveRegs.end()) |
| 624 | llvm_unreachable("register isn't live" ); |
| 625 | } |
| 626 | auto PrevMask = It->second; |
| 627 | It->second &= ~S.LaneMask; |
| 628 | CurPressure.inc(Reg: MO.getReg(), PrevMask, NewMask: It->second, MRI: *MRI); |
| 629 | } |
| 630 | } |
| 631 | if (It != LiveRegs.end() && It->second.none()) |
| 632 | LiveRegs.erase(I: It); |
| 633 | } else if (!LI.liveAt(index: SI)) { |
| 634 | auto It = LiveRegs.find(Val: MO.getReg()); |
| 635 | if (It == LiveRegs.end()) |
| 636 | llvm_unreachable("register isn't live" ); |
| 637 | CurPressure.inc(Reg: MO.getReg(), PrevMask: It->second, NewMask: LaneBitmask::getNone(), MRI: *MRI); |
| 638 | LiveRegs.erase(I: It); |
| 639 | } |
| 640 | } |
| 641 | |
| 642 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 643 | |
| 644 | LastTrackedMI = nullptr; |
| 645 | |
| 646 | return UseInternalIterator && (NextMI == MBBEnd); |
| 647 | } |
| 648 | |
| 649 | void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI, |
| 650 | bool UseInternalIterator) { |
| 651 | if (UseInternalIterator) { |
| 652 | LastTrackedMI = &*NextMI++; |
| 653 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
| 654 | } else { |
| 655 | LastTrackedMI = MI; |
| 656 | } |
| 657 | |
| 658 | const MachineInstr *CurrMI = LastTrackedMI; |
| 659 | |
| 660 | // Add new registers or mask bits. |
| 661 | for (const auto &MO : CurrMI->all_defs()) { |
| 662 | Register Reg = MO.getReg(); |
| 663 | if (!Reg.isVirtual()) |
| 664 | continue; |
| 665 | auto &LiveMask = LiveRegs[Reg]; |
| 666 | auto PrevMask = LiveMask; |
| 667 | LiveMask |= getDefRegMask(MO, MRI: *MRI); |
| 668 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
| 669 | } |
| 670 | |
| 671 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
| 672 | } |
| 673 | |
| 674 | bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) { |
| 675 | if (UseInternalIterator && NextMI == MBBEnd) |
| 676 | return false; |
| 677 | |
| 678 | advanceBeforeNext(MI, UseInternalIterator); |
| 679 | advanceToNext(MI, UseInternalIterator); |
| 680 | if (!UseInternalIterator) { |
| 681 | // We must remove any dead def lanes from the current RP |
| 682 | advanceBeforeNext(MI, UseInternalIterator: true); |
| 683 | } |
| 684 | return true; |
| 685 | } |
| 686 | |
| 687 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { |
| 688 | while (NextMI != End) |
| 689 | if (!advance()) return false; |
| 690 | return true; |
| 691 | } |
| 692 | |
| 693 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, |
| 694 | MachineBasicBlock::const_iterator End, |
| 695 | const LiveRegSet *LiveRegsCopy) { |
| 696 | reset(MI: *Begin, LiveRegsCopy); |
| 697 | return advance(End); |
| 698 | } |
| 699 | |
| 700 | Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, |
| 701 | const GCNRPTracker::LiveRegSet &TrackedLR, |
| 702 | const TargetRegisterInfo *TRI, StringRef Pfx) { |
| 703 | return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { |
| 704 | for (auto const &P : TrackedLR) { |
| 705 | auto I = LISLR.find(Val: P.first); |
| 706 | if (I == LISLR.end()) { |
| 707 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 708 | << " isn't found in LIS reported set\n" ; |
| 709 | } else if (I->second != P.second) { |
| 710 | OS << Pfx << printReg(Reg: P.first, TRI) |
| 711 | << " masks doesn't match: LIS reported " << PrintLaneMask(LaneMask: I->second) |
| 712 | << ", tracked " << PrintLaneMask(LaneMask: P.second) << '\n'; |
| 713 | } |
| 714 | } |
| 715 | for (auto const &P : LISLR) { |
| 716 | auto I = TrackedLR.find(Val: P.first); |
| 717 | if (I == TrackedLR.end()) { |
| 718 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
| 719 | << " isn't found in tracked set\n" ; |
| 720 | } |
| 721 | } |
| 722 | }); |
| 723 | } |
| 724 | |
| 725 | GCNRegPressure |
| 726 | GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI, |
| 727 | const SIRegisterInfo *TRI) const { |
| 728 | assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction." ); |
| 729 | |
| 730 | SlotIndex SlotIdx; |
| 731 | SlotIdx = LIS.getInstructionIndex(Instr: *MI).getRegSlot(); |
| 732 | |
| 733 | // Account for register pressure similar to RegPressureTracker::recede(). |
| 734 | RegisterOperands RegOpers; |
| 735 | RegOpers.collect(MI: *MI, TRI: *TRI, MRI: *MRI, TrackLaneMasks: true, /*IgnoreDead=*/false); |
| 736 | RegOpers.adjustLaneLiveness(LIS, MRI: *MRI, Pos: SlotIdx); |
| 737 | GCNRegPressure TempPressure = CurPressure; |
| 738 | |
| 739 | for (const VRegMaskOrUnit &Use : RegOpers.Uses) { |
| 740 | Register Reg = Use.RegUnit; |
| 741 | if (!Reg.isVirtual()) |
| 742 | continue; |
| 743 | LaneBitmask LastUseMask = getLastUsedLanes(RegUnit: Reg, Pos: SlotIdx); |
| 744 | if (LastUseMask.none()) |
| 745 | continue; |
| 746 | // The LastUseMask is queried from the liveness information of instruction |
| 747 | // which may be further down the schedule. Some lanes may actually not be |
| 748 | // last uses for the current position. |
| 749 | // FIXME: allow the caller to pass in the list of vreg uses that remain |
| 750 | // to be bottom-scheduled to avoid searching uses at each query. |
| 751 | SlotIndex CurrIdx; |
| 752 | const MachineBasicBlock *MBB = MI->getParent(); |
| 753 | MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward( |
| 754 | It: LastTrackedMI ? LastTrackedMI : MBB->begin(), End: MBB->end()); |
| 755 | if (IdxPos == MBB->end()) { |
| 756 | CurrIdx = LIS.getMBBEndIdx(mbb: MBB); |
| 757 | } else { |
| 758 | CurrIdx = LIS.getInstructionIndex(Instr: *IdxPos).getRegSlot(); |
| 759 | } |
| 760 | |
| 761 | LastUseMask = |
| 762 | findUseBetween(Reg, LastUseMask, PriorUseIdx: CurrIdx, NextUseIdx: SlotIdx, MRI: *MRI, TRI, LIS: &LIS); |
| 763 | if (LastUseMask.none()) |
| 764 | continue; |
| 765 | |
| 766 | auto It = LiveRegs.find(Val: Reg); |
| 767 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 768 | LaneBitmask NewMask = LiveMask & ~LastUseMask; |
| 769 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 770 | } |
| 771 | |
| 772 | // Generate liveness for defs. |
| 773 | for (const VRegMaskOrUnit &Def : RegOpers.Defs) { |
| 774 | Register Reg = Def.RegUnit; |
| 775 | if (!Reg.isVirtual()) |
| 776 | continue; |
| 777 | auto It = LiveRegs.find(Val: Reg); |
| 778 | LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0); |
| 779 | LaneBitmask NewMask = LiveMask | Def.LaneMask; |
| 780 | TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI); |
| 781 | } |
| 782 | |
| 783 | return TempPressure; |
| 784 | } |
| 785 | |
| 786 | bool GCNUpwardRPTracker::isValid() const { |
| 787 | const auto &SI = LIS.getInstructionIndex(Instr: *LastTrackedMI).getBaseIndex(); |
| 788 | const auto LISLR = llvm::getLiveRegs(SI, LIS, MRI: *MRI); |
| 789 | const auto &TrackedLR = LiveRegs; |
| 790 | |
| 791 | if (!isEqual(S1: LISLR, S2: TrackedLR)) { |
| 792 | dbgs() << "\nGCNUpwardRPTracker error: Tracked and" |
| 793 | " LIS reported livesets mismatch:\n" |
| 794 | << print(LiveRegs: LISLR, MRI: *MRI); |
| 795 | reportMismatch(LISLR, TrackedLR, TRI: MRI->getTargetRegisterInfo()); |
| 796 | return false; |
| 797 | } |
| 798 | |
| 799 | auto LISPressure = getRegPressure(MRI: *MRI, LiveRegs: LISLR); |
| 800 | if (LISPressure != CurPressure) { |
| 801 | dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " |
| 802 | << print(RP: CurPressure) << "LIS rpt: " << print(RP: LISPressure); |
| 803 | return false; |
| 804 | } |
| 805 | return true; |
| 806 | } |
| 807 | |
| 808 | Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, |
| 809 | const MachineRegisterInfo &MRI) { |
| 810 | return Printable([&LiveRegs, &MRI](raw_ostream &OS) { |
| 811 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 812 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
| 813 | Register Reg = Register::index2VirtReg(Index: I); |
| 814 | auto It = LiveRegs.find(Val: Reg); |
| 815 | if (It != LiveRegs.end() && It->second.any()) |
| 816 | OS << ' ' << printVRegOrUnit(VRegOrUnit: Reg, TRI) << ':' |
| 817 | << PrintLaneMask(LaneMask: It->second); |
| 818 | } |
| 819 | OS << '\n'; |
| 820 | }); |
| 821 | } |
| 822 | |
| 823 | void GCNRegPressure::dump() const { dbgs() << print(RP: *this); } |
| 824 | |
| 825 | static cl::opt<bool> UseDownwardTracker( |
| 826 | "amdgpu-print-rp-downward" , |
| 827 | cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass" ), |
| 828 | cl::init(Val: false), cl::Hidden); |
| 829 | |
| 830 | char llvm::GCNRegPressurePrinter::ID = 0; |
| 831 | char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; |
| 832 | |
| 833 | INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp" , "" , true, true) |
| 834 | |
| 835 | // Return lanemask of Reg's subregs that are live-through at [Begin, End] and |
| 836 | // are fully covered by Mask. |
| 837 | static LaneBitmask |
| 838 | getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, |
| 839 | Register Reg, SlotIndex Begin, SlotIndex End, |
| 840 | LaneBitmask Mask = LaneBitmask::getAll()) { |
| 841 | |
| 842 | auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { |
| 843 | auto *Segment = LR.getSegmentContaining(Idx: Begin); |
| 844 | return Segment && Segment->contains(I: End); |
| 845 | }; |
| 846 | |
| 847 | LaneBitmask LiveThroughMask; |
| 848 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 849 | if (LI.hasSubRanges()) { |
| 850 | for (auto &SR : LI.subranges()) { |
| 851 | if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) |
| 852 | LiveThroughMask |= SR.LaneMask; |
| 853 | } |
| 854 | } else { |
| 855 | LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); |
| 856 | if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) |
| 857 | LiveThroughMask = RegMask; |
| 858 | } |
| 859 | |
| 860 | return LiveThroughMask; |
| 861 | } |
| 862 | |
| 863 | bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { |
| 864 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
| 865 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
| 866 | const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); |
| 867 | |
| 868 | auto &OS = dbgs(); |
| 869 | |
| 870 | // Leading spaces are important for YAML syntax. |
| 871 | #define PFX " " |
| 872 | |
| 873 | OS << "---\nname: " << MF.getName() << "\nbody: |\n" ; |
| 874 | |
| 875 | auto printRP = [](const GCNRegPressure &RP) { |
| 876 | return Printable([&RP](raw_ostream &OS) { |
| 877 | OS << format(PFX " %-5d" , Vals: RP.getSGPRNum()) |
| 878 | << format(Fmt: " %-5d" , Vals: RP.getVGPRNum(UnifiedVGPRFile: false)); |
| 879 | }); |
| 880 | }; |
| 881 | |
| 882 | auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, |
| 883 | const GCNRPTracker::LiveRegSet &LISLR) { |
| 884 | if (LISLR != TrackedLR) { |
| 885 | OS << PFX " mis LIS: " << llvm::print(LiveRegs: LISLR, MRI) |
| 886 | << reportMismatch(LISLR, TrackedLR, TRI, PFX " " ); |
| 887 | } |
| 888 | }; |
| 889 | |
| 890 | // Register pressure before and at an instruction (in program order). |
| 891 | SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; |
| 892 | |
| 893 | for (auto &MBB : MF) { |
| 894 | RP.clear(); |
| 895 | RP.reserve(N: MBB.size()); |
| 896 | |
| 897 | OS << PFX; |
| 898 | MBB.printName(os&: OS); |
| 899 | OS << ":\n" ; |
| 900 | |
| 901 | SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(mbb: &MBB); |
| 902 | SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(mbb: &MBB); |
| 903 | |
| 904 | GCNRPTracker::LiveRegSet LiveIn, LiveOut; |
| 905 | GCNRegPressure RPAtMBBEnd; |
| 906 | |
| 907 | if (UseDownwardTracker) { |
| 908 | if (MBB.empty()) { |
| 909 | LiveIn = LiveOut = getLiveRegs(SI: MBBStartSlot, LIS, MRI); |
| 910 | RPAtMBBEnd = getRegPressure(MRI, LiveRegs&: LiveIn); |
| 911 | } else { |
| 912 | GCNDownwardRPTracker RPT(LIS); |
| 913 | RPT.reset(MI: MBB.front()); |
| 914 | |
| 915 | LiveIn = RPT.getLiveRegs(); |
| 916 | |
| 917 | while (!RPT.advanceBeforeNext()) { |
| 918 | GCNRegPressure RPBeforeMI = RPT.getPressure(); |
| 919 | RPT.advanceToNext(); |
| 920 | RP.emplace_back(Args&: RPBeforeMI, Args: RPT.getPressure()); |
| 921 | } |
| 922 | |
| 923 | LiveOut = RPT.getLiveRegs(); |
| 924 | RPAtMBBEnd = RPT.getPressure(); |
| 925 | } |
| 926 | } else { |
| 927 | GCNUpwardRPTracker RPT(LIS); |
| 928 | RPT.reset(MRI, SI: MBBEndSlot); |
| 929 | |
| 930 | LiveOut = RPT.getLiveRegs(); |
| 931 | RPAtMBBEnd = RPT.getPressure(); |
| 932 | |
| 933 | for (auto &MI : reverse(C&: MBB)) { |
| 934 | RPT.resetMaxPressure(); |
| 935 | RPT.recede(MI); |
| 936 | if (!MI.isDebugInstr()) |
| 937 | RP.emplace_back(Args: RPT.getPressure(), Args: RPT.getMaxPressure()); |
| 938 | } |
| 939 | |
| 940 | LiveIn = RPT.getLiveRegs(); |
| 941 | } |
| 942 | |
| 943 | OS << PFX " Live-in: " << llvm::print(LiveRegs: LiveIn, MRI); |
| 944 | if (!UseDownwardTracker) |
| 945 | ReportLISMismatchIfAny(LiveIn, getLiveRegs(SI: MBBStartSlot, LIS, MRI)); |
| 946 | |
| 947 | OS << PFX " SGPR VGPR\n" ; |
| 948 | int I = 0; |
| 949 | for (auto &MI : MBB) { |
| 950 | if (!MI.isDebugInstr()) { |
| 951 | auto &[RPBeforeInstr, RPAtInstr] = |
| 952 | RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; |
| 953 | ++I; |
| 954 | OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " " ; |
| 955 | } else |
| 956 | OS << PFX " " ; |
| 957 | MI.print(OS); |
| 958 | } |
| 959 | OS << printRP(RPAtMBBEnd) << '\n'; |
| 960 | |
| 961 | OS << PFX " Live-out:" << llvm::print(LiveRegs: LiveOut, MRI); |
| 962 | if (UseDownwardTracker) |
| 963 | ReportLISMismatchIfAny(LiveOut, getLiveRegs(SI: MBBEndSlot, LIS, MRI)); |
| 964 | |
| 965 | GCNRPTracker::LiveRegSet LiveThrough; |
| 966 | for (auto [Reg, Mask] : LiveIn) { |
| 967 | LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Val: Reg); |
| 968 | if (MaskIntersection.any()) { |
| 969 | LaneBitmask LTMask = getRegLiveThroughMask( |
| 970 | MRI, LIS, Reg, Begin: MBBStartSlot, End: MBBEndSlot, Mask: MaskIntersection); |
| 971 | if (LTMask.any()) |
| 972 | LiveThrough[Reg] = LTMask; |
| 973 | } |
| 974 | } |
| 975 | OS << PFX " Live-thr:" << llvm::print(LiveRegs: LiveThrough, MRI); |
| 976 | OS << printRP(getRegPressure(MRI, LiveRegs&: LiveThrough)) << '\n'; |
| 977 | } |
| 978 | OS << "...\n" ; |
| 979 | return false; |
| 980 | |
| 981 | #undef PFX |
| 982 | } |
| 983 | |