1//===- GCNRegPressure.h -----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file defines the GCNRegPressure class, which tracks registry pressure
11/// by bookkeeping number of SGPR/VGPRs used, weights for large SGPR/VGPRs. It
12/// also implements a compare function, which compares different register
13/// pressures, and declares one with max occupancy as winner.
14///
15//===----------------------------------------------------------------------===//
16
17#ifndef LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
18#define LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
19
20#include "GCNSubtarget.h"
21#include "llvm/CodeGen/LiveIntervals.h"
22#include <algorithm>
23
24namespace llvm {
25
26class MachineRegisterInfo;
27class raw_ostream;
28class SlotIndex;
29
30struct GCNRegPressure {
31 enum RegKind {
32 SGPR32,
33 SGPR_TUPLE,
34 VGPR32,
35 VGPR_TUPLE,
36 AGPR32,
37 AGPR_TUPLE,
38 TOTAL_KINDS
39 };
40
41 GCNRegPressure() {
42 clear();
43 }
44
45 bool empty() const { return getSGPRNum() == 0 && getVGPRNum(UnifiedVGPRFile: false) == 0; }
46
47 void clear() { std::fill(first: &Value[0], last: &Value[TOTAL_KINDS], value: 0); }
48
49 unsigned getSGPRNum() const { return Value[SGPR32]; }
50 unsigned getVGPRNum(bool UnifiedVGPRFile) const {
51 if (UnifiedVGPRFile) {
52 return Value[AGPR32] ? alignTo(Value: Value[VGPR32], Align: 4) + Value[AGPR32]
53 : Value[VGPR32] + Value[AGPR32];
54 }
55 return std::max(a: Value[VGPR32], b: Value[AGPR32]);
56 }
57 unsigned getAGPRNum() const { return Value[AGPR32]; }
58
59 unsigned getVGPRTuplesWeight() const { return std::max(a: Value[VGPR_TUPLE],
60 b: Value[AGPR_TUPLE]); }
61 unsigned getSGPRTuplesWeight() const { return Value[SGPR_TUPLE]; }
62
63 unsigned getOccupancy(const GCNSubtarget &ST) const {
64 return std::min(a: ST.getOccupancyWithNumSGPRs(SGPRs: getSGPRNum()),
65 b: ST.getOccupancyWithNumVGPRs(VGPRs: getVGPRNum(UnifiedVGPRFile: ST.hasGFX90AInsts())));
66 }
67
68 void inc(unsigned Reg,
69 LaneBitmask PrevMask,
70 LaneBitmask NewMask,
71 const MachineRegisterInfo &MRI);
72
73 bool higherOccupancy(const GCNSubtarget &ST, const GCNRegPressure& O) const {
74 return getOccupancy(ST) > O.getOccupancy(ST);
75 }
76
77 /// Compares \p this GCNRegpressure to \p O, returning true if \p this is
78 /// less. Since GCNRegpressure contains different types of pressures, and due
79 /// to target-specific pecularities (e.g. we care about occupancy rather than
80 /// raw register usage), we determine if \p this GCNRegPressure is less than
81 /// \p O based on the following tiered comparisons (in order order of
82 /// precedence):
83 /// 1. Better occupancy
84 /// 2. Less spilling (first preference to VGPR spills, then to SGPR spills)
85 /// 3. Less tuple register pressure (first preference to VGPR tuples if we
86 /// determine that SGPR pressure is not important)
87 /// 4. Less raw register pressure (first preference to VGPR tuples if we
88 /// determine that SGPR pressure is not important)
89 bool less(const MachineFunction &MF, const GCNRegPressure &O,
90 unsigned MaxOccupancy = std::numeric_limits<unsigned>::max()) const;
91
92 bool operator==(const GCNRegPressure &O) const {
93 return std::equal(first1: &Value[0], last1: &Value[TOTAL_KINDS], first2: O.Value);
94 }
95
96 bool operator!=(const GCNRegPressure &O) const {
97 return !(*this == O);
98 }
99
100 GCNRegPressure &operator+=(const GCNRegPressure &RHS) {
101 for (unsigned I = 0; I < TOTAL_KINDS; ++I)
102 Value[I] += RHS.Value[I];
103 return *this;
104 }
105
106 GCNRegPressure &operator-=(const GCNRegPressure &RHS) {
107 for (unsigned I = 0; I < TOTAL_KINDS; ++I)
108 Value[I] -= RHS.Value[I];
109 return *this;
110 }
111
112 void dump() const;
113
114private:
115 unsigned Value[TOTAL_KINDS];
116
117 static unsigned getRegKind(Register Reg, const MachineRegisterInfo &MRI);
118
119 friend GCNRegPressure max(const GCNRegPressure &P1,
120 const GCNRegPressure &P2);
121
122 friend Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST);
123};
124
125inline GCNRegPressure max(const GCNRegPressure &P1, const GCNRegPressure &P2) {
126 GCNRegPressure Res;
127 for (unsigned I = 0; I < GCNRegPressure::TOTAL_KINDS; ++I)
128 Res.Value[I] = std::max(a: P1.Value[I], b: P2.Value[I]);
129 return Res;
130}
131
132inline GCNRegPressure operator+(const GCNRegPressure &P1,
133 const GCNRegPressure &P2) {
134 GCNRegPressure Sum = P1;
135 Sum += P2;
136 return Sum;
137}
138
139inline GCNRegPressure operator-(const GCNRegPressure &P1,
140 const GCNRegPressure &P2) {
141 GCNRegPressure Diff = P1;
142 Diff -= P2;
143 return Diff;
144}
145
146class GCNRPTracker {
147public:
148 using LiveRegSet = DenseMap<unsigned, LaneBitmask>;
149
150protected:
151 const LiveIntervals &LIS;
152 LiveRegSet LiveRegs;
153 GCNRegPressure CurPressure, MaxPressure;
154 const MachineInstr *LastTrackedMI = nullptr;
155 mutable const MachineRegisterInfo *MRI = nullptr;
156
157 GCNRPTracker(const LiveIntervals &LIS_) : LIS(LIS_) {}
158
159 void reset(const MachineInstr &MI, const LiveRegSet *LiveRegsCopy,
160 bool After);
161
162public:
163 // live regs for the current state
164 const decltype(LiveRegs) &getLiveRegs() const { return LiveRegs; }
165 const MachineInstr *getLastTrackedMI() const { return LastTrackedMI; }
166
167 void clearMaxPressure() { MaxPressure.clear(); }
168
169 GCNRegPressure getPressure() const { return CurPressure; }
170
171 decltype(LiveRegs) moveLiveRegs() {
172 return std::move(LiveRegs);
173 }
174};
175
176GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
177 const MachineRegisterInfo &MRI);
178
179class GCNUpwardRPTracker : public GCNRPTracker {
180public:
181 GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
182
183 // reset tracker and set live register set to the specified value.
184 void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
185
186 // reset tracker at the specified slot index.
187 void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
188 reset(MRI_: MRI, LiveRegs_: llvm::getLiveRegs(SI, LIS, MRI));
189 }
190
191 // reset tracker to the end of the MBB.
192 void reset(const MachineBasicBlock &MBB) {
193 reset(MRI: MBB.getParent()->getRegInfo(),
194 SI: LIS.getSlotIndexes()->getMBBEndIdx(mbb: &MBB));
195 }
196
197 // reset tracker to the point just after MI (in program order).
198 void reset(const MachineInstr &MI) {
199 reset(MRI: MI.getMF()->getRegInfo(), SI: LIS.getInstructionIndex(Instr: MI).getDeadSlot());
200 }
201
202 // move to the state just before the MI (in program order).
203 void recede(const MachineInstr &MI);
204
205 // checks whether the tracker's state after receding MI corresponds
206 // to reported by LIS.
207 bool isValid() const;
208
209 const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
210
211 void resetMaxPressure() { MaxPressure = CurPressure; }
212
213 GCNRegPressure getMaxPressureAndReset() {
214 GCNRegPressure RP = MaxPressure;
215 resetMaxPressure();
216 return RP;
217 }
218};
219
220class GCNDownwardRPTracker : public GCNRPTracker {
221 // Last position of reset or advanceBeforeNext
222 MachineBasicBlock::const_iterator NextMI;
223
224 MachineBasicBlock::const_iterator MBBEnd;
225
226public:
227 GCNDownwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
228
229 MachineBasicBlock::const_iterator getNext() const { return NextMI; }
230
231 // Return MaxPressure and clear it.
232 GCNRegPressure moveMaxPressure() {
233 auto Res = MaxPressure;
234 MaxPressure.clear();
235 return Res;
236 }
237
238 // Reset tracker to the point before the MI
239 // filling live regs upon this point using LIS.
240 // Returns false if block is empty except debug values.
241 bool reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
242
243 // Move to the state right before the next MI or after the end of MBB.
244 // Returns false if reached end of the block.
245 bool advanceBeforeNext();
246
247 // Move to the state at the MI, advanceBeforeNext has to be called first.
248 void advanceToNext();
249
250 // Move to the state at the next MI. Returns false if reached end of block.
251 bool advance();
252
253 // Advance instructions until before End.
254 bool advance(MachineBasicBlock::const_iterator End);
255
256 // Reset to Begin and advance to End.
257 bool advance(MachineBasicBlock::const_iterator Begin,
258 MachineBasicBlock::const_iterator End,
259 const LiveRegSet *LiveRegsCopy = nullptr);
260};
261
262LaneBitmask getLiveLaneMask(unsigned Reg,
263 SlotIndex SI,
264 const LiveIntervals &LIS,
265 const MachineRegisterInfo &MRI);
266
267LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
268 const MachineRegisterInfo &MRI);
269
270GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
271 const MachineRegisterInfo &MRI);
272
273/// creates a map MachineInstr -> LiveRegSet
274/// R - range of iterators on instructions
275/// After - upon entry or exit of every instruction
276/// Note: there is no entry in the map for instructions with empty live reg set
277/// Complexity = O(NumVirtRegs * averageLiveRangeSegmentsPerReg * lg(R))
278template <typename Range>
279DenseMap<MachineInstr*, GCNRPTracker::LiveRegSet>
280getLiveRegMap(Range &&R, bool After, LiveIntervals &LIS) {
281 std::vector<SlotIndex> Indexes;
282 Indexes.reserve(n: std::distance(R.begin(), R.end()));
283 auto &SII = *LIS.getSlotIndexes();
284 for (MachineInstr *I : R) {
285 auto SI = SII.getInstructionIndex(MI: *I);
286 Indexes.push_back(x: After ? SI.getDeadSlot() : SI.getBaseIndex());
287 }
288 llvm::sort(C&: Indexes);
289
290 auto &MRI = (*R.begin())->getParent()->getParent()->getRegInfo();
291 DenseMap<MachineInstr *, GCNRPTracker::LiveRegSet> LiveRegMap;
292 SmallVector<SlotIndex, 32> LiveIdxs, SRLiveIdxs;
293 for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
294 auto Reg = Register::index2VirtReg(Index: I);
295 if (!LIS.hasInterval(Reg))
296 continue;
297 auto &LI = LIS.getInterval(Reg);
298 LiveIdxs.clear();
299 if (!LI.findIndexesLiveAt(R&: Indexes, O: std::back_inserter(x&: LiveIdxs)))
300 continue;
301 if (!LI.hasSubRanges()) {
302 for (auto SI : LiveIdxs)
303 LiveRegMap[SII.getInstructionFromIndex(index: SI)][Reg] =
304 MRI.getMaxLaneMaskForVReg(Reg);
305 } else
306 for (const auto &S : LI.subranges()) {
307 // constrain search for subranges by indexes live at main range
308 SRLiveIdxs.clear();
309 S.findIndexesLiveAt(R&: LiveIdxs, O: std::back_inserter(x&: SRLiveIdxs));
310 for (auto SI : SRLiveIdxs)
311 LiveRegMap[SII.getInstructionFromIndex(index: SI)][Reg] |= S.LaneMask;
312 }
313 }
314 return LiveRegMap;
315}
316
317inline GCNRPTracker::LiveRegSet getLiveRegsAfter(const MachineInstr &MI,
318 const LiveIntervals &LIS) {
319 return getLiveRegs(SI: LIS.getInstructionIndex(Instr: MI).getDeadSlot(), LIS,
320 MRI: MI.getParent()->getParent()->getRegInfo());
321}
322
323inline GCNRPTracker::LiveRegSet getLiveRegsBefore(const MachineInstr &MI,
324 const LiveIntervals &LIS) {
325 return getLiveRegs(SI: LIS.getInstructionIndex(Instr: MI).getBaseIndex(), LIS,
326 MRI: MI.getParent()->getParent()->getRegInfo());
327}
328
329template <typename Range>
330GCNRegPressure getRegPressure(const MachineRegisterInfo &MRI,
331 Range &&LiveRegs) {
332 GCNRegPressure Res;
333 for (const auto &RM : LiveRegs)
334 Res.inc(Reg: RM.first, PrevMask: LaneBitmask::getNone(), NewMask: RM.second, MRI);
335 return Res;
336}
337
338bool isEqual(const GCNRPTracker::LiveRegSet &S1,
339 const GCNRPTracker::LiveRegSet &S2);
340
341Printable print(const GCNRegPressure &RP, const GCNSubtarget *ST = nullptr);
342
343Printable print(const GCNRPTracker::LiveRegSet &LiveRegs,
344 const MachineRegisterInfo &MRI);
345
346Printable reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
347 const GCNRPTracker::LiveRegSet &TrackedL,
348 const TargetRegisterInfo *TRI, StringRef Pfx = " ");
349
350struct GCNRegPressurePrinter : public MachineFunctionPass {
351 static char ID;
352
353public:
354 GCNRegPressurePrinter() : MachineFunctionPass(ID) {}
355
356 bool runOnMachineFunction(MachineFunction &MF) override;
357
358 void getAnalysisUsage(AnalysisUsage &AU) const override {
359 AU.addRequired<LiveIntervalsWrapperPass>();
360 AU.setPreservesAll();
361 MachineFunctionPass::getAnalysisUsage(AU);
362 }
363};
364
365} // end namespace llvm
366
367#endif // LLVM_LIB_TARGET_AMDGPU_GCNREGPRESSURE_H
368