1//===- RDFRegisters.cpp ---------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ADT/BitVector.h"
10#include "llvm/CodeGen/MachineFunction.h"
11#include "llvm/CodeGen/MachineInstr.h"
12#include "llvm/CodeGen/MachineOperand.h"
13#include "llvm/CodeGen/RDFRegisters.h"
14#include "llvm/CodeGen/TargetRegisterInfo.h"
15#include "llvm/MC/LaneBitmask.h"
16#include "llvm/MC/MCRegisterInfo.h"
17#include "llvm/Support/ErrorHandling.h"
18#include "llvm/Support/Format.h"
19#include "llvm/Support/MathExtras.h"
20#include "llvm/Support/raw_ostream.h"
21#include <cassert>
22#include <cstdint>
23#include <set>
24#include <utility>
25
26namespace llvm::rdf {
27
28PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
29 const MachineFunction &mf)
30 : TRI(tri) {
31 RegInfos.resize(new_size: TRI.getNumRegs());
32
33 BitVector BadRC(TRI.getNumRegs());
34 for (const TargetRegisterClass *RC : TRI.regclasses()) {
35 for (MCPhysReg R : *RC) {
36 RegInfo &RI = RegInfos[R];
37 if (RI.RegClass != nullptr && !BadRC[R]) {
38 if (RC->LaneMask != RI.RegClass->LaneMask) {
39 BadRC.set(R);
40 RI.RegClass = nullptr;
41 }
42 } else
43 RI.RegClass = RC;
44 }
45 }
46
47 UnitInfos.resize(S: TRI.getNumRegUnits());
48
49 for (MCRegUnit U : TRI.regunits()) {
50 if (UnitInfos[U].Reg != 0)
51 continue;
52 MCRegUnitRootIterator R(U, &TRI);
53 assert(R.isValid());
54 RegisterId F = *R;
55 ++R;
56 if (R.isValid()) {
57 UnitInfos[U].Mask = LaneBitmask::getAll();
58 UnitInfos[U].Reg = F;
59 } else {
60 for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
61 std::pair<MCRegUnit, LaneBitmask> P = *I;
62 UnitInfo &UI = UnitInfos[P.first];
63 UI.Reg = F;
64 UI.Mask = P.second;
65 }
66 }
67 }
68
69 for (const uint32_t *RM : TRI.getRegMasks())
70 RegMasks.insert(Val: RM);
71 for (const MachineBasicBlock &B : mf)
72 for (const MachineInstr &In : B)
73 for (const MachineOperand &Op : In.operands())
74 if (Op.isRegMask())
75 RegMasks.insert(Val: Op.getRegMask());
76
77 MaskInfos.resize(new_size: RegMasks.size() + 1);
78 for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
79 BitVector PU(TRI.getNumRegUnits());
80 const uint32_t *MB = RegMasks.get(Idx: M);
81 for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
82 if (!(MB[I / 32] & (1u << (I % 32))))
83 continue;
84 for (MCRegUnit Unit : TRI.regunits(Reg: MCRegister::from(Val: I)))
85 PU.set(static_cast<unsigned>(Unit));
86 }
87 MaskInfos[M].Units = PU.flip();
88 }
89
90 AliasInfos.resize(S: TRI.getNumRegUnits());
91 for (MCRegUnit U : TRI.regunits()) {
92 BitVector AS(TRI.getNumRegs());
93 for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
94 for (MCPhysReg S : TRI.superregs_inclusive(Reg: *R))
95 AS.set(S);
96 AliasInfos[U].Regs = AS;
97 }
98}
99
100bool PhysicalRegisterInfo::alias(RegisterRef RA, RegisterRef RB) const {
101 return !disjoint(A: getUnits(RR: RA), B: getUnits(RR: RB));
102}
103
104std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterRef RR) const {
105 // Do not include Reg in the alias set.
106 std::set<RegisterId> AS;
107 assert(!RR.isUnit() && "No units allowed");
108 if (RR.isMask()) {
109 // XXX SLOW
110 const uint32_t *MB = getRegMaskBits(RR);
111 for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
112 if (MB[i / 32] & (1u << (i % 32)))
113 continue;
114 AS.insert(x: i);
115 }
116 return AS;
117 }
118
119 assert(RR.isReg());
120 for (MCRegAliasIterator AI(RR.asMCReg(), &TRI, false); AI.isValid(); ++AI)
121 AS.insert(x: *AI);
122
123 return AS;
124}
125
126std::set<RegisterId> PhysicalRegisterInfo::getUnits(RegisterRef RR) const {
127 std::set<RegisterId> Units;
128
129 if (RR.isReg()) {
130 if (RR.Mask.none())
131 return Units; // Empty
132 for (MCRegUnitMaskIterator UM(RR.asMCReg(), &TRI); UM.isValid(); ++UM) {
133 auto [U, M] = *UM;
134 if ((M & RR.Mask).any())
135 Units.insert(x: static_cast<unsigned>(U));
136 }
137 return Units;
138 }
139
140 assert(RR.isMask());
141 unsigned NumRegs = TRI.getNumRegs();
142 const uint32_t *MB = getRegMaskBits(RR);
143 for (unsigned I = 0, E = (NumRegs + 31) / 32; I != E; ++I) {
144 uint32_t C = ~MB[I]; // Clobbered regs
145 if (I == 0) // Reg 0 should be ignored
146 C &= maskLeadingOnes<unsigned>(N: 31);
147 if (I + 1 == E && NumRegs % 32 != 0) // Last word may be partial
148 C &= maskTrailingOnes<unsigned>(N: NumRegs % 32);
149 if (C == 0)
150 continue;
151 while (C != 0) {
152 unsigned T = llvm::countr_zero(Val: C);
153 unsigned CR = 32 * I + T; // Clobbered reg
154 for (MCRegUnit U : TRI.regunits(Reg: CR))
155 Units.insert(x: static_cast<unsigned>(U));
156 C &= ~(1u << T);
157 }
158 }
159 return Units;
160}
161
162RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, RegisterId R) const {
163 if (RR.Id == R)
164 return RR;
165 if (unsigned Idx = TRI.getSubRegIndex(RegNo: RegisterRef(R).asMCReg(), SubRegNo: RR.asMCReg()))
166 return RegisterRef(R, TRI.composeSubRegIndexLaneMask(IdxA: Idx, Mask: RR.Mask));
167 if (unsigned Idx =
168 TRI.getSubRegIndex(RegNo: RR.asMCReg(), SubRegNo: RegisterRef(R).asMCReg())) {
169 const RegInfo &RI = RegInfos[R];
170 LaneBitmask RCM =
171 RI.RegClass ? RI.RegClass->LaneMask : LaneBitmask::getAll();
172 LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(IdxA: Idx, LaneMask: RR.Mask);
173 return RegisterRef(R, M & RCM);
174 }
175 llvm_unreachable("Invalid arguments: unrelated registers?");
176}
177
178bool PhysicalRegisterInfo::equal_to(RegisterRef A, RegisterRef B) const {
179 if (!A.isReg() || !B.isReg()) {
180 // For non-regs, or comparing reg and non-reg, use only the Id member.
181 return A.Id == B.Id;
182 }
183
184 if (A.Id == B.Id)
185 return A.Mask == B.Mask;
186
187 // Compare reg units lexicographically.
188 MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
189 MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
190 while (AI.isValid() && BI.isValid()) {
191 auto [AReg, AMask] = *AI;
192 auto [BReg, BMask] = *BI;
193
194 // If both iterators point to a unit contained in both A and B, then
195 // compare the units.
196 if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
197 if (AReg != BReg)
198 return false;
199 // Units are equal, move on to the next ones.
200 ++AI;
201 ++BI;
202 continue;
203 }
204
205 if ((AMask & A.Mask).none())
206 ++AI;
207 if ((BMask & B.Mask).none())
208 ++BI;
209 }
210 // One or both have reached the end.
211 return static_cast<int>(AI.isValid()) == static_cast<int>(BI.isValid());
212}
213
214bool PhysicalRegisterInfo::less(RegisterRef A, RegisterRef B) const {
215 if (!A.isReg() || !B.isReg()) {
216 // For non-regs, or comparing reg and non-reg, use only the Id member.
217 return A.Id < B.Id;
218 }
219
220 if (A.Id == B.Id)
221 return A.Mask < B.Mask;
222 if (A.Mask == B.Mask)
223 return A.Id < B.Id;
224
225 // Compare reg units lexicographically.
226 llvm::MCRegUnitMaskIterator AI(A.asMCReg(), &getTRI());
227 llvm::MCRegUnitMaskIterator BI(B.asMCReg(), &getTRI());
228 while (AI.isValid() && BI.isValid()) {
229 auto [AReg, AMask] = *AI;
230 auto [BReg, BMask] = *BI;
231
232 // If both iterators point to a unit contained in both A and B, then
233 // compare the units.
234 if ((AMask & A.Mask).any() && (BMask & B.Mask).any()) {
235 if (AReg != BReg)
236 return AReg < BReg;
237 // Units are equal, move on to the next ones.
238 ++AI;
239 ++BI;
240 continue;
241 }
242
243 if ((AMask & A.Mask).none())
244 ++AI;
245 if ((BMask & B.Mask).none())
246 ++BI;
247 }
248 // One or both have reached the end: assume invalid < valid.
249 return static_cast<int>(AI.isValid()) < static_cast<int>(BI.isValid());
250}
251
252void PhysicalRegisterInfo::print(raw_ostream &OS, RegisterRef A) const {
253 if (A.isReg()) {
254 MCRegister Reg = A.asMCReg();
255 if (Reg && Reg.id() < TRI.getNumRegs())
256 OS << TRI.getName(RegNo: Reg);
257 else
258 OS << printReg(Reg, TRI: &TRI);
259 OS << PrintLaneMaskShort(A.Mask);
260 } else if (A.isUnit()) {
261 OS << printRegUnit(Unit: A.asMCRegUnit(), TRI: &TRI);
262 } else {
263 unsigned Idx = A.asMaskIdx();
264 const char *Fmt = Idx < 0x10000 ? "%04x" : "%08x";
265 OS << "M#" << format(Fmt, Vals: Idx);
266 }
267}
268
269void PhysicalRegisterInfo::print(raw_ostream &OS, const RegisterAggr &A) const {
270 OS << '{';
271 for (unsigned U : A.units())
272 OS << ' ' << printRegUnit(Unit: static_cast<MCRegUnit>(U), TRI: &TRI);
273 OS << " }";
274}
275
276bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
277 if (RR.isMask())
278 return Units.anyCommon(RHS: PRI.getMaskUnits(RR));
279
280 for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
281 auto [Unit, LaneMask] = *U;
282 if ((LaneMask & RR.Mask).any())
283 if (Units.test(Idx: static_cast<unsigned>(Unit)))
284 return true;
285 }
286 return false;
287}
288
289bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
290 if (RR.isMask())
291 return PRI.getMaskUnits(RR).subsetOf(RHS: Units);
292
293 for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
294 auto [Unit, LaneMask] = *U;
295 if ((LaneMask & RR.Mask).any())
296 if (!Units.test(Idx: static_cast<unsigned>(Unit)))
297 return false;
298 }
299 return true;
300}
301
302RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
303 if (RR.isMask()) {
304 Units |= PRI.getMaskUnits(RR);
305 return *this;
306 }
307
308 for (MCRegUnitMaskIterator U(RR.asMCReg(), &PRI.getTRI()); U.isValid(); ++U) {
309 auto [Unit, LaneMask] = *U;
310 if ((LaneMask & RR.Mask).any())
311 Units.set(static_cast<unsigned>(Unit));
312 }
313 return *this;
314}
315
316RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
317 Units |= RG.Units;
318 return *this;
319}
320
321RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
322 return intersect(RG: RegisterAggr(PRI).insert(RR));
323}
324
325RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
326 Units &= RG.Units;
327 return *this;
328}
329
330RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
331 return clear(RG: RegisterAggr(PRI).insert(RR));
332}
333
334RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
335 Units.reset(RHS: RG.Units);
336 return *this;
337}
338
339RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
340 RegisterAggr T(PRI);
341 T.insert(RR).intersect(RG: *this);
342 if (T.empty())
343 return RegisterRef();
344 RegisterRef NR = T.makeRegRef();
345 assert(NR);
346 return NR;
347}
348
349RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
350 return RegisterAggr(PRI).insert(RR).clear(RG: *this).makeRegRef();
351}
352
353RegisterRef RegisterAggr::makeRegRef() const {
354 int U = Units.find_first();
355 if (U < 0)
356 return RegisterRef();
357
358 // Find the set of all registers that are aliased to all the units
359 // in this aggregate.
360
361 // Get all the registers aliased to the first unit in the bit vector.
362 BitVector Regs = PRI.getUnitAliases(U: static_cast<MCRegUnit>(U));
363 U = Units.find_next(Prev: U);
364
365 // For each other unit, intersect it with the set of all registers
366 // aliased that unit.
367 while (U >= 0) {
368 Regs &= PRI.getUnitAliases(U: static_cast<MCRegUnit>(U));
369 U = Units.find_next(Prev: U);
370 }
371
372 // If there is at least one register remaining, pick the first one,
373 // and consolidate the masks of all of its units contained in this
374 // aggregate.
375
376 int F = Regs.find_first();
377 if (F <= 0)
378 return RegisterRef();
379
380 LaneBitmask M;
381 for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
382 auto [Unit, LaneMask] = *I;
383 if (Units.test(Idx: static_cast<unsigned>(Unit)))
384 M |= LaneMask;
385 }
386 return RegisterRef(F, M);
387}
388
389RegisterAggr::ref_iterator::ref_iterator(const RegisterAggr &RG, bool End)
390 : Owner(&RG) {
391 for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(Prev: U)) {
392 RegisterRef R = RG.PRI.getRefForUnit(U: static_cast<MCRegUnit>(U));
393 Masks[R.Id] |= R.Mask;
394 }
395 Pos = End ? Masks.end() : Masks.begin();
396 Index = End ? Masks.size() : 0;
397}
398
399raw_ostream &operator<<(raw_ostream &OS, const RegisterAggr &A) {
400 A.getPRI().print(OS, A);
401 return OS;
402}
403
404raw_ostream &operator<<(raw_ostream &OS, const PrintLaneMaskShort &P) {
405 if (P.Mask.all())
406 return OS;
407 if (P.Mask.none())
408 return OS << ":*none*";
409
410 LaneBitmask::Type Val = P.Mask.getAsInteger();
411 if ((Val & 0xffff) == Val)
412 return OS << ':' << format(Fmt: "%04llX", Vals: Val);
413 if ((Val & 0xffffffff) == Val)
414 return OS << ':' << format(Fmt: "%08llX", Vals: Val);
415 return OS << ':' << PrintLaneMask(LaneMask: P.Mask);
416}
417
418} // namespace llvm::rdf
419