1//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//==-----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements helpers for target-independent rematerialization at the MIR
11/// level.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/CodeGen/Rematerializer.h"
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/CodeGen/LiveIntervals.h"
20#include "llvm/CodeGen/MachineBasicBlock.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/Register.h"
24#include "llvm/CodeGen/TargetOpcodes.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26#include "llvm/Support/Debug.h"
27
28#define DEBUG_TYPE "rematerializer"
29
30using namespace llvm;
31using RegisterIdx = Rematerializer::RegisterIdx;
32
33/// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this
34/// implies it is also live there). When \p LI has sub-ranges, checks that
35/// all sub-ranges intersecting with \p Mask are also live at \p UseIdx.
36static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask,
37 SlotIndex UseIdx, const LiveInterval &LI) {
38 if (&OVNI != LI.getVNInfoAt(Idx: UseIdx))
39 return false;
40
41 if (LI.hasSubRanges()) {
42 // Check that intersecting subranges are live at user.
43 for (const LiveInterval::SubRange &SR : LI.subranges()) {
44 if ((SR.LaneMask & Mask).none())
45 continue;
46 if (!SR.liveAt(index: UseIdx))
47 return false;
48
49 // Early exit if all used lanes are checked. No need to continue.
50 Mask &= ~SR.LaneMask;
51 if (Mask.none())
52 break;
53 }
54 }
55 return true;
56}
57
58/// If \p MO is a virtual read register, returns it. Otherwise returns the
59/// sentinel register.
60static Register getRegDependency(const MachineOperand &MO) {
61 if (!MO.isReg() || !MO.readsReg())
62 return Register();
63 Register Reg = MO.getReg();
64 if (Reg.isPhysical()) {
65 // By the requirements on trivially rematerializable instructions, a
66 // physical register use is either constant or ignorable.
67 return Register();
68 }
69 return Reg;
70}
71
72RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx,
73 unsigned UseRegion,
74 DependencyReuseInfo &DRI) {
75 MachineInstr *FirstMI =
76 getReg(RegIdx: RootIdx).getRegionUseBounds(UseRegion, LIS).first;
77 RegisterIdx NewRegIdx = rematerializeToPos(RootIdx, InsertPos: FirstMI, DRI);
78 transferRegionUsers(FromRegIdx: RootIdx, ToRegIdx: NewRegIdx, UseRegion);
79 return NewRegIdx;
80}
81
82RegisterIdx
83Rematerializer::rematerializeToPos(RegisterIdx RootIdx,
84 MachineBasicBlock::iterator InsertPos,
85 DependencyReuseInfo &DRI) {
86 assert(!DRI.DependencyMap.contains(RootIdx));
87 LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << " to "
88 << printUser(&*InsertPos) << '\n');
89
90 // Traverse the root's dependency DAG depth-first to find the set of
91 // registers we must rematerialize along with it and a legal order to
92 // rematerialize them in.
93 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
94 SmallSetVector<RegisterIdx, 8> RematOrder;
95 RematOrder.insert(X: RootIdx);
96 do {
97 RegisterIdx RegIdx = DepDAG.pop_back_val();
98 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) {
99 // The dependency may already have a rematerialization ready to use.
100 if (DRI.DependencyMap.contains(Val: Dep.RegIdx))
101 continue;
102 // We may have already seen the dependency in the dependency DAG.
103 if (RematOrder.contains(key: Dep.RegIdx))
104 continue;
105 DepDAG.push_back(Elt: Dep.RegIdx);
106 RematOrder.insert(X: Dep.RegIdx);
107 }
108 } while (!DepDAG.empty());
109
110 // Rematerialize all necessary registers in the root's dependency DAG. At each
111 // rematerialization, dependencies should already be available.
112 RegisterIdx LastNewIdx;
113 for (RegisterIdx RegIdx : reverse(C&: RematOrder)) {
114 assert(!DRI.DependencyMap.contains(RegIdx) && "useless remat");
115 SmallVector<Reg::Dependency, 2> Dependencies;
116 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
117 Dependencies.emplace_back(Args: Dep.MOIdx, Args&: DRI.DependencyMap.at(Val: Dep.RegIdx));
118 LastNewIdx = rematerializeReg(RegIdx, InsertPos, Dependencies: std::move(Dependencies));
119 DRI.DependencyMap.insert(KV: {RegIdx, LastNewIdx});
120 }
121
122 return LastNewIdx;
123}
124
125void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) {
126 auto Remats = Rematerializations.find(Val: RootIdx);
127 if (Remats == Rematerializations.end())
128 return;
129
130 LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx)
131 << '\n');
132
133 reviveRegIfDead(RootIdx);
134 // All of the rematerialization's users must use the revived register.
135 for (RegisterIdx RematRegIdx : Remats->getSecond()) {
136 for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses)
137 transferRegionUsers(FromRegIdx: RematRegIdx, ToRegIdx: RootIdx, UseRegion);
138 }
139 Rematerializations.erase(Val: RootIdx);
140
141 LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of "
142 << printID(RootIdx) << '\n');
143}
144
145void Rematerializer::rollback(RegisterIdx RematIdx) {
146 assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) &&
147 "cannot rollback dead register");
148 const RegisterIdx OriginRegIdx = getOriginOf(RematRegIdx: RematIdx);
149 reviveRegIfDead(RootIdx: OriginRegIdx);
150 for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses)
151 transferRegionUsers(FromRegIdx: RematIdx, ToRegIdx: OriginRegIdx, UseRegion);
152}
153
154void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) {
155 if (getReg(RegIdx: RootIdx).isAlive())
156 return;
157 assert(Revivable.contains(RootIdx) && "not revivable");
158
159 // Traverse the root's dependency DAG depth-first to find the set of
160 // registers we must revive and a legal order to revive them in.
161 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
162 SmallSetVector<RegisterIdx, 8> ReviveOrder;
163 ReviveOrder.insert(X: RootIdx);
164 do {
165 // All dependencies of a revived register need to be alive too.
166 const Reg &ReviveReg = getReg(RegIdx: DepDAG.pop_back_val());
167 for (const Reg::Dependency &Dep : ReviveReg.Dependencies) {
168 // We may have already seen the dependency in the dependency DAG.
169 if (ReviveOrder.contains(key: Dep.RegIdx))
170 continue;
171
172 // Dead dependencies need to be revived.
173 Reg &DepReg = Regs[Dep.RegIdx];
174 if (!DepReg.isAlive()) {
175 assert(Revivable.contains(Dep.RegIdx) && "not revivable");
176 ReviveOrder.insert(X: Dep.RegIdx);
177 DepDAG.push_back(Elt: Dep.RegIdx);
178 }
179
180 // All dependencies get a new user (the revived register).
181 DepReg.addUser(MI: ReviveReg.DefMI, Region: ReviveReg.DefRegion);
182 LISUpdates.insert(V: Dep.RegIdx);
183 }
184 } while (!DepDAG.empty());
185
186 for (RegisterIdx RegIdx : reverse(C&: ReviveOrder)) {
187 // Pick any rematerialization to retrieve the original opcode from.
188 Reg &ReviveReg = Regs[RegIdx];
189 assert(Rematerializations.contains(RegIdx) && "no remats");
190 RegisterIdx RematIdx = *Rematerializations.at(Val: RegIdx).begin();
191 ReviveReg.DefMI->setDesc(getReg(RegIdx: RematIdx).DefMI->getDesc());
192 for (const auto &[MOIdx, Reg] : Revivable.at(Val: RegIdx))
193 ReviveReg.DefMI->getOperand(i: MOIdx).setReg(Reg);
194 Revivable.erase(Val: RegIdx);
195 LISUpdates.insert(V: RegIdx);
196
197 LLVM_DEBUG({
198 dbgs() << "** Revived " << printID(RegIdx) << " @ ";
199 LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs());
200 dbgs() << '\n';
201 });
202 }
203}
204
205void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
206 MachineInstr &UserMI) {
207 transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
208 unsigned UserRegion = MIRegion.at(Val: &UserMI);
209 Regs[FromRegIdx].eraseUser(MI: &UserMI, Region: UserRegion);
210 Regs[ToRegIdx].addUser(MI: &UserMI, Region: UserRegion);
211 deleteRegIfUnused(RootIdx: FromRegIdx);
212}
213
214void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
215 RegisterIdx ToRegIdx,
216 unsigned UseRegion) {
217 auto &FromRegUsers = Regs[FromRegIdx].Uses;
218 auto UsesIt = FromRegUsers.find(Val: UseRegion);
219 if (UsesIt == FromRegUsers.end())
220 return;
221
222 const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond();
223 for (MachineInstr *UserMI : RegionUsers)
224 transferUserImpl(FromRegIdx, ToRegIdx, UserMI&: *UserMI);
225 Regs[ToRegIdx].addUsers(NewUsers: RegionUsers, Region: UseRegion);
226 FromRegUsers.erase(Val: UseRegion);
227 deleteRegIfUnused(RootIdx: FromRegIdx);
228}
229
230void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
231 RegisterIdx ToRegIdx,
232 MachineInstr &UserMI) {
233 assert(MIRegion.contains(&UserMI) && "unknown user");
234 assert(getReg(FromRegIdx).Uses.at(MIRegion.at(&UserMI)).contains(&UserMI) &&
235 "not a user");
236 assert(FromRegIdx != ToRegIdx && "identical registers");
237 assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) &&
238 "unrelated registers");
239
240 LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to "
241 << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n');
242
243 UserMI.substituteRegister(FromReg: getReg(RegIdx: FromRegIdx).getDefReg(),
244 ToReg: getReg(RegIdx: ToRegIdx).getDefReg(), SubIdx: 0, RegInfo: TRI);
245 LISUpdates.insert(V: FromRegIdx);
246 LISUpdates.insert(V: ToRegIdx);
247
248 // If the user is rematerializable, we must change its dependency to the
249 // new register.
250 if (RegisterIdx UserRegIdx = getDefRegIdx(MI: UserMI); UserRegIdx != NoReg) {
251 // Look for the user's dependency that matches the register.
252 for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) {
253 if (Dep.RegIdx == FromRegIdx) {
254 Dep.RegIdx = ToRegIdx;
255 return;
256 }
257 }
258 llvm_unreachable("broken dependency");
259 }
260}
261
262void Rematerializer::updateLiveIntervals() {
263 DenseSet<Register> SeenUnrematRegs;
264 for (RegisterIdx RegIdx : LISUpdates) {
265 const Reg &UpdateReg = getReg(RegIdx);
266 assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg");
267
268 Register DefReg = UpdateReg.getDefReg();
269 if (LIS.hasInterval(Reg: DefReg))
270 LIS.removeInterval(Reg: DefReg);
271 LIS.createAndComputeVirtRegInterval(Reg: DefReg);
272
273 LLVM_DEBUG({
274 dbgs() << "Re-computed interval for " << printID(RegIdx) << ": ";
275 LIS.getInterval(DefReg).print(dbgs());
276 dbgs() << '\n' << printRegUsers(RegIdx);
277 });
278
279 // Update intervals for unrematerializable operands.
280 for (unsigned MOIdx : getUnrematableOprds(RegIdx)) {
281 Register UnrematReg = UpdateReg.DefMI->getOperand(i: MOIdx).getReg();
282 if (!SeenUnrematRegs.insert(V: UnrematReg).second)
283 continue;
284 LIS.removeInterval(Reg: UnrematReg);
285 LIS.createAndComputeVirtRegInterval(Reg: UnrematReg);
286 LLVM_DEBUG(
287 dbgs() << " Re-computed interval for register "
288 << printReg(UnrematReg, &TRI,
289 UpdateReg.DefMI->getOperand(MOIdx).getSubReg(),
290 &MRI)
291 << '\n');
292 }
293 }
294 LISUpdates.clear();
295}
296
297void Rematerializer::commitRematerializations() {
298 for (auto &[RegIdx, _] : Revivable)
299 deleteReg(RegIdx);
300 Revivable.clear();
301}
302
303bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
304 ArrayRef<SlotIndex> Uses) const {
305 if (Uses.empty())
306 return true;
307 Register Reg = MO.getReg();
308 unsigned SubIdx = MO.getSubReg();
309 LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
310 : MRI.getMaxLaneMaskForVReg(Reg);
311 const LiveInterval &LI = LIS.getInterval(Reg);
312 const VNInfo *DefVN =
313 LI.getVNInfoAt(Idx: LIS.getInstructionIndex(Instr: *MO.getParent()).getRegSlot(EC: true));
314 for (SlotIndex Use : Uses) {
315 if (!isIdenticalAtUse(OVNI: *DefVN, Mask, UseIdx: Use, LI))
316 return false;
317 }
318 return true;
319}
320
321RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx,
322 unsigned Region,
323 SlotIndex Before) const {
324 auto It = Rematerializations.find(Val: getOriginOrSelf(RegIdx));
325 if (It == Rematerializations.end())
326 return NoReg;
327 const RematsOf &Remats = It->getSecond();
328
329 SlotIndex BestSlot;
330 RegisterIdx BestRegIdx = NoReg;
331 for (RegisterIdx RematRegIdx : Remats) {
332 const Reg &RematReg = getReg(RegIdx: RematRegIdx);
333 if (RematReg.DefRegion != Region || RematReg.Uses.empty())
334 continue;
335 SlotIndex RematRegSlot =
336 LIS.getInstructionIndex(Instr: *RematReg.DefMI).getRegSlot();
337 if (RematRegSlot < Before &&
338 (BestRegIdx == NoReg || RematRegSlot > BestSlot)) {
339 BestSlot = RematRegSlot;
340 BestRegIdx = RematRegIdx;
341 }
342 }
343 return BestRegIdx;
344}
345
346void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
347 if (!getReg(RegIdx: RootIdx).Uses.empty())
348 return;
349
350 // Traverse the root's dependency DAG depth-first to find the set of registers
351 // we can delete and a legal order to delete them in.
352 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
353 SmallSetVector<RegisterIdx, 8> DeleteOrder;
354 DeleteOrder.insert(X: RootIdx);
355 do {
356 // A deleted register's dependencies may be deletable too.
357 const Reg &DeleteReg = getReg(RegIdx: DepDAG.pop_back_val());
358 for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
359 // All dependencies loose a user (the delete register).
360 Reg &DepReg = Regs[Dep.RegIdx];
361 DepReg.eraseUser(MI: DeleteReg.DefMI, Region: DeleteReg.DefRegion);
362 if (DepReg.Uses.empty()) {
363 DeleteOrder.insert(X: Dep.RegIdx);
364 DepDAG.push_back(Elt: Dep.RegIdx);
365 }
366 }
367 } while (!DepDAG.empty());
368
369 for (RegisterIdx RegIdx : reverse(C&: DeleteOrder)) {
370 Reg &DeleteReg = Regs[RegIdx];
371 LIS.removeInterval(Reg: DeleteReg.getDefReg());
372 LISUpdates.erase(V: RegIdx);
373 const bool IsRematerializedReg = isRematerializedRegister(RegIdx);
374 if (SupportRollback && !IsRematerializedReg) {
375 // Replace all read registers with the null one to prevent them from
376 // showing up in use-lists, which is disallowed for debug instructions in
377 // live interval calculations. Store mappings between operand indices and
378 // original registers for potential rollback.
379 DenseMap<unsigned, Register> &RegMap =
380 Revivable.try_emplace(Key: RegIdx).first->getSecond();
381 for (auto [Idx, MO] : enumerate(First: DeleteReg.DefMI->operands())) {
382 if (MO.isReg() && MO.readsReg()) {
383 RegMap.insert(KV: {Idx, MO.getReg()});
384 MO.setReg(Register());
385 }
386 }
387 DeleteReg.DefMI->setDesc(TII.get(Opcode: TargetOpcode::DBG_VALUE));
388 } else {
389 deleteReg(RegIdx);
390 }
391 if (IsRematerializedReg) {
392 // Delete rematerialized register from its origin's rematerializations.
393 RematsOf &OriginRemats = Rematerializations.at(Val: getOriginOf(RematRegIdx: RegIdx));
394 assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
395 OriginRemats.erase(V: RegIdx);
396 if (OriginRemats.empty())
397 Rematerializations.erase(Val: RegIdx);
398 }
399 LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n");
400 }
401}
402
403void Rematerializer::deleteReg(RegisterIdx RegIdx) {
404 Reg &DeleteReg = Regs[RegIdx];
405 assert(DeleteReg.DefMI && "register was already deleted");
406 // It is not possible for the deleted instruction to be the upper region
407 // boundary since we don't ever consider them rematerializable.
408 MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first;
409 if (RegionBegin == DeleteReg.DefMI)
410 RegionBegin = std::next(x: MachineBasicBlock::iterator(DeleteReg.DefMI));
411 LIS.RemoveMachineInstrFromMaps(MI&: *DeleteReg.DefMI);
412 DeleteReg.DefMI->eraseFromParent();
413 MIRegion.erase(Val: DeleteReg.DefMI);
414 DeleteReg.DefMI = nullptr;
415}
416
417Rematerializer::Rematerializer(MachineFunction &MF,
418 SmallVectorImpl<RegionBoundaries> &Regions,
419 LiveIntervals &LIS)
420 : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS),
421 TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) {
422#ifdef EXPENSIVE_CHECKS
423 // Check that regions are valid.
424 DenseSet<MachineInstr *> SeenMIs;
425 for (const auto &[RegionBegin, RegionEnd] : Regions) {
426 assert(RegionBegin != RegionEnd && "empty region");
427 for (auto MI = RegionBegin; MI != RegionEnd; ++MI) {
428 bool IsNewMI = SeenMIs.insert(&*MI).second;
429 assert(IsNewMI && "overlapping regions");
430 assert(!MI->isTerminator() && "terminator in region");
431 }
432 if (RegionEnd != RegionBegin->getParent()->end()) {
433 bool IsNewMI = SeenMIs.insert(&*RegionEnd).second;
434 assert(IsNewMI && "overlapping regions (upper bound)");
435 }
436 }
437#endif
438}
439
440bool Rematerializer::analyze(bool SupportRollback) {
441 Regs.clear();
442 UnrematableOprds.clear();
443 Origins.clear();
444 Rematerializations.clear();
445 MIRegion.clear();
446 RegToIdx.clear();
447 LISUpdates.clear();
448 Revivable.clear();
449 this->SupportRollback = SupportRollback;
450 if (Regions.empty())
451 return false;
452
453 // Initialize MI to containing region mapping.
454 for (unsigned I = 0, E = Regions.size(); I < E; ++I) {
455 RegionBoundaries Region = Regions[I];
456 assert(Region.first != Region.second && "empty cannot be region");
457 for (auto MI = Region.first; MI != Region.second; ++MI) {
458 assert(!MIRegion.contains(&*MI) && "regions should not intersect");
459 MIRegion.insert(KV: {&*MI, I});
460 }
461
462 // A terminator instruction is considered part of the region it terminates.
463 if (Region.second != Region.first->getParent()->end()) {
464 MachineInstr *RegionTerm = &*Region.second;
465 assert(!MIRegion.contains(RegionTerm) && "regions should not intersect");
466 MIRegion.insert(KV: {RegionTerm, I});
467 }
468 }
469
470 const unsigned NumVirtRegs = MRI.getNumVirtRegs();
471 BitVector SeenRegs(NumVirtRegs);
472 for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
473 if (!SeenRegs[I])
474 addRegIfRematerializable(VirtRegIdx: I, SeenRegs);
475 }
476 assert(Regs.size() == UnrematableOprds.size());
477
478 LLVM_DEBUG({
479 for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I)
480 dbgs() << printDependencyDAG(I) << '\n';
481 });
482 return !Regs.empty();
483}
484
485void Rematerializer::addRegIfRematerializable(unsigned VirtRegIdx,
486 BitVector &SeenRegs) {
487 assert(!SeenRegs[VirtRegIdx] && "register already seen");
488 Register DefReg = Register::index2VirtReg(Index: VirtRegIdx);
489 SeenRegs.set(VirtRegIdx);
490
491 MachineOperand *MO = MRI.getOneDef(Reg: DefReg);
492 if (!MO)
493 return;
494 MachineInstr &DefMI = *MO->getParent();
495 if (!isMIRematerializable(MI: DefMI))
496 return;
497 auto DefRegion = MIRegion.find(Val: &DefMI);
498 if (DefRegion == MIRegion.end())
499 return;
500
501 Reg RematReg;
502 RematReg.DefMI = &DefMI;
503 RematReg.DefRegion = DefRegion->second;
504 unsigned SubIdx = DefMI.getOperand(i: 0).getSubReg();
505 RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
506 : MRI.getMaxLaneMaskForVReg(Reg: DefReg);
507
508 // Collect the candidate's direct users, both rematerializable and
509 // unrematerializable. MIs outside provided regions cannot be tracked so the
510 // registers they use are not safely rematerializable.
511 for (MachineInstr &UseMI : MRI.use_nodbg_instructions(Reg: DefReg)) {
512 if (auto UseRegion = MIRegion.find(Val: &UseMI); UseRegion != MIRegion.end())
513 RematReg.addUser(MI: &UseMI, Region: UseRegion->second);
514 else
515 return;
516 }
517 if (RematReg.Uses.empty())
518 return;
519
520 // Collect the candidate's dependencies. If the same register is used
521 // multiple times we just need to consider it once.
522 SmallDenseSet<Register, 4> AllDepRegs;
523 SmallVector<unsigned, 2> UnrematDeps;
524 for (const auto &[MOIdx, MO] : enumerate(First: RematReg.DefMI->operands())) {
525 Register DepReg = getRegDependency(MO);
526 if (!DepReg || !AllDepRegs.insert(V: DepReg).second)
527 continue;
528 unsigned DepRegIdx = DepReg.virtRegIndex();
529 if (!SeenRegs[DepRegIdx])
530 addRegIfRematerializable(VirtRegIdx: DepRegIdx, SeenRegs);
531 if (auto DepIt = RegToIdx.find(Val: DepReg); DepIt != RegToIdx.end())
532 RematReg.Dependencies.push_back(Elt: Reg::Dependency(MOIdx, DepIt->second));
533 else
534 UnrematDeps.push_back(Elt: MOIdx);
535 }
536
537 // The register is rematerializable.
538 RegToIdx.insert(KV: {DefReg, Regs.size()});
539 Regs.push_back(Elt: RematReg);
540 UnrematableOprds.push_back(Elt: UnrematDeps);
541}
542
543bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const {
544 if (!TII.isReMaterializable(MI))
545 return false;
546
547 assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual");
548 assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def");
549
550 for (const MachineOperand &MO : MI.all_uses()) {
551 // We can't remat physreg uses, unless it is a constant or an ignorable
552 // use (e.g. implicit exec use on VALU instructions)
553 if (MO.getReg().isPhysical()) {
554 if (MRI.isConstantPhysReg(PhysReg: MO.getReg()) || TII.isIgnorableUse(MO))
555 continue;
556 return false;
557 }
558 }
559
560 return true;
561}
562
563RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const {
564 if (!MI.getNumOperands() || !MI.getOperand(i: 0).isReg() ||
565 MI.getOperand(i: 0).readsReg())
566 return NoReg;
567 Register Reg = MI.getOperand(i: 0).getReg();
568 auto UserRegIt = RegToIdx.find(Val: Reg);
569 if (UserRegIt == RegToIdx.end())
570 return NoReg;
571 return UserRegIt->second;
572}
573
574RegisterIdx Rematerializer::rematerializeReg(
575 RegisterIdx RegIdx, MachineBasicBlock::iterator InsertPos,
576 SmallVectorImpl<Reg::Dependency> &&Dependencies) {
577 unsigned UseRegion = MIRegion.at(Val: &*InsertPos);
578 RegisterIdx NewRegIdx = Regs.size();
579
580 Reg &NewReg = Regs.emplace_back();
581 Reg &FromReg = Regs[RegIdx];
582 NewReg.Mask = FromReg.Mask;
583 NewReg.DefRegion = UseRegion;
584 NewReg.Dependencies = std::move(Dependencies);
585
586 // Track rematerialization link between registers. Origins are always
587 // registers that existed originally, and rematerializations are always
588 // attached to them.
589 RegisterIdx OriginIdx =
590 isRematerializedRegister(RegIdx) ? getOriginOf(RematRegIdx: RegIdx) : RegIdx;
591 Origins.push_back(Elt: OriginIdx);
592 Rematerializations[OriginIdx].insert(V: NewRegIdx);
593
594 // Use the TII to rematerialize the defining instruction with a new defined
595 // register.
596 Register NewDefReg = MRI.cloneVirtualRegister(VReg: FromReg.getDefReg());
597 TII.reMaterialize(MBB&: *InsertPos->getParent(), MI: InsertPos, DestReg: NewDefReg, SubIdx: 0,
598 Orig: *FromReg.DefMI);
599 NewReg.DefMI = &*std::prev(x: InsertPos);
600 RegToIdx.insert(KV: {NewDefReg, NewRegIdx});
601
602 // Update the DAG.
603 RegionBoundaries &Bounds = Regions[UseRegion];
604 if (Bounds.first == std::next(x: MachineBasicBlock::iterator(NewReg.DefMI)))
605 Bounds.first = NewReg.DefMI;
606 LIS.InsertMachineInstrInMaps(MI&: *NewReg.DefMI);
607 MIRegion.emplace_or_assign(Key: NewReg.DefMI, Args&: UseRegion);
608 LISUpdates.insert(V: NewRegIdx);
609
610 // Replace dependencies as needed in the rematerialized MI. All dependencies
611 // of the latter gain a new user.
612 auto ZipedDeps = zip_equal(t&: FromReg.Dependencies, u&: NewReg.Dependencies);
613 for (const auto &[OldDep, NewDep] : ZipedDeps) {
614 assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
615 LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
616 << printID(OldDep.RegIdx) << " -> "
617 << printID(NewDep.RegIdx) << '\n');
618
619 Reg &NewDepReg = Regs[NewDep.RegIdx];
620 if (OldDep.RegIdx != NewDep.RegIdx) {
621 Register OldDefReg = FromReg.DefMI->getOperand(i: OldDep.MOIdx).getReg();
622 NewReg.DefMI->substituteRegister(FromReg: OldDefReg, ToReg: NewDepReg.getDefReg(), SubIdx: 0,
623 RegInfo: TRI);
624 LISUpdates.insert(V: OldDep.RegIdx);
625 }
626 NewDepReg.addUser(MI: NewReg.DefMI, Region: UseRegion);
627 LISUpdates.insert(V: NewDep.RegIdx);
628 }
629
630 LLVM_DEBUG({
631 dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
632 << printRematReg(NewRegIdx) << '\n';
633 });
634 return NewRegIdx;
635}
636
637std::pair<MachineInstr *, MachineInstr *>
638Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion,
639 const LiveIntervals &LIS) const {
640 auto It = Uses.find(Val: UseRegion);
641 if (It == Uses.end())
642 return {nullptr, nullptr};
643 const RegionUsers &RegionUsers = It->getSecond();
644 assert(!RegionUsers.empty() && "empty userset in region");
645
646 auto User = RegionUsers.begin(), UserEnd = RegionUsers.end();
647 MachineInstr *FirstMI = *User, *LastMI = FirstMI;
648 SlotIndex FirstIndex = LIS.getInstructionIndex(Instr: *FirstMI),
649 LastIndex = FirstIndex;
650
651 while (++User != UserEnd) {
652 SlotIndex UserIndex = LIS.getInstructionIndex(Instr: **User);
653 if (UserIndex < FirstIndex) {
654 FirstIndex = UserIndex;
655 FirstMI = *User;
656 } else if (UserIndex > LastIndex) {
657 LastIndex = UserIndex;
658 LastMI = *User;
659 }
660 }
661
662 return {FirstMI, LastMI};
663}
664
665void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) {
666 Uses[Region].insert(V: MI);
667}
668
669void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers,
670 unsigned Region) {
671 Uses[Region].insert_range(R: NewUsers);
672}
673
674void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) {
675 assert(Uses.contains(Region) && "no user in region");
676 assert(Uses.at(Region).contains(MI) && "user not in region");
677 RegionUsers &RUsers = Uses[Region];
678 if (RUsers.size() == 1)
679 Uses.erase(Val: Region);
680 else
681 RUsers.erase(V: MI);
682}
683
684Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const {
685 return Printable([&, RootIdx](raw_ostream &OS) {
686 DenseMap<RegisterIdx, unsigned> RegDepths;
687 std::function<void(RegisterIdx, unsigned)> WalkTree =
688 [&](RegisterIdx RegIdx, unsigned Depth) -> void {
689 unsigned MaxDepth = std::max(a: RegDepths.lookup_or(Val: RegIdx, Default&: Depth), b: Depth);
690 RegDepths.emplace_or_assign(Key: RegIdx, Args&: MaxDepth);
691 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
692 WalkTree(Dep.RegIdx, Depth + 1);
693 };
694 WalkTree(RootIdx, 0);
695
696 // Sort in decreasing depth order to print root at the bottom.
697 SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(),
698 RegDepths.end());
699 sort(C&: Regs, Comp: [](const auto &LHS, const auto &RHS) {
700 return LHS.second > RHS.second;
701 });
702
703 OS << printID(RegIdx: RootIdx) << " has " << Regs.size() - 1 << " dependencies\n";
704 for (const auto &[RegIdx, Depth] : Regs) {
705 OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' '
706 << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n';
707 }
708 OS << printRegUsers(RegIdx: RootIdx);
709 });
710}
711
712Printable Rematerializer::printID(RegisterIdx RegIdx) const {
713 return Printable([&, RegIdx](raw_ostream &OS) {
714 const Reg &PrintReg = getReg(RegIdx);
715 OS << '(' << RegIdx << '/';
716 if (!PrintReg.DefMI) {
717 OS << "<dead>";
718 } else {
719 OS << printReg(Reg: PrintReg.getDefReg(), TRI: &TRI,
720 SubIdx: PrintReg.DefMI->getOperand(i: 0).getSubReg(), MRI: &MRI);
721 }
722 OS << ")[" << PrintReg.DefRegion << "]";
723 });
724}
725
726Printable Rematerializer::printRematReg(RegisterIdx RegIdx,
727 bool SkipRegions) const {
728 return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) {
729 const Reg &PrintReg = getReg(RegIdx);
730 if (!SkipRegions) {
731 OS << printID(RegIdx) << " [" << PrintReg.DefRegion;
732 if (!PrintReg.Uses.empty()) {
733 assert(PrintReg.DefMI && "dead register cannot have uses");
734 const LiveInterval &LI = LIS.getInterval(Reg: PrintReg.getDefReg());
735 // First display all regions in which the register is live-through and
736 // not used.
737 bool First = true;
738 for (const auto [I, Bounds] : enumerate(First&: Regions)) {
739 if (Bounds.first == Bounds.second)
740 continue;
741 if (!PrintReg.Uses.contains(Val: I) &&
742 LI.liveAt(index: LIS.getInstructionIndex(Instr: *Bounds.first)) &&
743 LI.liveAt(index: LIS.getInstructionIndex(Instr: *std::prev(x: Bounds.second))
744 .getRegSlot())) {
745 OS << (First ? " - " : ",") << I;
746 First = false;
747 }
748 }
749 OS << (First ? " --> " : " -> ");
750
751 // Then display regions in which the register is used.
752 auto It = PrintReg.Uses.begin();
753 OS << It->first;
754 while (++It != PrintReg.Uses.end())
755 OS << "," << It->first;
756 }
757 OS << "] ";
758 }
759 OS << printID(RegIdx) << ' ';
760 PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
761 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
762 OS << " @ ";
763 LIS.getInstructionIndex(Instr: *PrintReg.DefMI).print(os&: OS);
764 });
765}
766
767Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const {
768 return Printable([&, RegIdx](raw_ostream &OS) {
769 for (const auto &[_, Users] : getReg(RegIdx).Uses) {
770 for (MachineInstr *MI : Users)
771 dbgs() << " User " << printUser(MI) << '\n';
772 }
773 });
774}
775
776Printable Rematerializer::printUser(const MachineInstr *MI) const {
777 return Printable([&, MI](raw_ostream &OS) {
778 RegisterIdx RegIdx = getDefRegIdx(MI: *MI);
779 if (RegIdx != NoReg)
780 OS << printID(RegIdx);
781 else
782 OS << "(-/-)[" << MIRegion.at(Val: MI) << ']';
783 OS << ' ';
784 MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
785 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
786 OS << " @ ";
787 LIS.getInstructionIndex(Instr: *MI).print(os&: dbgs());
788 });
789}
790