1//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//==-----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements helpers for target-independent rematerialization at the MIR
11/// level.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/CodeGen/Rematerializer.h"
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/CodeGen/LiveIntervals.h"
20#include "llvm/CodeGen/MachineBasicBlock.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/Register.h"
24#include "llvm/CodeGen/TargetOpcodes.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26#include "llvm/Support/Debug.h"
27#include <optional>
28
29#define DEBUG_TYPE "rematerializer"
30
31using namespace llvm;
32using RegisterIdx = Rematerializer::RegisterIdx;
33
34// Pin the vtable to this file.
35void Rematerializer::Listener::anchor() {}
36
37/// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this
38/// implies it is also live there). When \p LI has sub-ranges, checks that
39/// all sub-ranges intersecting with \p Mask are also live at \p UseIdx.
40static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask,
41 SlotIndex UseIdx, const LiveInterval &LI) {
42 if (&OVNI != LI.getVNInfoAt(Idx: UseIdx))
43 return false;
44
45 if (LI.hasSubRanges()) {
46 // Check that intersecting subranges are live at user.
47 for (const LiveInterval::SubRange &SR : LI.subranges()) {
48 if ((SR.LaneMask & Mask).none())
49 continue;
50 if (!SR.liveAt(index: UseIdx))
51 return false;
52
53 // Early exit if all used lanes are checked. No need to continue.
54 Mask &= ~SR.LaneMask;
55 if (Mask.none())
56 break;
57 }
58 }
59 return true;
60}
61
62/// If \p MO is a virtual read register, returns it. Otherwise returns the
63/// sentinel register.
64static Register getRegDependency(const MachineOperand &MO) {
65 if (!MO.isReg() || !MO.readsReg())
66 return Register();
67 Register Reg = MO.getReg();
68 if (Reg.isPhysical()) {
69 // By the requirements on trivially rematerializable instructions, a
70 // physical register use is either constant or ignorable.
71 return Register();
72 }
73 return Reg;
74}
75
76RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx,
77 unsigned UseRegion,
78 DependencyReuseInfo &DRI) {
79 MachineInstr *FirstMI =
80 getReg(RegIdx: RootIdx).getRegionUseBounds(UseRegion, LIS).first;
81 // If there are no users in the region, rematerialize the register at the very
82 // end of the region.
83 MachineBasicBlock::iterator InsertPos =
84 FirstMI ? FirstMI : Regions[UseRegion].second;
85 RegisterIdx NewRegIdx =
86 rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI);
87 transferRegionUsers(FromRegIdx: RootIdx, ToRegIdx: NewRegIdx, UseRegion);
88 return NewRegIdx;
89}
90
91RegisterIdx
92Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion,
93 MachineBasicBlock::iterator InsertPos,
94 DependencyReuseInfo &DRI) {
95 assert(!DRI.DependencyMap.contains(RootIdx));
96 LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n');
97
98 // Traverse the root's dependency DAG depth-first to find the set of
99 // registers we must rematerialize along with it and a legal order to
100 // rematerialize them in.
101 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
102 SmallSetVector<RegisterIdx, 8> RematOrder;
103 RematOrder.insert(X: RootIdx);
104 do {
105 RegisterIdx RegIdx = DepDAG.pop_back_val();
106 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) {
107 // The dependency may already have a rematerialization ready to use.
108 if (DRI.DependencyMap.contains(Val: Dep.RegIdx))
109 continue;
110 // We may have already seen the dependency in the dependency DAG.
111 if (RematOrder.contains(key: Dep.RegIdx))
112 continue;
113 DepDAG.push_back(Elt: Dep.RegIdx);
114 RematOrder.insert(X: Dep.RegIdx);
115 }
116 } while (!DepDAG.empty());
117
118 // Rematerialize all necessary registers in the root's dependency DAG. At each
119 // rematerialization, dependencies should already be available.
120 RegisterIdx LastNewIdx;
121 for (RegisterIdx RegIdx : reverse(C&: RematOrder)) {
122 assert(!DRI.DependencyMap.contains(RegIdx) && "useless remat");
123 SmallVector<Reg::Dependency, 2> Dependencies;
124 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
125 Dependencies.emplace_back(Args: Dep.MOIdx, Args&: DRI.DependencyMap.at(Val: Dep.RegIdx));
126 LastNewIdx =
127 rematerializeReg(RegIdx, UseRegion, InsertPos, Dependencies: std::move(Dependencies));
128 DRI.DependencyMap.insert(KV: {RegIdx, LastNewIdx});
129 }
130
131 return LastNewIdx;
132}
133
134void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) {
135 auto Remats = Rematerializations.find(Val: RootIdx);
136 if (Remats == Rematerializations.end())
137 return;
138
139 LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx)
140 << '\n');
141
142 reviveRegIfDead(RootIdx);
143 // All of the rematerialization's users must use the revived register.
144 for (RegisterIdx RematRegIdx : Remats->getSecond()) {
145 for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses)
146 transferRegionUsers(FromRegIdx: RematRegIdx, ToRegIdx: RootIdx, UseRegion);
147 }
148 Rematerializations.erase(Val: RootIdx);
149
150 LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of "
151 << printID(RootIdx) << '\n');
152}
153
154void Rematerializer::rollback(RegisterIdx RematIdx) {
155 assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) &&
156 "cannot rollback dead register");
157 const RegisterIdx OriginRegIdx = getOriginOf(RematRegIdx: RematIdx);
158 reviveRegIfDead(RootIdx: OriginRegIdx);
159 for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses)
160 transferRegionUsers(FromRegIdx: RematIdx, ToRegIdx: OriginRegIdx, UseRegion);
161}
162
163void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) {
164 if (getReg(RegIdx: RootIdx).isAlive())
165 return;
166 assert(Revivable.contains(RootIdx) && "not revivable");
167
168 // Traverse the root's dependency DAG depth-first to find the set of
169 // registers we must revive and a legal order to revive them in.
170 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
171 SmallSetVector<RegisterIdx, 8> ReviveOrder;
172 ReviveOrder.insert(X: RootIdx);
173 do {
174 // All dependencies of a revived register need to be alive too.
175 const Reg &ReviveReg = getReg(RegIdx: DepDAG.pop_back_val());
176 for (const Reg::Dependency &Dep : ReviveReg.Dependencies) {
177 // We may have already seen the dependency in the dependency DAG.
178 if (ReviveOrder.contains(key: Dep.RegIdx))
179 continue;
180
181 // Dead dependencies need to be revived.
182 Reg &DepReg = Regs[Dep.RegIdx];
183 if (!DepReg.isAlive()) {
184 assert(Revivable.contains(Dep.RegIdx) && "not revivable");
185 ReviveOrder.insert(X: Dep.RegIdx);
186 DepDAG.push_back(Elt: Dep.RegIdx);
187 }
188
189 // All dependencies get a new user (the revived register).
190 DepReg.addUser(MI: ReviveReg.DefMI, Region: ReviveReg.DefRegion);
191 LISUpdates.insert(V: Dep.RegIdx);
192 }
193 } while (!DepDAG.empty());
194
195 for (RegisterIdx RegIdx : reverse(C&: ReviveOrder)) {
196 // Pick any rematerialization to retrieve the original opcode from.
197 Reg &ReviveReg = Regs[RegIdx];
198 assert(Rematerializations.contains(RegIdx) && "no remats");
199 RegisterIdx RematIdx = *Rematerializations.at(Val: RegIdx).begin();
200 ReviveReg.DefMI->setDesc(getReg(RegIdx: RematIdx).DefMI->getDesc());
201 for (const auto &[MOIdx, Reg] : Revivable.at(Val: RegIdx))
202 ReviveReg.DefMI->getOperand(i: MOIdx).setReg(Reg);
203 Revivable.erase(Val: RegIdx);
204 LISUpdates.insert(V: RegIdx);
205
206 LLVM_DEBUG({
207 dbgs() << "** Revived " << printID(RegIdx) << " @ ";
208 LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs());
209 dbgs() << '\n';
210 });
211 }
212}
213
214void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
215 unsigned UserRegion, MachineInstr &UserMI) {
216 transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
217 Regs[FromRegIdx].eraseUser(MI: &UserMI, Region: UserRegion);
218 Regs[ToRegIdx].addUser(MI: &UserMI, Region: UserRegion);
219 deleteRegIfUnused(RootIdx: FromRegIdx);
220}
221
222void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
223 RegisterIdx ToRegIdx,
224 unsigned UseRegion) {
225 auto &FromRegUsers = Regs[FromRegIdx].Uses;
226 auto UsesIt = FromRegUsers.find(Val: UseRegion);
227 if (UsesIt == FromRegUsers.end())
228 return;
229
230 const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond();
231 for (MachineInstr *UserMI : RegionUsers)
232 transferUserImpl(FromRegIdx, ToRegIdx, UserMI&: *UserMI);
233 Regs[ToRegIdx].addUsers(NewUsers: RegionUsers, Region: UseRegion);
234 FromRegUsers.erase(Val: UseRegion);
235 deleteRegIfUnused(RootIdx: FromRegIdx);
236}
237
238void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
239 RegisterIdx ToRegIdx,
240 MachineInstr &UserMI) {
241 assert(FromRegIdx != ToRegIdx && "identical registers");
242 assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) &&
243 "unrelated registers");
244
245 LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to "
246 << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n');
247
248 UserMI.substituteRegister(FromReg: getReg(RegIdx: FromRegIdx).getDefReg(),
249 ToReg: getReg(RegIdx: ToRegIdx).getDefReg(), SubIdx: 0, RegInfo: TRI);
250 LISUpdates.insert(V: FromRegIdx);
251 LISUpdates.insert(V: ToRegIdx);
252
253 // If the user is rematerializable, we must change its dependency to the
254 // new register.
255 if (RegisterIdx UserRegIdx = getDefRegIdx(MI: UserMI); UserRegIdx != NoReg) {
256 // Look for the user's dependency that matches the register.
257 for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) {
258 if (Dep.RegIdx == FromRegIdx) {
259 Dep.RegIdx = ToRegIdx;
260 return;
261 }
262 }
263 llvm_unreachable("broken dependency");
264 }
265}
266
267void Rematerializer::updateLiveIntervals() {
268 DenseSet<Register> SeenUnrematRegs;
269 for (RegisterIdx RegIdx : LISUpdates) {
270 const Reg &UpdateReg = getReg(RegIdx);
271 assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg");
272
273 Register DefReg = UpdateReg.getDefReg();
274 if (LIS.hasInterval(Reg: DefReg))
275 LIS.removeInterval(Reg: DefReg);
276 LIS.createAndComputeVirtRegInterval(Reg: DefReg);
277
278 LLVM_DEBUG({
279 dbgs() << "Re-computed interval for " << printID(RegIdx) << ": ";
280 LIS.getInterval(DefReg).print(dbgs());
281 dbgs() << '\n' << printRegUsers(RegIdx);
282 });
283
284 // Update intervals for unrematerializable operands.
285 for (unsigned MOIdx : getUnrematableOprds(RegIdx)) {
286 Register UnrematReg = UpdateReg.DefMI->getOperand(i: MOIdx).getReg();
287 if (!SeenUnrematRegs.insert(V: UnrematReg).second)
288 continue;
289 LIS.removeInterval(Reg: UnrematReg);
290 LIS.createAndComputeVirtRegInterval(Reg: UnrematReg);
291 LLVM_DEBUG(
292 dbgs() << " Re-computed interval for register "
293 << printReg(UnrematReg, &TRI,
294 UpdateReg.DefMI->getOperand(MOIdx).getSubReg(),
295 &MRI)
296 << '\n');
297 }
298 }
299 LISUpdates.clear();
300}
301
302void Rematerializer::commitRematerializations() {
303 for (auto &[RegIdx, _] : Revivable)
304 deleteReg(RegIdx);
305 Revivable.clear();
306}
307
308bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
309 ArrayRef<SlotIndex> Uses) const {
310 if (Uses.empty())
311 return true;
312 Register Reg = MO.getReg();
313 unsigned SubIdx = MO.getSubReg();
314 LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
315 : MRI.getMaxLaneMaskForVReg(Reg);
316 const LiveInterval &LI = LIS.getInterval(Reg);
317 const VNInfo *DefVN =
318 LI.getVNInfoAt(Idx: LIS.getInstructionIndex(Instr: *MO.getParent()).getRegSlot(EC: true));
319 for (SlotIndex Use : Uses) {
320 if (!isIdenticalAtUse(OVNI: *DefVN, Mask, UseIdx: Use, LI))
321 return false;
322 }
323 return true;
324}
325
326RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx,
327 unsigned Region,
328 SlotIndex Before) const {
329 auto It = Rematerializations.find(Val: getOriginOrSelf(RegIdx));
330 if (It == Rematerializations.end())
331 return NoReg;
332 const RematsOf &Remats = It->getSecond();
333
334 SlotIndex BestSlot;
335 RegisterIdx BestRegIdx = NoReg;
336 for (RegisterIdx RematRegIdx : Remats) {
337 const Reg &RematReg = getReg(RegIdx: RematRegIdx);
338 if (RematReg.DefRegion != Region || RematReg.Uses.empty())
339 continue;
340 SlotIndex RematRegSlot =
341 LIS.getInstructionIndex(Instr: *RematReg.DefMI).getRegSlot();
342 if (RematRegSlot < Before &&
343 (BestRegIdx == NoReg || RematRegSlot > BestSlot)) {
344 BestSlot = RematRegSlot;
345 BestRegIdx = RematRegIdx;
346 }
347 }
348 return BestRegIdx;
349}
350
351void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
352 if (!getReg(RegIdx: RootIdx).Uses.empty())
353 return;
354
355 // Traverse the root's dependency DAG depth-first to find the set of registers
356 // we can delete and a legal order to delete them in.
357 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
358 SmallSetVector<RegisterIdx, 8> DeleteOrder;
359 DeleteOrder.insert(X: RootIdx);
360 do {
361 // A deleted register's dependencies may be deletable too.
362 const Reg &DeleteReg = getReg(RegIdx: DepDAG.pop_back_val());
363 for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
364 // All dependencies loose a user (the delete register).
365 Reg &DepReg = Regs[Dep.RegIdx];
366 DepReg.eraseUser(MI: DeleteReg.DefMI, Region: DeleteReg.DefRegion);
367 if (DepReg.Uses.empty()) {
368 DeleteOrder.insert(X: Dep.RegIdx);
369 DepDAG.push_back(Elt: Dep.RegIdx);
370 }
371 }
372 } while (!DepDAG.empty());
373
374 for (RegisterIdx RegIdx : reverse(C&: DeleteOrder)) {
375 Reg &DeleteReg = Regs[RegIdx];
376 LIS.removeInterval(Reg: DeleteReg.getDefReg());
377 LISUpdates.erase(V: RegIdx);
378 const bool IsRematerializedReg = isRematerializedRegister(RegIdx);
379 if (SupportRollback && !IsRematerializedReg) {
380 // Replace all read registers with the null one to prevent them from
381 // showing up in use-lists, which is disallowed for debug instructions in
382 // live interval calculations. Store mappings between operand indices and
383 // original registers for potential rollback.
384 DenseMap<unsigned, Register> &RegMap =
385 Revivable.try_emplace(Key: RegIdx).first->getSecond();
386 for (auto [Idx, MO] : enumerate(First: DeleteReg.DefMI->operands())) {
387 if (MO.isReg() && MO.readsReg()) {
388 RegMap.insert(KV: {Idx, MO.getReg()});
389 MO.setReg(Register());
390 }
391 }
392 DeleteReg.DefMI->setDesc(TII.get(Opcode: TargetOpcode::DBG_VALUE));
393 } else {
394 deleteReg(RegIdx);
395 }
396 if (IsRematerializedReg) {
397 // Delete rematerialized register from its origin's rematerializations.
398 RematsOf &OriginRemats = Rematerializations.at(Val: getOriginOf(RematRegIdx: RegIdx));
399 assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
400 OriginRemats.erase(V: RegIdx);
401 if (OriginRemats.empty())
402 Rematerializations.erase(Val: RegIdx);
403 }
404 LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n");
405 }
406}
407
408void Rematerializer::deleteReg(RegisterIdx RegIdx) {
409 noteRegDeleted(RegIdx);
410
411 Reg &DeleteReg = Regs[RegIdx];
412 assert(DeleteReg.DefMI && "register was already deleted");
413 // It is not possible for the deleted instruction to be the upper region
414 // boundary since we don't ever consider them rematerializable.
415 MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first;
416 if (RegionBegin == DeleteReg.DefMI)
417 RegionBegin = std::next(x: MachineBasicBlock::iterator(DeleteReg.DefMI));
418 LIS.RemoveMachineInstrFromMaps(MI&: *DeleteReg.DefMI);
419 DeleteReg.DefMI->eraseFromParent();
420 DeleteReg.DefMI = nullptr;
421}
422
423Rematerializer::Rematerializer(MachineFunction &MF,
424 SmallVectorImpl<RegionBoundaries> &Regions,
425 LiveIntervals &LIS)
426 : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS),
427 TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) {
428#ifdef EXPENSIVE_CHECKS
429 // Check that regions are valid.
430 DenseSet<MachineInstr *> SeenMIs;
431 for (const auto &[RegionBegin, RegionEnd] : Regions) {
432 assert(RegionBegin != RegionEnd && "empty region");
433 for (auto MI = RegionBegin; MI != RegionEnd; ++MI) {
434 bool IsNewMI = SeenMIs.insert(&*MI).second;
435 assert(IsNewMI && "overlapping regions");
436 assert(!MI->isTerminator() && "terminator in region");
437 }
438 if (RegionEnd != RegionBegin->getParent()->end()) {
439 bool IsNewMI = SeenMIs.insert(&*RegionEnd).second;
440 assert(IsNewMI && "overlapping regions (upper bound)");
441 }
442 }
443#endif
444}
445
446bool Rematerializer::analyze(bool SupportRollback) {
447 Regs.clear();
448 UnrematableOprds.clear();
449 Origins.clear();
450 Rematerializations.clear();
451 RegionMBB.clear();
452 RegToIdx.clear();
453 LISUpdates.clear();
454 Revivable.clear();
455 this->SupportRollback = SupportRollback;
456 if (Regions.empty())
457 return false;
458
459 /// Maps all MIs to their parent region. Region terminators are considered
460 /// part of the region they terminate.
461 DenseMap<MachineInstr *, unsigned> MIRegion;
462
463 // Initialize MI to containing region mapping.
464 RegionMBB.reserve(N: Regions.size());
465 for (unsigned I = 0, E = Regions.size(); I < E; ++I) {
466 RegionBoundaries Region = Regions[I];
467 assert(Region.first != Region.second && "empty cannot be region");
468 for (auto MI = Region.first; MI != Region.second; ++MI) {
469 assert(!MIRegion.contains(&*MI) && "regions should not intersect");
470 MIRegion.insert(KV: {&*MI, I});
471 }
472 MachineBasicBlock &MBB = *Region.first->getParent();
473 RegionMBB.push_back(Elt: &MBB);
474
475 // A terminator instruction is considered part of the region it terminates.
476 if (Region.second != MBB.end()) {
477 MachineInstr *RegionTerm = &*Region.second;
478 assert(!MIRegion.contains(RegionTerm) && "regions should not intersect");
479 MIRegion.insert(KV: {RegionTerm, I});
480 }
481 }
482
483 const unsigned NumVirtRegs = MRI.getNumVirtRegs();
484 BitVector SeenRegs(NumVirtRegs);
485 for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
486 if (!SeenRegs[I])
487 addRegIfRematerializable(VirtRegIdx: I, MIRegion, SeenRegs);
488 }
489 assert(Regs.size() == UnrematableOprds.size());
490
491 LLVM_DEBUG({
492 for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I)
493 dbgs() << printDependencyDAG(I) << '\n';
494 });
495 return !Regs.empty();
496}
497
498void Rematerializer::addRegIfRematerializable(
499 unsigned VirtRegIdx, const DenseMap<MachineInstr *, unsigned> &MIRegion,
500 BitVector &SeenRegs) {
501 assert(!SeenRegs[VirtRegIdx] && "register already seen");
502 Register DefReg = Register::index2VirtReg(Index: VirtRegIdx);
503 SeenRegs.set(VirtRegIdx);
504
505 MachineOperand *MO = MRI.getOneDef(Reg: DefReg);
506 if (!MO)
507 return;
508 MachineInstr &DefMI = *MO->getParent();
509 if (!isMIRematerializable(MI: DefMI))
510 return;
511 auto DefRegion = MIRegion.find(Val: &DefMI);
512 if (DefRegion == MIRegion.end())
513 return;
514
515 Reg RematReg;
516 RematReg.DefMI = &DefMI;
517 RematReg.DefRegion = DefRegion->second;
518 unsigned SubIdx = DefMI.getOperand(i: 0).getSubReg();
519 RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
520 : MRI.getMaxLaneMaskForVReg(Reg: DefReg);
521
522 // Collect the candidate's direct users, both rematerializable and
523 // unrematerializable. MIs outside provided regions cannot be tracked so the
524 // registers they use are not safely rematerializable.
525 for (MachineInstr &UseMI : MRI.use_nodbg_instructions(Reg: DefReg)) {
526 if (auto UseRegion = MIRegion.find(Val: &UseMI); UseRegion != MIRegion.end())
527 RematReg.addUser(MI: &UseMI, Region: UseRegion->second);
528 else
529 return;
530 }
531 if (RematReg.Uses.empty())
532 return;
533
534 // Collect the candidate's dependencies. If the same register is used
535 // multiple times we just need to consider it once.
536 SmallDenseSet<Register, 4> AllDepRegs;
537 SmallVector<unsigned, 2> UnrematDeps;
538 for (const auto &[MOIdx, MO] : enumerate(First: RematReg.DefMI->operands())) {
539 Register DepReg = getRegDependency(MO);
540 if (!DepReg || !AllDepRegs.insert(V: DepReg).second)
541 continue;
542 unsigned DepRegIdx = DepReg.virtRegIndex();
543 if (!SeenRegs[DepRegIdx])
544 addRegIfRematerializable(VirtRegIdx: DepRegIdx, MIRegion, SeenRegs);
545 if (auto DepIt = RegToIdx.find(Val: DepReg); DepIt != RegToIdx.end())
546 RematReg.Dependencies.push_back(Elt: Reg::Dependency(MOIdx, DepIt->second));
547 else
548 UnrematDeps.push_back(Elt: MOIdx);
549 }
550
551 // The register is rematerializable.
552 RegToIdx.insert(KV: {DefReg, Regs.size()});
553 Regs.push_back(Elt: RematReg);
554 UnrematableOprds.push_back(Elt: UnrematDeps);
555}
556
557bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const {
558 if (!TII.isReMaterializable(MI))
559 return false;
560
561 assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual");
562 assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def");
563
564 for (const MachineOperand &MO : MI.all_uses()) {
565 // We can't remat physreg uses, unless it is a constant or an ignorable
566 // use (e.g. implicit exec use on VALU instructions)
567 if (MO.getReg().isPhysical()) {
568 if (MRI.isConstantPhysReg(PhysReg: MO.getReg()) || TII.isIgnorableUse(MO))
569 continue;
570 return false;
571 }
572 }
573
574 return true;
575}
576
577RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const {
578 if (!MI.getNumOperands() || !MI.getOperand(i: 0).isReg() ||
579 MI.getOperand(i: 0).readsReg())
580 return NoReg;
581 Register Reg = MI.getOperand(i: 0).getReg();
582 auto UserRegIt = RegToIdx.find(Val: Reg);
583 if (UserRegIt == RegToIdx.end())
584 return NoReg;
585 return UserRegIt->second;
586}
587
588RegisterIdx Rematerializer::rematerializeReg(
589 RegisterIdx RegIdx, unsigned UseRegion,
590 MachineBasicBlock::iterator InsertPos,
591 SmallVectorImpl<Reg::Dependency> &&Dependencies) {
592 RegisterIdx NewRegIdx = Regs.size();
593
594 Reg &NewReg = Regs.emplace_back();
595 Reg &FromReg = Regs[RegIdx];
596 NewReg.Mask = FromReg.Mask;
597 NewReg.DefRegion = UseRegion;
598 NewReg.Dependencies = std::move(Dependencies);
599
600 // Track rematerialization link between registers. Origins are always
601 // registers that existed originally, and rematerializations are always
602 // attached to them.
603 const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx);
604 Origins.push_back(Elt: OriginIdx);
605 Rematerializations[OriginIdx].insert(V: NewRegIdx);
606
607 // Use the TII to rematerialize the defining instruction with a new defined
608 // register.
609 Register NewDefReg = MRI.cloneVirtualRegister(VReg: FromReg.getDefReg());
610 TII.reMaterialize(MBB&: *RegionMBB[UseRegion], MI: InsertPos, DestReg: NewDefReg, SubIdx: 0,
611 Orig: *FromReg.DefMI);
612 NewReg.DefMI = &*std::prev(x: InsertPos);
613 RegToIdx.insert(KV: {NewDefReg, NewRegIdx});
614
615 // Update the DAG.
616 RegionBoundaries &Bounds = Regions[UseRegion];
617 if (Bounds.first == std::next(x: MachineBasicBlock::iterator(NewReg.DefMI)))
618 Bounds.first = NewReg.DefMI;
619 LIS.InsertMachineInstrInMaps(MI&: *NewReg.DefMI);
620 LISUpdates.insert(V: NewRegIdx);
621
622 // Replace dependencies as needed in the rematerialized MI. All dependencies
623 // of the latter gain a new user.
624 auto ZipedDeps = zip_equal(t&: FromReg.Dependencies, u&: NewReg.Dependencies);
625 for (const auto &[OldDep, NewDep] : ZipedDeps) {
626 assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
627 LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
628 << printID(OldDep.RegIdx) << " -> "
629 << printID(NewDep.RegIdx) << '\n');
630
631 Reg &NewDepReg = Regs[NewDep.RegIdx];
632 if (OldDep.RegIdx != NewDep.RegIdx) {
633 Register OldDefReg = FromReg.DefMI->getOperand(i: OldDep.MOIdx).getReg();
634 NewReg.DefMI->substituteRegister(FromReg: OldDefReg, ToReg: NewDepReg.getDefReg(), SubIdx: 0,
635 RegInfo: TRI);
636 LISUpdates.insert(V: OldDep.RegIdx);
637 }
638 NewDepReg.addUser(MI: NewReg.DefMI, Region: UseRegion);
639 LISUpdates.insert(V: NewDep.RegIdx);
640 }
641
642 noteRegCreated(RegIdx: NewRegIdx);
643
644 LLVM_DEBUG({
645 dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
646 << printRematReg(NewRegIdx) << '\n';
647 });
648 return NewRegIdx;
649}
650
651std::pair<MachineInstr *, MachineInstr *>
652Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion,
653 const LiveIntervals &LIS) const {
654 auto It = Uses.find(Val: UseRegion);
655 if (It == Uses.end())
656 return {nullptr, nullptr};
657 const RegionUsers &RegionUsers = It->getSecond();
658 assert(!RegionUsers.empty() && "empty userset in region");
659
660 auto User = RegionUsers.begin(), UserEnd = RegionUsers.end();
661 MachineInstr *FirstMI = *User, *LastMI = FirstMI;
662 SlotIndex FirstIndex = LIS.getInstructionIndex(Instr: *FirstMI),
663 LastIndex = FirstIndex;
664
665 while (++User != UserEnd) {
666 SlotIndex UserIndex = LIS.getInstructionIndex(Instr: **User);
667 if (UserIndex < FirstIndex) {
668 FirstIndex = UserIndex;
669 FirstMI = *User;
670 } else if (UserIndex > LastIndex) {
671 LastIndex = UserIndex;
672 LastMI = *User;
673 }
674 }
675
676 return {FirstMI, LastMI};
677}
678
679void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) {
680 Uses[Region].insert(V: MI);
681}
682
683void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers,
684 unsigned Region) {
685 Uses[Region].insert_range(R: NewUsers);
686}
687
688void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) {
689 RegionUsers &RUsers = Uses.at(Val: Region);
690 assert(RUsers.contains(MI) && "user not in region");
691 if (RUsers.size() == 1)
692 Uses.erase(Val: Region);
693 else
694 RUsers.erase(V: MI);
695}
696
697Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const {
698 return Printable([&, RootIdx](raw_ostream &OS) {
699 DenseMap<RegisterIdx, unsigned> RegDepths;
700 std::function<void(RegisterIdx, unsigned)> WalkTree =
701 [&](RegisterIdx RegIdx, unsigned Depth) -> void {
702 unsigned MaxDepth = std::max(a: RegDepths.lookup_or(Val: RegIdx, Default&: Depth), b: Depth);
703 RegDepths.emplace_or_assign(Key: RegIdx, Args&: MaxDepth);
704 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
705 WalkTree(Dep.RegIdx, Depth + 1);
706 };
707 WalkTree(RootIdx, 0);
708
709 // Sort in decreasing depth order to print root at the bottom.
710 SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(),
711 RegDepths.end());
712 sort(C&: Regs, Comp: [](const auto &LHS, const auto &RHS) {
713 return LHS.second > RHS.second;
714 });
715
716 OS << printID(RegIdx: RootIdx) << " has " << Regs.size() - 1 << " dependencies\n";
717 for (const auto &[RegIdx, Depth] : Regs) {
718 OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' '
719 << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n';
720 }
721 OS << printRegUsers(RegIdx: RootIdx);
722 });
723}
724
725Printable Rematerializer::printID(RegisterIdx RegIdx) const {
726 return Printable([&, RegIdx](raw_ostream &OS) {
727 const Reg &PrintReg = getReg(RegIdx);
728 OS << '(' << RegIdx << '/';
729 if (!PrintReg.DefMI) {
730 OS << "<dead>";
731 } else {
732 OS << printReg(Reg: PrintReg.getDefReg(), TRI: &TRI,
733 SubIdx: PrintReg.DefMI->getOperand(i: 0).getSubReg(), MRI: &MRI);
734 }
735 OS << ")[" << PrintReg.DefRegion << "]";
736 });
737}
738
739Printable Rematerializer::printRematReg(RegisterIdx RegIdx,
740 bool SkipRegions) const {
741 return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) {
742 const Reg &PrintReg = getReg(RegIdx);
743 if (!SkipRegions) {
744 OS << printID(RegIdx) << " [" << PrintReg.DefRegion;
745 if (!PrintReg.Uses.empty()) {
746 assert(PrintReg.DefMI && "dead register cannot have uses");
747 const LiveInterval &LI = LIS.getInterval(Reg: PrintReg.getDefReg());
748 // First display all regions in which the register is live-through and
749 // not used.
750 bool First = true;
751 for (const auto [I, Bounds] : enumerate(First&: Regions)) {
752 if (Bounds.first == Bounds.second)
753 continue;
754 if (!PrintReg.Uses.contains(Val: I) &&
755 LI.liveAt(index: LIS.getInstructionIndex(Instr: *Bounds.first)) &&
756 LI.liveAt(index: LIS.getInstructionIndex(Instr: *std::prev(x: Bounds.second))
757 .getRegSlot())) {
758 OS << (First ? " - " : ",") << I;
759 First = false;
760 }
761 }
762 OS << (First ? " --> " : " -> ");
763
764 // Then display regions in which the register is used.
765 auto It = PrintReg.Uses.begin();
766 OS << It->first;
767 while (++It != PrintReg.Uses.end())
768 OS << "," << It->first;
769 }
770 OS << "] ";
771 }
772 OS << printID(RegIdx) << ' ';
773 PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
774 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
775 OS << " @ ";
776 LIS.getInstructionIndex(Instr: *PrintReg.DefMI).print(os&: OS);
777 });
778}
779
780Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const {
781 return Printable([&, RegIdx](raw_ostream &OS) {
782 for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) {
783 for (MachineInstr *MI : Users)
784 OS << " User " << printUser(MI, UseRegion) << '\n';
785 }
786 });
787}
788
789Printable Rematerializer::printUser(const MachineInstr *MI,
790 std::optional<unsigned> UseRegion) const {
791 return Printable([&, MI, UseRegion](raw_ostream &OS) {
792 RegisterIdx RegIdx = getDefRegIdx(MI: *MI);
793 if (RegIdx != NoReg) {
794 OS << printID(RegIdx);
795 } else {
796 OS << "(-/-)[";
797 if (UseRegion)
798 OS << *UseRegion;
799 else
800 OS << '?';
801 OS << ']';
802 }
803 OS << ' ';
804 MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
805 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
806 OS << " @ ";
807 LIS.getInstructionIndex(Instr: *MI).print(os&: OS);
808 });
809}
810