1 | //===- GCNRegPressure.cpp -------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | /// |
9 | /// \file |
10 | /// This file implements the GCNRegPressure class. |
11 | /// |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #include "GCNRegPressure.h" |
15 | #include "AMDGPU.h" |
16 | #include "llvm/CodeGen/RegisterPressure.h" |
17 | |
18 | using namespace llvm; |
19 | |
20 | #define DEBUG_TYPE "machine-scheduler" |
21 | |
22 | bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1, |
23 | const GCNRPTracker::LiveRegSet &S2) { |
24 | if (S1.size() != S2.size()) |
25 | return false; |
26 | |
27 | for (const auto &P : S1) { |
28 | auto I = S2.find(Val: P.first); |
29 | if (I == S2.end() || I->second != P.second) |
30 | return false; |
31 | } |
32 | return true; |
33 | } |
34 | |
35 | /////////////////////////////////////////////////////////////////////////////// |
36 | // GCNRegPressure |
37 | |
38 | unsigned GCNRegPressure::getRegKind(Register Reg, |
39 | const MachineRegisterInfo &MRI) { |
40 | assert(Reg.isVirtual()); |
41 | const auto RC = MRI.getRegClass(Reg); |
42 | auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo()); |
43 | return STI->isSGPRClass(RC) |
44 | ? (STI->getRegSizeInBits(RC: *RC) == 32 ? SGPR32 : SGPR_TUPLE) |
45 | : STI->isAGPRClass(RC) |
46 | ? (STI->getRegSizeInBits(RC: *RC) == 32 ? AGPR32 : AGPR_TUPLE) |
47 | : (STI->getRegSizeInBits(RC: *RC) == 32 ? VGPR32 : VGPR_TUPLE); |
48 | } |
49 | |
50 | void GCNRegPressure::inc(unsigned Reg, |
51 | LaneBitmask PrevMask, |
52 | LaneBitmask NewMask, |
53 | const MachineRegisterInfo &MRI) { |
54 | if (SIRegisterInfo::getNumCoveredRegs(LM: NewMask) == |
55 | SIRegisterInfo::getNumCoveredRegs(LM: PrevMask)) |
56 | return; |
57 | |
58 | int Sign = 1; |
59 | if (NewMask < PrevMask) { |
60 | std::swap(a&: NewMask, b&: PrevMask); |
61 | Sign = -1; |
62 | } |
63 | |
64 | switch (auto Kind = getRegKind(Reg, MRI)) { |
65 | case SGPR32: |
66 | case VGPR32: |
67 | case AGPR32: |
68 | Value[Kind] += Sign; |
69 | break; |
70 | |
71 | case SGPR_TUPLE: |
72 | case VGPR_TUPLE: |
73 | case AGPR_TUPLE: |
74 | assert(PrevMask < NewMask); |
75 | |
76 | Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] += |
77 | Sign * SIRegisterInfo::getNumCoveredRegs(LM: ~PrevMask & NewMask); |
78 | |
79 | if (PrevMask.none()) { |
80 | assert(NewMask.any()); |
81 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
82 | Value[Kind] += |
83 | Sign * TRI->getRegClassWeight(RC: MRI.getRegClass(Reg)).RegWeight; |
84 | } |
85 | break; |
86 | |
87 | default: llvm_unreachable("Unknown register kind" ); |
88 | } |
89 | } |
90 | |
91 | bool GCNRegPressure::less(const MachineFunction &MF, const GCNRegPressure &O, |
92 | unsigned MaxOccupancy) const { |
93 | const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>(); |
94 | |
95 | const auto SGPROcc = std::min(a: MaxOccupancy, |
96 | b: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum())); |
97 | const auto VGPROcc = |
98 | std::min(a: MaxOccupancy, |
99 | b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()))); |
100 | const auto OtherSGPROcc = std::min(a: MaxOccupancy, |
101 | b: ST.getOccupancyWithNumSGPRs(SGPRs: O.getSGPRNum())); |
102 | const auto OtherVGPROcc = |
103 | std::min(a: MaxOccupancy, |
104 | b: ST.getOccupancyWithNumVGPRs(VGPRs: O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()))); |
105 | |
106 | const auto Occ = std::min(a: SGPROcc, b: VGPROcc); |
107 | const auto OtherOcc = std::min(a: OtherSGPROcc, b: OtherVGPROcc); |
108 | |
109 | // Give first precedence to the better occupancy. |
110 | if (Occ != OtherOcc) |
111 | return Occ > OtherOcc; |
112 | |
113 | unsigned MaxVGPRs = ST.getMaxNumVGPRs(MF); |
114 | unsigned MaxSGPRs = ST.getMaxNumSGPRs(MF); |
115 | |
116 | // SGPR excess pressure conditions |
117 | unsigned ExcessSGPR = std::max(a: static_cast<int>(getSGPRNum() - MaxSGPRs), b: 0); |
118 | unsigned OtherExcessSGPR = |
119 | std::max(a: static_cast<int>(O.getSGPRNum() - MaxSGPRs), b: 0); |
120 | |
121 | auto WaveSize = ST.getWavefrontSize(); |
122 | // The number of virtual VGPRs required to handle excess SGPR |
123 | unsigned VGPRForSGPRSpills = (ExcessSGPR + (WaveSize - 1)) / WaveSize; |
124 | unsigned OtherVGPRForSGPRSpills = |
125 | (OtherExcessSGPR + (WaveSize - 1)) / WaveSize; |
126 | |
127 | unsigned MaxArchVGPRs = ST.getAddressableNumArchVGPRs(); |
128 | |
129 | // Unified excess pressure conditions, accounting for VGPRs used for SGPR |
130 | // spills |
131 | unsigned ExcessVGPR = |
132 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
133 | VGPRForSGPRSpills - MaxVGPRs), |
134 | b: 0); |
135 | unsigned OtherExcessVGPR = |
136 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) + |
137 | OtherVGPRForSGPRSpills - MaxVGPRs), |
138 | b: 0); |
139 | // Arch VGPR excess pressure conditions, accounting for VGPRs used for SGPR |
140 | // spills |
141 | unsigned ExcessArchVGPR = std::max( |
142 | a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) + VGPRForSGPRSpills - MaxArchVGPRs), |
143 | b: 0); |
144 | unsigned OtherExcessArchVGPR = |
145 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) + OtherVGPRForSGPRSpills - |
146 | MaxArchVGPRs), |
147 | b: 0); |
148 | // AGPR excess pressure conditions |
149 | unsigned ExcessAGPR = std::max( |
150 | a: static_cast<int>(ST.hasGFX90AInsts() ? (getAGPRNum() - MaxArchVGPRs) |
151 | : (getAGPRNum() - MaxVGPRs)), |
152 | b: 0); |
153 | unsigned OtherExcessAGPR = std::max( |
154 | a: static_cast<int>(ST.hasGFX90AInsts() ? (O.getAGPRNum() - MaxArchVGPRs) |
155 | : (O.getAGPRNum() - MaxVGPRs)), |
156 | b: 0); |
157 | |
158 | bool ExcessRP = ExcessSGPR || ExcessVGPR || ExcessArchVGPR || ExcessAGPR; |
159 | bool OtherExcessRP = OtherExcessSGPR || OtherExcessVGPR || |
160 | OtherExcessArchVGPR || OtherExcessAGPR; |
161 | |
162 | // Give second precedence to the reduced number of spills to hold the register |
163 | // pressure. |
164 | if (ExcessRP || OtherExcessRP) { |
165 | // The difference in excess VGPR pressure, after including VGPRs used for |
166 | // SGPR spills |
167 | int VGPRDiff = ((OtherExcessVGPR + OtherExcessArchVGPR + OtherExcessAGPR) - |
168 | (ExcessVGPR + ExcessArchVGPR + ExcessAGPR)); |
169 | |
170 | int SGPRDiff = OtherExcessSGPR - ExcessSGPR; |
171 | |
172 | if (VGPRDiff != 0) |
173 | return VGPRDiff > 0; |
174 | if (SGPRDiff != 0) { |
175 | unsigned PureExcessVGPR = |
176 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
177 | b: 0) + |
178 | std::max(a: static_cast<int>(getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
179 | unsigned OtherPureExcessVGPR = |
180 | std::max( |
181 | a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) - MaxVGPRs), |
182 | b: 0) + |
183 | std::max(a: static_cast<int>(O.getVGPRNum(UnifiedVGPRFile: false) - MaxArchVGPRs), b: 0); |
184 | |
185 | // If we have a special case where there is a tie in excess VGPR, but one |
186 | // of the pressures has VGPR usage from SGPR spills, prefer the pressure |
187 | // with SGPR spills. |
188 | if (PureExcessVGPR != OtherPureExcessVGPR) |
189 | return SGPRDiff < 0; |
190 | // If both pressures have the same excess pressure before and after |
191 | // accounting for SGPR spills, prefer fewer SGPR spills. |
192 | return SGPRDiff > 0; |
193 | } |
194 | } |
195 | |
196 | bool SGPRImportant = SGPROcc < VGPROcc; |
197 | const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc; |
198 | |
199 | // If both pressures disagree on what is more important compare vgprs. |
200 | if (SGPRImportant != OtherSGPRImportant) { |
201 | SGPRImportant = false; |
202 | } |
203 | |
204 | // Give third precedence to lower register tuple pressure. |
205 | bool SGPRFirst = SGPRImportant; |
206 | for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) { |
207 | if (SGPRFirst) { |
208 | auto SW = getSGPRTuplesWeight(); |
209 | auto OtherSW = O.getSGPRTuplesWeight(); |
210 | if (SW != OtherSW) |
211 | return SW < OtherSW; |
212 | } else { |
213 | auto VW = getVGPRTuplesWeight(); |
214 | auto OtherVW = O.getVGPRTuplesWeight(); |
215 | if (VW != OtherVW) |
216 | return VW < OtherVW; |
217 | } |
218 | } |
219 | |
220 | // Give final precedence to lower general RP. |
221 | return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()): |
222 | (getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts()) < |
223 | O.getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts())); |
224 | } |
225 | |
226 | Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) { |
227 | return Printable([&RP, ST](raw_ostream &OS) { |
228 | OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' ' |
229 | << "AGPRs: " << RP.getAGPRNum(); |
230 | if (ST) |
231 | OS << "(O" |
232 | << ST->getOccupancyWithNumVGPRs(VGPRs: RP.getVGPRNum(UnifiedVGPRFile: ST->hasGFX90AInsts())) |
233 | << ')'; |
234 | OS << ", SGPRs: " << RP.getSGPRNum(); |
235 | if (ST) |
236 | OS << "(O" << ST->getOccupancyWithNumSGPRs(SGPRs: RP.getSGPRNum()) << ')'; |
237 | OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight() |
238 | << ", LSGPR WT: " << RP.getSGPRTuplesWeight(); |
239 | if (ST) |
240 | OS << " -> Occ: " << RP.getOccupancy(ST: *ST); |
241 | OS << '\n'; |
242 | }); |
243 | } |
244 | |
245 | static LaneBitmask getDefRegMask(const MachineOperand &MO, |
246 | const MachineRegisterInfo &MRI) { |
247 | assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual()); |
248 | |
249 | // We don't rely on read-undef flag because in case of tentative schedule |
250 | // tracking it isn't set correctly yet. This works correctly however since |
251 | // use mask has been tracked before using LIS. |
252 | return MO.getSubReg() == 0 ? |
253 | MRI.getMaxLaneMaskForVReg(Reg: MO.getReg()) : |
254 | MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubIdx: MO.getSubReg()); |
255 | } |
256 | |
257 | static void |
258 | collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs, |
259 | const MachineInstr &MI, const LiveIntervals &LIS, |
260 | const MachineRegisterInfo &MRI) { |
261 | SlotIndex InstrSI; |
262 | for (const auto &MO : MI.operands()) { |
263 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
264 | continue; |
265 | if (!MO.isUse() || !MO.readsReg()) |
266 | continue; |
267 | |
268 | Register Reg = MO.getReg(); |
269 | if (llvm::any_of(Range&: RegMaskPairs, P: [Reg](const RegisterMaskPair &RM) { |
270 | return RM.RegUnit == Reg; |
271 | })) |
272 | continue; |
273 | |
274 | LaneBitmask UseMask; |
275 | auto &LI = LIS.getInterval(Reg); |
276 | if (!LI.hasSubRanges()) |
277 | UseMask = MRI.getMaxLaneMaskForVReg(Reg); |
278 | else { |
279 | // For a tentative schedule LIS isn't updated yet but livemask should |
280 | // remain the same on any schedule. Subreg defs can be reordered but they |
281 | // all must dominate uses anyway. |
282 | if (!InstrSI) |
283 | InstrSI = LIS.getInstructionIndex(Instr: *MO.getParent()).getBaseIndex(); |
284 | UseMask = getLiveLaneMask(LI, SI: InstrSI, MRI); |
285 | } |
286 | |
287 | RegMaskPairs.emplace_back(Args&: Reg, Args&: UseMask); |
288 | } |
289 | } |
290 | |
291 | /////////////////////////////////////////////////////////////////////////////// |
292 | // GCNRPTracker |
293 | |
294 | LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI, |
295 | const LiveIntervals &LIS, |
296 | const MachineRegisterInfo &MRI) { |
297 | return getLiveLaneMask(LI: LIS.getInterval(Reg), SI, MRI); |
298 | } |
299 | |
300 | LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI, |
301 | const MachineRegisterInfo &MRI) { |
302 | LaneBitmask LiveMask; |
303 | if (LI.hasSubRanges()) { |
304 | for (const auto &S : LI.subranges()) |
305 | if (S.liveAt(index: SI)) { |
306 | LiveMask |= S.LaneMask; |
307 | assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg()))); |
308 | } |
309 | } else if (LI.liveAt(index: SI)) { |
310 | LiveMask = MRI.getMaxLaneMaskForVReg(Reg: LI.reg()); |
311 | } |
312 | return LiveMask; |
313 | } |
314 | |
315 | GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI, |
316 | const LiveIntervals &LIS, |
317 | const MachineRegisterInfo &MRI) { |
318 | GCNRPTracker::LiveRegSet LiveRegs; |
319 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
320 | auto Reg = Register::index2VirtReg(Index: I); |
321 | if (!LIS.hasInterval(Reg)) |
322 | continue; |
323 | auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI); |
324 | if (LiveMask.any()) |
325 | LiveRegs[Reg] = LiveMask; |
326 | } |
327 | return LiveRegs; |
328 | } |
329 | |
330 | void GCNRPTracker::reset(const MachineInstr &MI, |
331 | const LiveRegSet *LiveRegsCopy, |
332 | bool After) { |
333 | const MachineFunction &MF = *MI.getMF(); |
334 | MRI = &MF.getRegInfo(); |
335 | if (LiveRegsCopy) { |
336 | if (&LiveRegs != LiveRegsCopy) |
337 | LiveRegs = *LiveRegsCopy; |
338 | } else { |
339 | LiveRegs = After ? getLiveRegsAfter(MI, LIS) |
340 | : getLiveRegsBefore(MI, LIS); |
341 | } |
342 | |
343 | MaxPressure = CurPressure = getRegPressure(MRI: *MRI, LiveRegs); |
344 | } |
345 | |
346 | //////////////////////////////////////////////////////////////////////////////// |
347 | // GCNUpwardRPTracker |
348 | |
349 | void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_, |
350 | const LiveRegSet &LiveRegs_) { |
351 | MRI = &MRI_; |
352 | LiveRegs = LiveRegs_; |
353 | LastTrackedMI = nullptr; |
354 | MaxPressure = CurPressure = getRegPressure(MRI: MRI_, LiveRegs: LiveRegs_); |
355 | } |
356 | |
357 | void GCNUpwardRPTracker::recede(const MachineInstr &MI) { |
358 | assert(MRI && "call reset first" ); |
359 | |
360 | LastTrackedMI = &MI; |
361 | |
362 | if (MI.isDebugInstr()) |
363 | return; |
364 | |
365 | // Kill all defs. |
366 | GCNRegPressure DefPressure, ECDefPressure; |
367 | bool HasECDefs = false; |
368 | for (const MachineOperand &MO : MI.all_defs()) { |
369 | if (!MO.getReg().isVirtual()) |
370 | continue; |
371 | |
372 | Register Reg = MO.getReg(); |
373 | LaneBitmask DefMask = getDefRegMask(MO, MRI: *MRI); |
374 | |
375 | // Treat a def as fully live at the moment of definition: keep a record. |
376 | if (MO.isEarlyClobber()) { |
377 | ECDefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
378 | HasECDefs = true; |
379 | } else |
380 | DefPressure.inc(Reg, PrevMask: LaneBitmask::getNone(), NewMask: DefMask, MRI: *MRI); |
381 | |
382 | auto I = LiveRegs.find(Val: Reg); |
383 | if (I == LiveRegs.end()) |
384 | continue; |
385 | |
386 | LaneBitmask &LiveMask = I->second; |
387 | LaneBitmask PrevMask = LiveMask; |
388 | LiveMask &= ~DefMask; |
389 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
390 | if (LiveMask.none()) |
391 | LiveRegs.erase(I); |
392 | } |
393 | |
394 | // Update MaxPressure with defs pressure. |
395 | DefPressure += CurPressure; |
396 | if (HasECDefs) |
397 | DefPressure += ECDefPressure; |
398 | MaxPressure = max(P1: DefPressure, P2: MaxPressure); |
399 | |
400 | // Make uses alive. |
401 | SmallVector<RegisterMaskPair, 8> RegUses; |
402 | collectVirtualRegUses(RegMaskPairs&: RegUses, MI, LIS, MRI: *MRI); |
403 | for (const RegisterMaskPair &U : RegUses) { |
404 | LaneBitmask &LiveMask = LiveRegs[U.RegUnit]; |
405 | LaneBitmask PrevMask = LiveMask; |
406 | LiveMask |= U.LaneMask; |
407 | CurPressure.inc(Reg: U.RegUnit, PrevMask, NewMask: LiveMask, MRI: *MRI); |
408 | } |
409 | |
410 | // Update MaxPressure with uses plus early-clobber defs pressure. |
411 | MaxPressure = HasECDefs ? max(P1: CurPressure + ECDefPressure, P2: MaxPressure) |
412 | : max(P1: CurPressure, P2: MaxPressure); |
413 | |
414 | assert(CurPressure == getRegPressure(*MRI, LiveRegs)); |
415 | } |
416 | |
417 | //////////////////////////////////////////////////////////////////////////////// |
418 | // GCNDownwardRPTracker |
419 | |
420 | bool GCNDownwardRPTracker::reset(const MachineInstr &MI, |
421 | const LiveRegSet *LiveRegsCopy) { |
422 | MRI = &MI.getParent()->getParent()->getRegInfo(); |
423 | LastTrackedMI = nullptr; |
424 | MBBEnd = MI.getParent()->end(); |
425 | NextMI = &MI; |
426 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
427 | if (NextMI == MBBEnd) |
428 | return false; |
429 | GCNRPTracker::reset(MI: *NextMI, LiveRegsCopy, After: false); |
430 | return true; |
431 | } |
432 | |
433 | bool GCNDownwardRPTracker::advanceBeforeNext() { |
434 | assert(MRI && "call reset first" ); |
435 | if (!LastTrackedMI) |
436 | return NextMI == MBBEnd; |
437 | |
438 | assert(NextMI == MBBEnd || !NextMI->isDebugInstr()); |
439 | |
440 | SlotIndex SI = NextMI == MBBEnd |
441 | ? LIS.getInstructionIndex(Instr: *LastTrackedMI).getDeadSlot() |
442 | : LIS.getInstructionIndex(Instr: *NextMI).getBaseIndex(); |
443 | assert(SI.isValid()); |
444 | |
445 | // Remove dead registers or mask bits. |
446 | SmallSet<Register, 8> SeenRegs; |
447 | for (auto &MO : LastTrackedMI->operands()) { |
448 | if (!MO.isReg() || !MO.getReg().isVirtual()) |
449 | continue; |
450 | if (MO.isUse() && !MO.readsReg()) |
451 | continue; |
452 | if (!SeenRegs.insert(V: MO.getReg()).second) |
453 | continue; |
454 | const LiveInterval &LI = LIS.getInterval(Reg: MO.getReg()); |
455 | if (LI.hasSubRanges()) { |
456 | auto It = LiveRegs.end(); |
457 | for (const auto &S : LI.subranges()) { |
458 | if (!S.liveAt(index: SI)) { |
459 | if (It == LiveRegs.end()) { |
460 | It = LiveRegs.find(Val: MO.getReg()); |
461 | if (It == LiveRegs.end()) |
462 | llvm_unreachable("register isn't live" ); |
463 | } |
464 | auto PrevMask = It->second; |
465 | It->second &= ~S.LaneMask; |
466 | CurPressure.inc(Reg: MO.getReg(), PrevMask, NewMask: It->second, MRI: *MRI); |
467 | } |
468 | } |
469 | if (It != LiveRegs.end() && It->second.none()) |
470 | LiveRegs.erase(I: It); |
471 | } else if (!LI.liveAt(index: SI)) { |
472 | auto It = LiveRegs.find(Val: MO.getReg()); |
473 | if (It == LiveRegs.end()) |
474 | llvm_unreachable("register isn't live" ); |
475 | CurPressure.inc(Reg: MO.getReg(), PrevMask: It->second, NewMask: LaneBitmask::getNone(), MRI: *MRI); |
476 | LiveRegs.erase(I: It); |
477 | } |
478 | } |
479 | |
480 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
481 | |
482 | LastTrackedMI = nullptr; |
483 | |
484 | return NextMI == MBBEnd; |
485 | } |
486 | |
487 | void GCNDownwardRPTracker::advanceToNext() { |
488 | LastTrackedMI = &*NextMI++; |
489 | NextMI = skipDebugInstructionsForward(It: NextMI, End: MBBEnd); |
490 | |
491 | // Add new registers or mask bits. |
492 | for (const auto &MO : LastTrackedMI->all_defs()) { |
493 | Register Reg = MO.getReg(); |
494 | if (!Reg.isVirtual()) |
495 | continue; |
496 | auto &LiveMask = LiveRegs[Reg]; |
497 | auto PrevMask = LiveMask; |
498 | LiveMask |= getDefRegMask(MO, MRI: *MRI); |
499 | CurPressure.inc(Reg, PrevMask, NewMask: LiveMask, MRI: *MRI); |
500 | } |
501 | |
502 | MaxPressure = max(P1: MaxPressure, P2: CurPressure); |
503 | } |
504 | |
505 | bool GCNDownwardRPTracker::advance() { |
506 | if (NextMI == MBBEnd) |
507 | return false; |
508 | advanceBeforeNext(); |
509 | advanceToNext(); |
510 | return true; |
511 | } |
512 | |
513 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) { |
514 | while (NextMI != End) |
515 | if (!advance()) return false; |
516 | return true; |
517 | } |
518 | |
519 | bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin, |
520 | MachineBasicBlock::const_iterator End, |
521 | const LiveRegSet *LiveRegsCopy) { |
522 | reset(MI: *Begin, LiveRegsCopy); |
523 | return advance(End); |
524 | } |
525 | |
526 | Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR, |
527 | const GCNRPTracker::LiveRegSet &TrackedLR, |
528 | const TargetRegisterInfo *TRI, StringRef Pfx) { |
529 | return Printable([&LISLR, &TrackedLR, TRI, Pfx](raw_ostream &OS) { |
530 | for (auto const &P : TrackedLR) { |
531 | auto I = LISLR.find(Val: P.first); |
532 | if (I == LISLR.end()) { |
533 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
534 | << " isn't found in LIS reported set\n" ; |
535 | } else if (I->second != P.second) { |
536 | OS << Pfx << printReg(Reg: P.first, TRI) |
537 | << " masks doesn't match: LIS reported " << PrintLaneMask(LaneMask: I->second) |
538 | << ", tracked " << PrintLaneMask(LaneMask: P.second) << '\n'; |
539 | } |
540 | } |
541 | for (auto const &P : LISLR) { |
542 | auto I = TrackedLR.find(Val: P.first); |
543 | if (I == TrackedLR.end()) { |
544 | OS << Pfx << printReg(Reg: P.first, TRI) << ":L" << PrintLaneMask(LaneMask: P.second) |
545 | << " isn't found in tracked set\n" ; |
546 | } |
547 | } |
548 | }); |
549 | } |
550 | |
551 | bool GCNUpwardRPTracker::isValid() const { |
552 | const auto &SI = LIS.getInstructionIndex(Instr: *LastTrackedMI).getBaseIndex(); |
553 | const auto LISLR = llvm::getLiveRegs(SI, LIS, MRI: *MRI); |
554 | const auto &TrackedLR = LiveRegs; |
555 | |
556 | if (!isEqual(S1: LISLR, S2: TrackedLR)) { |
557 | dbgs() << "\nGCNUpwardRPTracker error: Tracked and" |
558 | " LIS reported livesets mismatch:\n" |
559 | << print(LiveRegs: LISLR, MRI: *MRI); |
560 | reportMismatch(LISLR, TrackedLR, TRI: MRI->getTargetRegisterInfo()); |
561 | return false; |
562 | } |
563 | |
564 | auto LISPressure = getRegPressure(MRI: *MRI, LiveRegs: LISLR); |
565 | if (LISPressure != CurPressure) { |
566 | dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: " |
567 | << print(RP: CurPressure) << "LIS rpt: " << print(RP: LISPressure); |
568 | return false; |
569 | } |
570 | return true; |
571 | } |
572 | |
573 | Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs, |
574 | const MachineRegisterInfo &MRI) { |
575 | return Printable([&LiveRegs, &MRI](raw_ostream &OS) { |
576 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
577 | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { |
578 | Register Reg = Register::index2VirtReg(Index: I); |
579 | auto It = LiveRegs.find(Val: Reg); |
580 | if (It != LiveRegs.end() && It->second.any()) |
581 | OS << ' ' << printVRegOrUnit(VRegOrUnit: Reg, TRI) << ':' |
582 | << PrintLaneMask(LaneMask: It->second); |
583 | } |
584 | OS << '\n'; |
585 | }); |
586 | } |
587 | |
588 | void GCNRegPressure::dump() const { dbgs() << print(RP: *this); } |
589 | |
590 | static cl::opt<bool> UseDownwardTracker( |
591 | "amdgpu-print-rp-downward" , |
592 | cl::desc("Use GCNDownwardRPTracker for GCNRegPressurePrinter pass" ), |
593 | cl::init(Val: false), cl::Hidden); |
594 | |
595 | char llvm::GCNRegPressurePrinter::ID = 0; |
596 | char &llvm::GCNRegPressurePrinterID = GCNRegPressurePrinter::ID; |
597 | |
598 | INITIALIZE_PASS(GCNRegPressurePrinter, "amdgpu-print-rp" , "" , true, true) |
599 | |
600 | // Return lanemask of Reg's subregs that are live-through at [Begin, End] and |
601 | // are fully covered by Mask. |
602 | static LaneBitmask |
603 | getRegLiveThroughMask(const MachineRegisterInfo &MRI, const LiveIntervals &LIS, |
604 | Register Reg, SlotIndex Begin, SlotIndex End, |
605 | LaneBitmask Mask = LaneBitmask::getAll()) { |
606 | |
607 | auto IsInOneSegment = [Begin, End](const LiveRange &LR) -> bool { |
608 | auto *Segment = LR.getSegmentContaining(Idx: Begin); |
609 | return Segment && Segment->contains(I: End); |
610 | }; |
611 | |
612 | LaneBitmask LiveThroughMask; |
613 | const LiveInterval &LI = LIS.getInterval(Reg); |
614 | if (LI.hasSubRanges()) { |
615 | for (auto &SR : LI.subranges()) { |
616 | if ((SR.LaneMask & Mask) == SR.LaneMask && IsInOneSegment(SR)) |
617 | LiveThroughMask |= SR.LaneMask; |
618 | } |
619 | } else { |
620 | LaneBitmask RegMask = MRI.getMaxLaneMaskForVReg(Reg); |
621 | if ((RegMask & Mask) == RegMask && IsInOneSegment(LI)) |
622 | LiveThroughMask = RegMask; |
623 | } |
624 | |
625 | return LiveThroughMask; |
626 | } |
627 | |
628 | bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) { |
629 | const MachineRegisterInfo &MRI = MF.getRegInfo(); |
630 | const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo(); |
631 | const LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); |
632 | |
633 | auto &OS = dbgs(); |
634 | |
635 | // Leading spaces are important for YAML syntax. |
636 | #define PFX " " |
637 | |
638 | OS << "---\nname: " << MF.getName() << "\nbody: |\n" ; |
639 | |
640 | auto printRP = [](const GCNRegPressure &RP) { |
641 | return Printable([&RP](raw_ostream &OS) { |
642 | OS << format(PFX " %-5d" , Vals: RP.getSGPRNum()) |
643 | << format(Fmt: " %-5d" , Vals: RP.getVGPRNum(UnifiedVGPRFile: false)); |
644 | }); |
645 | }; |
646 | |
647 | auto ReportLISMismatchIfAny = [&](const GCNRPTracker::LiveRegSet &TrackedLR, |
648 | const GCNRPTracker::LiveRegSet &LISLR) { |
649 | if (LISLR != TrackedLR) { |
650 | OS << PFX " mis LIS: " << llvm::print(LiveRegs: LISLR, MRI) |
651 | << reportMismatch(LISLR, TrackedLR, TRI, PFX " " ); |
652 | } |
653 | }; |
654 | |
655 | // Register pressure before and at an instruction (in program order). |
656 | SmallVector<std::pair<GCNRegPressure, GCNRegPressure>, 16> RP; |
657 | |
658 | for (auto &MBB : MF) { |
659 | RP.clear(); |
660 | RP.reserve(N: MBB.size()); |
661 | |
662 | OS << PFX; |
663 | MBB.printName(os&: OS); |
664 | OS << ":\n" ; |
665 | |
666 | SlotIndex MBBStartSlot = LIS.getSlotIndexes()->getMBBStartIdx(mbb: &MBB); |
667 | SlotIndex MBBEndSlot = LIS.getSlotIndexes()->getMBBEndIdx(mbb: &MBB); |
668 | |
669 | GCNRPTracker::LiveRegSet LiveIn, LiveOut; |
670 | GCNRegPressure RPAtMBBEnd; |
671 | |
672 | if (UseDownwardTracker) { |
673 | if (MBB.empty()) { |
674 | LiveIn = LiveOut = getLiveRegs(SI: MBBStartSlot, LIS, MRI); |
675 | RPAtMBBEnd = getRegPressure(MRI, LiveRegs&: LiveIn); |
676 | } else { |
677 | GCNDownwardRPTracker RPT(LIS); |
678 | RPT.reset(MI: MBB.front()); |
679 | |
680 | LiveIn = RPT.getLiveRegs(); |
681 | |
682 | while (!RPT.advanceBeforeNext()) { |
683 | GCNRegPressure RPBeforeMI = RPT.getPressure(); |
684 | RPT.advanceToNext(); |
685 | RP.emplace_back(Args&: RPBeforeMI, Args: RPT.getPressure()); |
686 | } |
687 | |
688 | LiveOut = RPT.getLiveRegs(); |
689 | RPAtMBBEnd = RPT.getPressure(); |
690 | } |
691 | } else { |
692 | GCNUpwardRPTracker RPT(LIS); |
693 | RPT.reset(MRI, SI: MBBEndSlot); |
694 | |
695 | LiveOut = RPT.getLiveRegs(); |
696 | RPAtMBBEnd = RPT.getPressure(); |
697 | |
698 | for (auto &MI : reverse(C&: MBB)) { |
699 | RPT.resetMaxPressure(); |
700 | RPT.recede(MI); |
701 | if (!MI.isDebugInstr()) |
702 | RP.emplace_back(Args: RPT.getPressure(), Args: RPT.getMaxPressure()); |
703 | } |
704 | |
705 | LiveIn = RPT.getLiveRegs(); |
706 | } |
707 | |
708 | OS << PFX " Live-in: " << llvm::print(LiveRegs: LiveIn, MRI); |
709 | if (!UseDownwardTracker) |
710 | ReportLISMismatchIfAny(LiveIn, getLiveRegs(SI: MBBStartSlot, LIS, MRI)); |
711 | |
712 | OS << PFX " SGPR VGPR\n" ; |
713 | int I = 0; |
714 | for (auto &MI : MBB) { |
715 | if (!MI.isDebugInstr()) { |
716 | auto &[RPBeforeInstr, RPAtInstr] = |
717 | RP[UseDownwardTracker ? I : (RP.size() - 1 - I)]; |
718 | ++I; |
719 | OS << printRP(RPBeforeInstr) << '\n' << printRP(RPAtInstr) << " " ; |
720 | } else |
721 | OS << PFX " " ; |
722 | MI.print(OS); |
723 | } |
724 | OS << printRP(RPAtMBBEnd) << '\n'; |
725 | |
726 | OS << PFX " Live-out:" << llvm::print(LiveRegs: LiveOut, MRI); |
727 | if (UseDownwardTracker) |
728 | ReportLISMismatchIfAny(LiveOut, getLiveRegs(SI: MBBEndSlot, LIS, MRI)); |
729 | |
730 | GCNRPTracker::LiveRegSet LiveThrough; |
731 | for (auto [Reg, Mask] : LiveIn) { |
732 | LaneBitmask MaskIntersection = Mask & LiveOut.lookup(Val: Reg); |
733 | if (MaskIntersection.any()) { |
734 | LaneBitmask LTMask = getRegLiveThroughMask( |
735 | MRI, LIS, Reg, Begin: MBBStartSlot, End: MBBEndSlot, Mask: MaskIntersection); |
736 | if (LTMask.any()) |
737 | LiveThrough[Reg] = LTMask; |
738 | } |
739 | } |
740 | OS << PFX " Live-thr:" << llvm::print(LiveRegs: LiveThrough, MRI); |
741 | OS << printRP(getRegPressure(MRI, LiveRegs&: LiveThrough)) << '\n'; |
742 | } |
743 | OS << "...\n" ; |
744 | return false; |
745 | |
746 | #undef PFX |
747 | } |