1//===- GCNRegPressure.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the GCNRegPressure class.
11///
12//===----------------------------------------------------------------------===//
13
14#include "GCNRegPressure.h"
15#include "AMDGPU.h"
16#include "SIMachineFunctionInfo.h"
17#include "llvm/CodeGen/RegisterPressure.h"
18
19using namespace llvm;
20
21#define DEBUG_TYPE "machine-scheduler"
22
23bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
24 const GCNRPTracker::LiveRegSet &S2) {
25 if (S1.size() != S2.size())
26 return false;
27
28 for (const auto &P : S1) {
29 auto I = S2.find(Val: P.first);
30 if (I == S2.end() || I->second != P.second)
31 return false;
32 }
33 return true;
34}
35
36///////////////////////////////////////////////////////////////////////////////
37// GCNRegPressure
38
39unsigned GCNRegPressure::getRegKind(const TargetRegisterClass *RC,
40 const SIRegisterInfo *STI) {
41 return STI->isSGPRClass(RC) ? SGPR : (STI->isAGPRClass(RC) ? AGPR : VGPR);
42}
43
44void GCNRegPressure::inc(unsigned Reg,
45 LaneBitmask PrevMask,
46 LaneBitmask NewMask,
47 const MachineRegisterInfo &MRI) {
48 unsigned NewNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: NewMask);
49 unsigned PrevNumCoveredRegs = SIRegisterInfo::getNumCoveredRegs(LM: PrevMask);
50 if (NewNumCoveredRegs == PrevNumCoveredRegs)
51 return;
52
53 int Sign = 1;
54 if (NewMask < PrevMask) {
55 std::swap(a&: NewMask, b&: PrevMask);
56 std::swap(a&: NewNumCoveredRegs, b&: PrevNumCoveredRegs);
57 Sign = -1;
58 }
59 assert(PrevMask < NewMask && PrevNumCoveredRegs < NewNumCoveredRegs &&
60 "prev mask should always be lesser than new");
61
62 const TargetRegisterClass *RC = MRI.getRegClass(Reg);
63 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
64 const SIRegisterInfo *STI = static_cast<const SIRegisterInfo *>(TRI);
65 unsigned RegKind = getRegKind(RC, STI);
66 if (TRI->getRegSizeInBits(RC: *RC) != 32) {
67 // Reg is from a tuple register class.
68 if (PrevMask.none()) {
69 unsigned TupleIdx = TOTAL_KINDS + RegKind;
70 Value[TupleIdx] += Sign * TRI->getRegClassWeight(RC).RegWeight;
71 }
72 // Pressure scales with number of new registers covered by the new mask.
73 // Note when true16 is enabled, we can no longer safely use the following
74 // approach to calculate the difference in the number of 32-bit registers
75 // between two masks:
76 //
77 // Sign *= SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78 //
79 // The issue is that the mask calculation `~PrevMask & NewMask` doesn't
80 // properly account for partial usage of a 32-bit register when dealing with
81 // 16-bit registers.
82 //
83 // Consider this example:
84 // Assume PrevMask = 0b0010 and NewMask = 0b1111. Here, the correct register
85 // usage difference should be 1, because even though PrevMask uses only half
86 // of a 32-bit register, it should still be counted as a full register use.
87 // However, the mask calculation yields `~PrevMask & NewMask = 0b1101`, and
88 // calling `getNumCoveredRegs` returns 2 instead of 1. This incorrect
89 // calculation can lead to integer overflow when Sign = -1.
90 Sign *= NewNumCoveredRegs - PrevNumCoveredRegs;
91 }
92 Value[RegKind] += Sign;
93}
94
95bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O,
96 unsigned MaxOccupancy) const {
97 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
98 unsigned DynamicVGPRBlockSize =
99 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
100
101 const auto SGPROcc = std::min(a: MaxOccupancy,
102 b: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum()));
103 const auto VGPROcc = std::min(
104 a: MaxOccupancy, b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()),
105 DynamicVGPRBlockSize));
106 const auto OtherSGPROcc = std::min(a: MaxOccupancy,
107 b: ST.getOccupancyWithNumSGPRs(SGPRs: O.getSGPRNum()));
108 const auto OtherVGPROcc =
109 std::min(a: MaxOccupancy,
110 b: ST.getOccupancyWithNumVGPRs(VGPRs: O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()),
111 DynamicVGPRBlockSize));
112
113 const auto Occ = std::min(a: SGPROcc, b: VGPROcc);
114 const auto OtherOcc = std::min(a: OtherSGPROcc, b: OtherVGPROcc);
115
116 // Give first precedence to the better occupancy.
117 if (Occ != OtherOcc)
118 return Occ > OtherOcc;
119
120 unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF);
121 unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF);
122
123 // SGPR excess pressure conditions
124 unsigned ExcessSGPR = std::max(a: static_cast<int>(getSGPRNum() - MaxSGPRs), b: 0);
125 unsigned OtherExcessSGPR =
126 std::max(a: static_cast<int>(O.getSGPRNum() - MaxSGPRs), b: 0);
127
128 auto WaveSize = ST.getWavefrontSize();
129 // The number of virtual VGPRs required to handle excess SGPR
130 unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize;
131 unsigned OtherVGPRForSGPRSpills =
132 (OtherExcessSGPR + (WaveSize - 1)) / WaveSize;
133
134 unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs();
135
136 // Unified excess pressure conditions, accounting for VGPRs used for SGPR
137 // spills
138 unsigned ExcessVGPR =
139 std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) +
140 VGPRForSGPRSpills - MaxVGPRs),
141 b: 0);
142 unsigned OtherExcessVGPR =
143 std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) +
144 OtherVGPRForSGPRSpills - MaxVGPRs),
145 b: 0);
146 // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR
147 // spills
148 unsigned ExcessArchVGPR = std::max(
149 a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) + VGPRForSGPRSpills - MaxArchVGPRs),
150 b: 0);
151 unsigned OtherExcessArchVGPR =
152 std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) + OtherVGPRForSGPRSpills -
153 MaxArchVGPRs),
154 b: 0);
155 // AGPR excess pressure conditions
156 unsigned ExcessAGPR = std::max(
157 a: static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs)
158 : (getAGPRNum() - MaxVGPRs)),
159 b: 0);
160 unsigned OtherExcessAGPR = std::max(
161 a: static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs)
162 : (O.getAGPRNum() - MaxVGPRs)),
163 b: 0);
164
165 bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR;
166 bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR ||
167 OtherExcessArchVGPR || OtherExcessAGPR;
168
169 // Give second precedence to the reduced number of spills to hold the register
170 // pressure.
171 if (ExcessRP || OtherExcessRP) {
172 // The difference in excess VGPR pressure, after including VGPRs used for
173 // SGPR spills
174 int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) -
175 (ExcessVGPR + ExcessArchVGPR + ExcessAGPR));
176
177 int SGPRDiff = OtherExcessSGPR - ExcessSGPR;
178
179 if (VGPRDiff != 0)
180 return VGPRDiff > 0;
181 if (SGPRDiff != 0) {
182 unsigned PureExcessVGPR =
183 std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs),
184 b: 0) +
185 std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0);
186 unsigned OtherPureExcessVGPR =
187 std::max(
188 a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs),
189 b: 0) +
190 std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0);
191
192 // If we have a special case where there is a tie in excess VGPR, but one
193 // of the pressures has VGPR usage from SGPR spills, prefer the pressure
194 // with SGPR spills.
195 if (PureExcessVGPR != OtherPureExcessVGPR)
196 return SGPRDiff < 0;
197 // If both pressures have the same excess pressure before and after
198 // accounting for SGPR spills, prefer fewer SGPR spills.
199 return SGPRDiff > 0;
200 }
201 }
202
203 bool SGPRImportant = SGPROcc < VGPROcc;
204 const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
205
206 // If both pressures disagree on what is more important compare vgprs.
207 if (SGPRImportant != OtherSGPRImportant) {
208 SGPRImportant = false;
209 }
210
211 // Give third precedence to lower register tuple pressure.
212 bool SGPRFirst = SGPRImportant;
213 for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
214 if (SGPRFirst) {
215 auto SW = getSGPRTuplesWeight();
216 auto OtherSW = O.getSGPRTuplesWeight();
217 if (SW != OtherSW)
218 return SW < OtherSW;
219 } else {
220 auto VW = getVGPRTuplesWeight();
221 auto OtherVW = O.getVGPRTuplesWeight();
222 if (VW != OtherVW)
223 return VW < OtherVW;
224 }
225 }
226
227 // Give final precedence to lower general RP.
228 return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
229 (getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) <
230 O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()));
231}
232
233Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST,
234 unsigned DynamicVGPRBlockSize) {
235 return Printable([&RP, ST, DynamicVGPRBlockSize](raw_ostream &OS) {
236 OS << "VGPRs: " << RP.getArchVGPRNum() << ' '
237 << "AGPRs: " << RP.getAGPRNum();
238 if (ST)
239 OS << "(O"
240 << ST->getOccupancyWithNumVGPRs(VGPRs: RP.getVGPRNum(UnifiedVGPRFile: ST->hasGFX90AInsts()),
241 DynamicVGPRBlockSize)
242 << ')';
243 OS << ", SGPRs: " << RP.getSGPRNum();
244 if (ST)
245 OS << "(O" << ST->getOccupancyWithNumSGPRs(SGPRs: RP.getSGPRNum()) << ')';
246 OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
247 << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
248 if (ST)
249 OS << " -> Occ: " << RP.getOccupancy(ST: *ST, DynamicVGPRBlockSize);
250 OS << '\n';
251 });
252}
253
254static LaneBitmask getDefRegMask(const MachineOperand &MO,
255 const MachineRegisterInfo &MRI) {
256 assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
257
258 // We don't rely on read-undef flag because in case of tentative schedule
259 // tracking it isn't set correctly yet. This works correctly however since
260 // use mask has been tracked before using LIS.
261 return MO.getSubReg() == 0 ?
262 MRI.getMaxLaneMaskForVReg(Reg: MO.getReg()) :
263 MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubIdx: MO.getSubReg());
264}
265
266static void
267collectVirtualRegUses(SmallVectorImpl<VRegMaskOrUnit> &VRegMaskOrUnits,
268 const MachineInstr &MI, const LiveIntervals &LIS,
269 const MachineRegisterInfo &MRI) {
270
271 auto &TRI = *MRI.getTargetRegisterInfo();
272 for (const auto &MO : MI.operands()) {
273 if (!MO.isReg() || !MO.getReg().isVirtual())
274 continue;
275 if (!MO.isUse() || !MO.readsReg())
276 continue;
277
278 Register Reg = MO.getReg();
279 auto I = llvm::find_if(Range&: VRegMaskOrUnits, P: [Reg](const VRegMaskOrUnit &RM) {
280 return RM.RegUnit == Reg;
281 });
282
283 auto &P = I == VRegMaskOrUnits.end()
284 ? VRegMaskOrUnits.emplace_back(Args&: Reg, Args: LaneBitmask::getNone())
285 : *I;
286
287 P.LaneMask |= MO.getSubReg() ? TRI.getSubRegIndexLaneMask(SubIdx: MO.getSubReg())
288 : MRI.getMaxLaneMaskForVReg(Reg);
289 }
290
291 SlotIndex InstrSI;
292 for (auto &P : VRegMaskOrUnits) {
293 auto &LI = LIS.getInterval(Reg: P.RegUnit);
294 if (!LI.hasSubRanges())
295 continue;
296
297 // For a tentative schedule LIS isn't updated yet but livemask should
298 // remain the same on any schedule. Subreg defs can be reordered but they
299 // all must dominate uses anyway.
300 if (!InstrSI)
301 InstrSI = LIS.getInstructionIndex(Instr: MI).getBaseIndex();
302
303 P.LaneMask = getLiveLaneMask(LI, SI: InstrSI, MRI, LaneMaskFilter: P.LaneMask);
304 }
305}
306
307/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
308static LaneBitmask getLanesWithProperty(
309 const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
310 bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
311 LaneBitmask SafeDefault,
312 function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
313 if (RegUnit.isVirtual()) {
314 const LiveInterval &LI = LIS.getInterval(Reg: RegUnit);
315 LaneBitmask Result;
316 if (TrackLaneMasks && LI.hasSubRanges()) {
317 for (const LiveInterval::SubRange &SR : LI.subranges()) {
318 if (Property(SR, Pos))
319 Result |= SR.LaneMask;
320 }
321 } else if (Property(LI, Pos)) {
322 Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(Reg: RegUnit)
323 : LaneBitmask::getAll();
324 }
325
326 return Result;
327 }
328
329 const LiveRange *LR = LIS.getCachedRegUnit(Unit: RegUnit);
330 if (LR == nullptr)
331 return SafeDefault;
332 return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
333}
334
335/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
336/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
337/// The query starts with a lane bitmask which gets lanes/bits removed for every
338/// use we find.
339static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
340 SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
341 const MachineRegisterInfo &MRI,
342 const SIRegisterInfo *TRI,
343 const LiveIntervals *LIS,
344 bool Upward = false) {
345 for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
346 if (MO.isUndef())
347 continue;
348 const MachineInstr *MI = MO.getParent();
349 SlotIndex InstSlot = LIS->getInstructionIndex(Instr: *MI).getRegSlot();
350 bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
351 : (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
352 if (!InRange)
353 continue;
354
355 unsigned SubRegIdx = MO.getSubReg();
356 LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubIdx: SubRegIdx);
357 LastUseMask &= ~UseMask;
358 if (LastUseMask.none())
359 return LaneBitmask::getNone();
360 }
361 return LastUseMask;
362}
363
364////////////////////////////////////////////////////////////////////////////////
365// GCNRPTarget
366
367GCNRPTarget::GCNRPTarget(const MachineFunction &MF, const GCNRegPressure &RP,
368 bool CombineVGPRSavings)
369 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
370 const Function &F = MF.getFunction();
371 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
372 setRegLimits(MaxSGPRs: ST.getMaxNumSGPRs(F), MaxVGPRs: ST.getMaxNumVGPRs(F), MF);
373}
374
375GCNRPTarget::GCNRPTarget(unsigned NumSGPRs, unsigned NumVGPRs,
376 const MachineFunction &MF, const GCNRegPressure &RP,
377 bool CombineVGPRSavings)
378 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
379 setRegLimits(MaxSGPRs: NumSGPRs, MaxVGPRs: NumVGPRs, MF);
380}
381
382GCNRPTarget::GCNRPTarget(unsigned Occupancy, const MachineFunction &MF,
383 const GCNRegPressure &RP, bool CombineVGPRSavings)
384 : RP(RP), CombineVGPRSavings(CombineVGPRSavings) {
385 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
386 unsigned DynamicVGPRBlockSize =
387 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
388 setRegLimits(MaxSGPRs: ST.getMaxNumSGPRs(WavesPerEU: Occupancy, /*Addressable=*/false),
389 MaxVGPRs: ST.getMaxNumVGPRs(WavesPerEU: Occupancy, DynamicVGPRBlockSize), MF);
390}
391
392void GCNRPTarget::setRegLimits(unsigned NumSGPRs, unsigned NumVGPRs,
393 const MachineFunction &MF) {
394 const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
395 unsigned DynamicVGPRBlockSize =
396 MF.getInfo<SIMachineFunctionInfo>()->getDynamicVGPRBlockSize();
397 MaxSGPRs = std::min(a: ST.getAddressableNumSGPRs(), b: NumSGPRs);
398 MaxVGPRs = std::min(a: ST.getAddressableNumArchVGPRs(), b: NumVGPRs);
399 MaxUnifiedVGPRs =
400 ST.hasGFX90AInsts()
401 ? std::min(a: ST.getAddressableNumVGPRs(DynamicVGPRBlockSize), b: NumVGPRs)
402 : 0;
403}
404
405bool GCNRPTarget::isSaveBeneficial(Register Reg,
406 const MachineRegisterInfo &MRI) const {
407 const TargetRegisterClass *RC = MRI.getRegClass(Reg);
408 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
409 const SIRegisterInfo *SRI = static_cast<const SIRegisterInfo *>(TRI);
410
411 if (SRI->isSGPRClass(RC))
412 return RP.getSGPRNum() > MaxSGPRs;
413 unsigned NumVGPRs =
414 SRI->isAGPRClass(RC) ? RP.getAGPRNum() : RP.getArchVGPRNum();
415 return isVGPRBankSaveBeneficial(NumVGPRs);
416}
417
418bool GCNRPTarget::satisfied() const {
419 if (RP.getSGPRNum() > MaxSGPRs)
420 return false;
421 if (RP.getVGPRNum(UnifiedVGPRFile: false) > MaxVGPRs &&
422 (!CombineVGPRSavings || !satisifiesVGPRBanksTarget()))
423 return false;
424 return satisfiesUnifiedTarget();
425}
426
427///////////////////////////////////////////////////////////////////////////////
428// GCNRPTracker
429
430LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
431 const LiveIntervals &LIS,
432 const MachineRegisterInfo &MRI,
433 LaneBitmask LaneMaskFilter) {
434 return getLiveLaneMask(LI: LIS.getInterval(Reg), SI, MRI, LaneMaskFilter);
435}
436
437LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
438 const MachineRegisterInfo &MRI,
439 LaneBitmask LaneMaskFilter) {
440 LaneBitmask LiveMask;
441 if (LI.hasSubRanges()) {
442 for (const auto &S : LI.subranges())
443 if ((S.LaneMask & LaneMaskFilter).any() && S.liveAt(index: SI)) {
444 LiveMask |= S.LaneMask;
445 assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
446 }
447 } else if (LI.liveAt(index: SI)) {
448 LiveMask = MRI.getMaxLaneMaskForVReg(Reg: LI.reg());
449 }
450 LiveMask &= LaneMaskFilter;
451 return LiveMask;
452}
453
454GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
455 const LiveIntervals &LIS,
456 const MachineRegisterInfo &MRI) {
457 GCNRPTracker::LiveRegSet LiveRegs;
458 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
459 auto Reg = Register::index2VirtReg(Index: I);
460 if (!LIS.hasInterval(Reg))
461 continue;
462 auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
463 if (LiveMask.any())
464 LiveRegs[Reg] = LiveMask;
465 }
466 return LiveRegs;
467}
468
469void GCNRPTracker::reset(const MachineInstr &MI,
470 const LiveRegSet *LiveRegsCopy,
471 bool After) {
472 const MachineFunction &MF = *MI.getMF();
473 MRI = &MF.getRegInfo();
474 if (LiveRegsCopy) {
475 if (&LiveRegs != LiveRegsCopy)
476 LiveRegs = *LiveRegsCopy;
477 } else {
478 LiveRegs = After ? getLiveRegsAfter(MI, LIS)
479 : getLiveRegsBefore(MI, LIS);
480 }
481
482 MaxPressure = CurPressure = getRegPressure(MRI: *MRI, LiveRegs);
483}
484
485void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
486 const LiveRegSet &LiveRegs_) {
487 MRI = &MRI_;
488 LiveRegs = LiveRegs_;
489 LastTrackedMI = nullptr;
490 MaxPressure = CurPressure = getRegPressure(MRI: MRI_, LiveRegs: LiveRegs_);
491}
492
493/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
494LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
495 SlotIndex Pos) const {
496 return getLanesWithProperty(
497 LIS, MRI: *MRI, TrackLaneMasks: true, RegUnit, Pos: Pos.getBaseIndex(), SafeDefault: LaneBitmask::getNone(),
498 Property: [](const LiveRange &LR, SlotIndex Pos) {
499 const LiveRange::Segment *S = LR.getSegmentContaining(Idx: Pos);
500 return S != nullptr && S->end == Pos.getRegSlot();
501 });
502}
503
504////////////////////////////////////////////////////////////////////////////////
505// GCNUpwardRPTracker
506
507void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
508 assert(MRI && "call reset first");
509
510 LastTrackedMI = &MI;
511
512 if (MI.isDebugInstr())
513 return;
514
515 // Kill all defs.
516 GCNRegPressure DefPressure, ECDefPressure;
517 bool HasECDefs = false;
518 for (const MachineOperand &MO : MI.all_defs()) {
519 if (!MO.getReg().isVirtual())
520 continue;
521
522 Register Reg = MO.getReg();
523 LaneBitmask DefMask = getDefRegMask(MO, MRI: *MRI);
524
525 // Treat a def as fully live at the moment of definition: keep a record.
526 if (MO.isEarlyClobber()) {
527 ECDefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI);
528 HasECDefs = true;
529 } else
530 DefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI);
531
532 auto I = LiveRegs.find(Val: Reg);
533 if (I == LiveRegs.end())
534 continue;
535
536 LaneBitmask &LiveMask = I->second;
537 LaneBitmask PrevMask = LiveMask;
538 LiveMask &= ~DefMask;
539 CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI);
540 if (LiveMask.none())
541 LiveRegs.erase(I);
542 }
543
544 // Update MaxPressure with defs pressure.
545 DefPressure += CurPressure;
546 if (HasECDefs)
547 DefPressure += ECDefPressure;
548 MaxPressure = max(P1: DefPressure, P2: MaxPressure);
549
550 // Make uses alive.
551 SmallVector<VRegMaskOrUnit, 8> RegUses;
552 collectVirtualRegUses(VRegMaskOrUnits&: RegUses, MI, LIS, MRI: *MRI);
553 for (const VRegMaskOrUnit &U : RegUses) {
554 LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
555 LaneBitmask PrevMask = LiveMask;
556 LiveMask |= U.LaneMask;
557 CurPressure.inc(Reg: U.RegUnit, PrevMask, NewMask: LiveMask, MRI: *MRI);
558 }
559
560 // Update MaxPressure with uses plus early-clobber defs pressure.
561 MaxPressure = HasECDefs ? max(P1: CurPressure + ECDefPressure, P2: MaxPressure)
562 : max(P1: CurPressure, P2: MaxPressure);
563
564 assert(CurPressure == getRegPressure(*MRI, LiveRegs));
565}
566
567////////////////////////////////////////////////////////////////////////////////
568// GCNDownwardRPTracker
569
570bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
571 const LiveRegSet *LiveRegsCopy) {
572 MRI = &MI.getParent()->getParent()->getRegInfo();
573 LastTrackedMI = nullptr;
574 MBBEnd = MI.getParent()->end();
575 NextMI = &MI;
576 NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd);
577 if (NextMI == MBBEnd)
578 return false;
579 GCNRPTracker::reset(MI: *NextMI, LiveRegsCopy, After: false);
580 return true;
581}
582
583bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
584 bool UseInternalIterator) {
585 assert(MRI && "call reset first");
586 SlotIndex SI;
587 const MachineInstr *CurrMI;
588 if (UseInternalIterator) {
589 if (!LastTrackedMI)
590 return NextMI == MBBEnd;
591
592 assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
593 CurrMI = LastTrackedMI;
594
595 SI = NextMI == MBBEnd
596 ? LIS.getInstructionIndex(Instr: *LastTrackedMI).getDeadSlot()
597 : LIS.getInstructionIndex(Instr: *NextMI).getBaseIndex();
598 } else { //! UseInternalIterator
599 SI = LIS.getInstructionIndex(Instr: *MI).getBaseIndex();
600 CurrMI = MI;
601 }
602
603 assert(SI.isValid());
604
605 // Remove dead registers or mask bits.
606 SmallSet<Register, 8> SeenRegs;
607 for (auto &MO : CurrMI->operands()) {
608 if (!MO.isReg() || !MO.getReg().isVirtual())
609 continue;
610 if (MO.isUse() && !MO.readsReg())
611 continue;
612 if (!UseInternalIterator && MO.isDef())
613 continue;
614 if (!SeenRegs.insert(V: MO.getReg()).second)
615 continue;
616 const LiveInterval &LI = LIS.getInterval(Reg: MO.getReg());
617 if (LI.hasSubRanges()) {
618 auto It = LiveRegs.end();
619 for (const auto &S : LI.subranges()) {
620 if (!S.liveAt(index: SI)) {
621 if (It == LiveRegs.end()) {
622 It = LiveRegs.find(Val: MO.getReg());
623 if (It == LiveRegs.end())
624 llvm_unreachable("register isn't live");
625 }
626 auto PrevMask = It->second;
627 It->second &= ~S.LaneMask;
628 CurPressure.inc(Reg: MO.getReg(), PrevMask, NewMask: It->second, MRI: *MRI);
629 }
630 }
631 if (It != LiveRegs.end() && It->second.none())
632 LiveRegs.erase(I: It);
633 } else if (!LI.liveAt(index: SI)) {
634 auto It = LiveRegs.find(Val: MO.getReg());
635 if (It == LiveRegs.end())
636 llvm_unreachable("register isn't live");
637 CurPressure.inc(Reg: MO.getReg(), PrevMask: It->second, NewMask: LaneBitmask::getNone(), MRI: *MRI);
638 LiveRegs.erase(I: It);
639 }
640 }
641
642 MaxPressure = max(P1: MaxPressure, P2: CurPressure);
643
644 LastTrackedMI = nullptr;
645
646 return UseInternalIterator && (NextMI == MBBEnd);
647}
648
649void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
650 bool UseInternalIterator) {
651 if (UseInternalIterator) {
652 LastTrackedMI = &*NextMI++;
653 NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd);
654 } else {
655 LastTrackedMI = MI;
656 }
657
658 const MachineInstr *CurrMI = LastTrackedMI;
659
660 // Add new registers or mask bits.
661 for (const auto &MO : CurrMI->all_defs()) {
662 Register Reg = MO.getReg();
663 if (!Reg.isVirtual())
664 continue;
665 auto &LiveMask = LiveRegs[Reg];
666 auto PrevMask = LiveMask;
667 LiveMask |= getDefRegMask(MO, MRI: *MRI);
668 CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI);
669 }
670
671 MaxPressure = max(P1: MaxPressure, P2: CurPressure);
672}
673
674bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
675 if (UseInternalIterator && NextMI == MBBEnd)
676 return false;
677
678 advanceBeforeNext(MI, UseInternalIterator);
679 advanceToNext(MI, UseInternalIterator);
680 if (!UseInternalIterator) {
681 // We must remove any dead def lanes from the current RP
682 advanceBeforeNext(MI, UseInternalIterator: true);
683 }
684 return true;
685}
686
687bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
688 while (NextMI != End)
689 if (!advance()) return false;
690 return true;
691}
692
693bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
694 MachineBasicBlock::const_iterator End,
695 const LiveRegSet *LiveRegsCopy) {
696 reset(MI: *Begin, LiveRegsCopy);
697 return advance(End);
698}
699
700Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
701 const GCNRPTracker::LiveRegSet &TrackedLR,
702 const TargetRegisterInfo *TRI, StringRef Pfx) {
703 return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) {
704 for (auto const &P : TrackedLR) {
705 auto I = LISLR.find(Val: P.first);
706 if (I == LISLR.end()) {
707 OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second)
708 << " isn't found in LIS reported set\n";
709 } else if (I->second != P.second) {
710 OS << Pfx << printReg(Reg: P.first, TRI)
711 << " masks doesn't match: LIS reported " << PrintLaneMask(LaneMask: I->second)
712 << ", tracked " << PrintLaneMask(LaneMask: P.second) << '\n';
713 }
714 }
715 for (auto const &P : LISLR) {
716 auto I = TrackedLR.find(Val: P.first);
717 if (I == TrackedLR.end()) {
718 OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second)
719 << " isn't found in tracked set\n";
720 }
721 }
722 });
723}
724
725GCNRegPressure
726GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
727 const SIRegisterInfo *TRI) const {
728 assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
729
730 SlotIndex SlotIdx;
731 SlotIdx = LIS.getInstructionIndex(Instr: *MI).getRegSlot();
732
733 // Account for register pressure similar to RegPressureTracker::recede().
734 RegisterOperands RegOpers;
735 RegOpers.collect(MI: *MI, TRI: *TRI, MRI: *MRI, TrackLaneMasks: true, /*IgnoreDead=*/false);
736 RegOpers.adjustLaneLiveness(LIS, MRI: *MRI, Pos: SlotIdx);
737 GCNRegPressure TempPressure = CurPressure;
738
739 for (const VRegMaskOrUnit &Use : RegOpers.Uses) {
740 Register Reg = Use.RegUnit;
741 if (!Reg.isVirtual())
742 continue;
743 LaneBitmask LastUseMask = getLastUsedLanes(RegUnit: Reg, Pos: SlotIdx);
744 if (LastUseMask.none())
745 continue;
746 // The LastUseMask is queried from the liveness information of instruction
747 // which may be further down the schedule. Some lanes may actually not be
748 // last uses for the current position.
749 // FIXME: allow the caller to pass in the list of vreg uses that remain
750 // to be bottom-scheduled to avoid searching uses at each query.
751 SlotIndex CurrIdx;
752 const MachineBasicBlock *MBB = MI->getParent();
753 MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
754 It: LastTrackedMI ? LastTrackedMI : MBB->begin(), End: MBB->end());
755 if (IdxPos == MBB->end()) {
756 CurrIdx = LIS.getMBBEndIdx(mbb: MBB);
757 } else {
758 CurrIdx = LIS.getInstructionIndex(Instr: *IdxPos).getRegSlot();
759 }
760
761 LastUseMask =
762 findUseBetween(Reg, LastUseMask, PriorUseIdx: CurrIdx, NextUseIdx: SlotIdx, MRI: *MRI, TRI, LIS: &LIS);
763 if (LastUseMask.none())
764 continue;
765
766 auto It = LiveRegs.find(Val: Reg);
767 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
768 LaneBitmask NewMask = LiveMask & ~LastUseMask;
769 TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI);
770 }
771
772 // Generate liveness for defs.
773 for (const VRegMaskOrUnit &Def : RegOpers.Defs) {
774 Register Reg = Def.RegUnit;
775 if (!Reg.isVirtual())
776 continue;
777 auto It = LiveRegs.find(Val: Reg);
778 LaneBitmask LiveMask = It != LiveRegs.end() ? It->second : LaneBitmask(0);
779 LaneBitmask NewMask = LiveMask | Def.LaneMask;
780 TempPressure.inc(Reg, PrevMask: LiveMask, NewMask, MRI: *MRI);
781 }
782
783 return TempPressure;
784}
785
786bool GCNUpwardRPTracker::isValid() const {
787 const auto &SI = LIS.getInstructionIndex(Instr: *LastTrackedMI).getBaseIndex();
788 const auto LISLR = llvm::getLiveRegs(SI, LIS, MRI: *MRI);
789 const auto &TrackedLR = LiveRegs;
790
791 if (!isEqual(S1: LISLR, S2: TrackedLR)) {
792 dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
793 " LIS reported livesets mismatch:\n"
794 << print(LiveRegs: LISLR, MRI: *MRI);
795 reportMismatch(LISLR, TrackedLR, TRI: MRI->getTargetRegisterInfo());
796 return false;
797 }
798
799 auto LISPressure = getRegPressure(MRI: *MRI, LiveRegs: LISLR);
800 if (LISPressure != CurPressure) {
801 dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
802 << print(RP: CurPressure) << "LIS rpt: " << print(RP: LISPressure);
803 return false;
804 }
805 return true;
806}
807
808Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
809 const MachineRegisterInfo &MRI) {
810 return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
811 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
812 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
813 Register Reg = Register::index2VirtReg(Index: I);
814 auto It = LiveRegs.find(Val: Reg);
815 if (It != LiveRegs.end() && It->second.any())
816 OS << ' ' << printVRegOrUnit(VRegOrUnit: Reg, TRI) << ':'
817 << PrintLaneMask(LaneMask: It->second);
818 }
819 OS << '\n';
820 });
821}
822
823void GCNRegPressure::dump() const { dbgs() << print(RP: *this); }
824
825static cl::opt<bool> UseDownwardTracker(
826 "amdgpu-print-rp-downward",
827 cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass"),
828 cl::init(Val: false), cl::Hidden);
829
830char llvm::GCNRegPressurePrinter::ID = 0;
831char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID;
832
833INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp", "", true, true)
834
835// Return lanemask of Reg's subregs that are live-through at [Begin, End] and
836// are fully covered by Mask.
837static LaneBitmask
838getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS,
839 Register Reg, SlotIndex Begin, SlotIndex End,
840 LaneBitmask Mask = LaneBitmask::getAll()) {
841
842 auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool {
843 auto *Segment = LR.getSegmentContaining(Idx: Begin);
844 return Segment && Segment->contains(I: End);
845 };
846
847 LaneBitmask LiveThroughMask;
848 const LiveInterval &LI = LIS.getInterval(Reg);
849 if (LI.hasSubRanges()) {
850 for (auto &SR : LI.subranges()) {
851 if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR))
852 LiveThroughMask |= SR.LaneMask;
853 }
854 } else {
855 LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg);
856 if ((RegMask & Mask) == RegMask && IsInOneSegment(LI))
857 LiveThroughMask = RegMask;
858 }
859
860 return LiveThroughMask;
861}
862
863bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
864 const MachineRegisterInfo &MRI = MF.getRegInfo();
865 const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
866 const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS();
867
868 auto &OS = dbgs();
869
870// Leading spaces are important for YAML syntax.
871#define PFX " "
872
873 OS << "---\nname: " << MF.getName() << "\nbody: |\n";
874
875 auto printRP = [](const GCNRegPressure &RP) {
876 return Printable([&RP](raw_ostream &OS) {
877 OS << format(PFX " %-5d", Vals: RP.getSGPRNum())
878 << format(Fmt: " %-5d", Vals: RP.getVGPRNum(UnifiedVGPRFile: false));
879 });
880 };
881
882 auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR,
883 const GCNRPTracker::LiveRegSet &LISLR) {
884 if (LISLR != TrackedLR) {
885 OS << PFX " mis LIS: " << llvm::print(LiveRegs: LISLR, MRI)
886 << reportMismatch(LISLR, TrackedLR, TRI, PFX " ");
887 }
888 };
889
890 // Register pressure before and at an instruction (in program order).
891 SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP;
892
893 for (auto &MBB : MF) {
894 RP.clear();
895 RP.reserve(N: MBB.size());
896
897 OS << PFX;
898 MBB.printName(os&: OS);
899 OS << ":\n";
900
901 SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(mbb: &MBB);
902 SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(mbb: &MBB);
903
904 GCNRPTracker::LiveRegSet LiveIn, LiveOut;
905 GCNRegPressure RPAtMBBEnd;
906
907 if (UseDownwardTracker) {
908 if (MBB.empty()) {
909 LiveIn = LiveOut = getLiveRegs(SI: MBBStartSlot, LIS, MRI);
910 RPAtMBBEnd = getRegPressure(MRI, LiveRegs&: LiveIn);
911 } else {
912 GCNDownwardRPTracker RPT(LIS);
913 RPT.reset(MI: MBB.front());
914
915 LiveIn = RPT.getLiveRegs();
916
917 while (!RPT.advanceBeforeNext()) {
918 GCNRegPressure RPBeforeMI = RPT.getPressure();
919 RPT.advanceToNext();
920 RP.emplace_back(Args&: RPBeforeMI, Args: RPT.getPressure());
921 }
922
923 LiveOut = RPT.getLiveRegs();
924 RPAtMBBEnd = RPT.getPressure();
925 }
926 } else {
927 GCNUpwardRPTracker RPT(LIS);
928 RPT.reset(MRI, SI: MBBEndSlot);
929
930 LiveOut = RPT.getLiveRegs();
931 RPAtMBBEnd = RPT.getPressure();
932
933 for (auto &MI : reverse(C&: MBB)) {
934 RPT.resetMaxPressure();
935 RPT.recede(MI);
936 if (!MI.isDebugInstr())
937 RP.emplace_back(Args: RPT.getPressure(), Args: RPT.getMaxPressure());
938 }
939
940 LiveIn = RPT.getLiveRegs();
941 }
942
943 OS << PFX " Live-in: " << llvm::print(LiveRegs: LiveIn, MRI);
944 if (!UseDownwardTracker)
945 ReportLISMismatchIfAny(LiveIn, getLiveRegs(SI: MBBStartSlot, LIS, MRI));
946
947 OS << PFX " SGPR VGPR\n";
948 int I = 0;
949 for (auto &MI : MBB) {
950 if (!MI.isDebugInstr()) {
951 auto &[RPBeforeInstr, RPAtInstr] =
952 RP[UseDownwardTracker ? I : (RP.size() - 1 - I)];
953 ++I;
954 OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " ";
955 } else
956 OS << PFX " ";
957 MI.print(OS);
958 }
959 OS << printRP(RPAtMBBEnd) << '\n';
960
961 OS << PFX " Live-out:" << llvm::print(LiveRegs: LiveOut, MRI);
962 if (UseDownwardTracker)
963 ReportLISMismatchIfAny(LiveOut, getLiveRegs(SI: MBBEndSlot, LIS, MRI));
964
965 GCNRPTracker::LiveRegSet LiveThrough;
966 for (auto [Reg, Mask] : LiveIn) {
967 LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Val: Reg);
968 if (MaskIntersection.any()) {
969 LaneBitmask LTMask = getRegLiveThroughMask(
970 MRI, LIS, Reg, Begin: MBBStartSlot, End: MBBEndSlot, Mask: MaskIntersection);
971 if (LTMask.any())
972 LiveThrough[Reg] = LTMask;
973 }
974 }
975 OS << PFX " Live-thr:" << llvm::print(LiveRegs: LiveThrough, MRI);
976 OS << printRP(getRegPressure(MRI, LiveRegs&: LiveThrough)) << '\n';
977 }
978 OS << "...\n";
979 return false;
980
981#undef PFX
982}
983