1//=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//==-----------------------------------------------------------------------===//
8//
9/// \file
10/// Implements helpers for target-independent rematerialization at the MIR
11/// level.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/CodeGen/Rematerializer.h"
16#include "llvm/ADT/MapVector.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/CodeGen/LiveIntervals.h"
20#include "llvm/CodeGen/MachineBasicBlock.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/Register.h"
24#include "llvm/CodeGen/TargetRegisterInfo.h"
25#include "llvm/Support/Debug.h"
26#include <optional>
27
28#define DEBUG_TYPE "rematerializer"
29
30using namespace llvm;
31using RegisterIdx = Rematerializer::RegisterIdx;
32
33// Pin the vtable to this file.
34void Rematerializer::Listener::anchor() {}
35
36/// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this
37/// implies it is also live there). When \p LI has sub-ranges, checks that
38/// all sub-ranges intersecting with \p Mask are also live at \p UseIdx.
39static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask,
40 SlotIndex UseIdx, const LiveInterval &LI) {
41 if (&OVNI != LI.getVNInfoAt(Idx: UseIdx))
42 return false;
43
44 if (LI.hasSubRanges()) {
45 // Check that intersecting subranges are live at user.
46 for (const LiveInterval::SubRange &SR : LI.subranges()) {
47 if ((SR.LaneMask & Mask).none())
48 continue;
49 if (!SR.liveAt(index: UseIdx))
50 return false;
51
52 // Early exit if all used lanes are checked. No need to continue.
53 Mask &= ~SR.LaneMask;
54 if (Mask.none())
55 break;
56 }
57 }
58 return true;
59}
60
61/// If \p MO is a virtual read register, returns it. Otherwise returns the
62/// sentinel register.
63static Register getRegDependency(const MachineOperand &MO) {
64 if (!MO.isReg() || !MO.readsReg())
65 return Register();
66 Register Reg = MO.getReg();
67 if (Reg.isPhysical()) {
68 // By the requirements on trivially rematerializable instructions, a
69 // physical register use is either constant or ignorable.
70 return Register();
71 }
72 return Reg;
73}
74
75RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx,
76 unsigned UseRegion,
77 DependencyReuseInfo &DRI) {
78 MachineInstr *FirstMI =
79 getReg(RegIdx: RootIdx).getRegionUseBounds(UseRegion, LIS).first;
80 // If there are no users in the region, rematerialize the register at the very
81 // end of the region.
82 MachineBasicBlock::iterator InsertPos =
83 FirstMI ? FirstMI : Regions[UseRegion].second;
84 RegisterIdx NewRegIdx =
85 rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI);
86 transferRegionUsers(FromRegIdx: RootIdx, ToRegIdx: NewRegIdx, UseRegion);
87 return NewRegIdx;
88}
89
90RegisterIdx
91Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion,
92 MachineBasicBlock::iterator InsertPos,
93 DependencyReuseInfo &DRI) {
94 assert(!DRI.DependencyMap.contains(RootIdx));
95 LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n');
96
97 SmallVector<Reg::Dependency, 2> NewDeps;
98 // Copy all dependencies because recursive rematerialization of dependencies
99 // may invalidate references to the backing vector of registers.
100 SmallVector<Reg::Dependency, 2> OldDeps(getReg(RegIdx: RootIdx).Dependencies);
101 for (const Reg::Dependency &Dep : OldDeps) {
102 // Recursively rematerialize required dependencies at the same position as
103 // the root. Registers form a DAG so the recursion is guaranteed to
104 // terminate.
105 auto RematIdx = DRI.DependencyMap.find(Val: Dep.RegIdx);
106 RegisterIdx NewDepRegIdx;
107 if (RematIdx == DRI.DependencyMap.end())
108 NewDepRegIdx = rematerializeToPos(RootIdx: Dep.RegIdx, UseRegion, InsertPos, DRI);
109 else
110 NewDepRegIdx = RematIdx->second;
111 NewDeps.emplace_back(Args: Dep.MOIdx, Args&: NewDepRegIdx);
112 }
113 RegisterIdx NewIdx =
114 rematerializeReg(RegIdx: RootIdx, UseRegion, InsertPos, Dependencies: std::move(NewDeps));
115 DRI.DependencyMap.insert(KV: {RootIdx, NewIdx});
116 return NewIdx;
117}
118
119void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx,
120 unsigned UserRegion, MachineInstr &UserMI) {
121 transferUserImpl(FromRegIdx, ToRegIdx, UserMI);
122 Regs[FromRegIdx].eraseUser(MI: &UserMI, Region: UserRegion);
123 Regs[ToRegIdx].addUser(MI: &UserMI, Region: UserRegion);
124 deleteRegIfUnused(RootIdx: FromRegIdx);
125}
126
127void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx,
128 RegisterIdx ToRegIdx,
129 unsigned UseRegion) {
130 auto &FromRegUsers = Regs[FromRegIdx].Uses;
131 auto UsesIt = FromRegUsers.find(Val: UseRegion);
132 if (UsesIt == FromRegUsers.end())
133 return;
134
135 const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond();
136 for (MachineInstr *UserMI : RegionUsers)
137 transferUserImpl(FromRegIdx, ToRegIdx, UserMI&: *UserMI);
138 Regs[ToRegIdx].addUsers(NewUsers: RegionUsers, Region: UseRegion);
139 FromRegUsers.erase(Val: UseRegion);
140 deleteRegIfUnused(RootIdx: FromRegIdx);
141}
142
143void Rematerializer::transferAllUsers(RegisterIdx FromRegIdx,
144 RegisterIdx ToRegIdx) {
145 Reg &FromReg = Regs[FromRegIdx], &ToReg = Regs[ToRegIdx];
146 for (const auto &[UseRegion, RegionUsers] : FromReg.Uses) {
147 for (MachineInstr *UserMI : RegionUsers)
148 transferUserImpl(FromRegIdx, ToRegIdx, UserMI&: *UserMI);
149 ToReg.addUsers(NewUsers: RegionUsers, Region: UseRegion);
150 }
151 FromReg.Uses.clear();
152 deleteRegIfUnused(RootIdx: FromRegIdx);
153}
154
155void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx,
156 RegisterIdx ToRegIdx,
157 MachineInstr &UserMI) {
158 assert(FromRegIdx != ToRegIdx && "identical registers");
159 assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) &&
160 "unrelated registers");
161
162 LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to "
163 << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n');
164
165 UserMI.substituteRegister(FromReg: getReg(RegIdx: FromRegIdx).getDefReg(),
166 ToReg: getReg(RegIdx: ToRegIdx).getDefReg(), SubIdx: 0, RegInfo: TRI);
167 LISUpdates.insert(V: FromRegIdx);
168 LISUpdates.insert(V: ToRegIdx);
169
170 // If the user is rematerializable, we must change its dependency to the
171 // new register.
172 if (RegisterIdx UserRegIdx = getDefRegIdx(MI: UserMI); UserRegIdx != NoReg) {
173 // Look for the user's dependency that matches the register.
174 for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) {
175 if (Dep.RegIdx == FromRegIdx) {
176 Dep.RegIdx = ToRegIdx;
177 return;
178 }
179 }
180 llvm_unreachable("broken dependency");
181 }
182}
183
184void Rematerializer::updateLiveIntervals() {
185 DenseSet<Register> SeenUnrematRegs;
186 for (RegisterIdx RegIdx : LISUpdates) {
187 const Reg &UpdateReg = getReg(RegIdx);
188 assert(UpdateReg.isAlive() && "dead register");
189
190 Register DefReg = UpdateReg.getDefReg();
191 if (LIS.hasInterval(Reg: DefReg))
192 LIS.removeInterval(Reg: DefReg);
193 // Rematerializable registers have a single definition by construction so
194 // re-creating their interval cannot yield a live interval with multiple
195 // connected components.
196 LIS.createAndComputeVirtRegInterval(Reg: DefReg);
197
198 LLVM_DEBUG({
199 dbgs() << "Re-computed interval for " << printID(RegIdx) << ": ";
200 LIS.getInterval(DefReg).print(dbgs());
201 dbgs() << '\n' << printRegUsers(RegIdx);
202 });
203
204 // Update intervals for unrematerializable operands.
205 for (unsigned MOIdx : getUnrematableOprds(RegIdx)) {
206 Register UnrematReg = UpdateReg.DefMI->getOperand(i: MOIdx).getReg();
207 if (!SeenUnrematRegs.insert(V: UnrematReg).second)
208 continue;
209 LIS.removeInterval(Reg: UnrematReg);
210 bool NeedSplit = false;
211
212 // Unrematerializable registers may end up with multiple connected
213 // components in their live interval after it is re-created. It needs to
214 // be split in such cases. We don't track unrematerializable registers by
215 // their actual register index (just by operand index) so we do not need
216 // to update any state in the rematerializer.
217 LiveInterval &LI =
218 LIS.createAndComputeVirtRegInterval(Reg: UnrematReg, NeedSplit);
219 if (NeedSplit) {
220 SmallVector<LiveInterval *> SplitLIs;
221 LIS.splitSeparateComponents(LI, SplitLIs);
222 }
223 LLVM_DEBUG(
224 dbgs() << " Re-computed interval for register "
225 << printReg(UnrematReg, &TRI,
226 UpdateReg.DefMI->getOperand(MOIdx).getSubReg(),
227 &MRI)
228 << '\n');
229 }
230 }
231 LISUpdates.clear();
232}
233
234bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO,
235 ArrayRef<SlotIndex> Uses) const {
236 if (Uses.empty())
237 return true;
238 Register Reg = MO.getReg();
239 unsigned SubIdx = MO.getSubReg();
240 LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
241 : MRI.getMaxLaneMaskForVReg(Reg);
242 const LiveInterval &LI = LIS.getInterval(Reg);
243 const VNInfo *DefVN =
244 LI.getVNInfoAt(Idx: LIS.getInstructionIndex(Instr: *MO.getParent()).getRegSlot(EC: true));
245 for (SlotIndex Use : Uses) {
246 if (!isIdenticalAtUse(OVNI: *DefVN, Mask, UseIdx: Use, LI))
247 return false;
248 }
249 return true;
250}
251
252RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx,
253 unsigned Region,
254 SlotIndex Before) const {
255 auto It = Rematerializations.find(Val: getOriginOrSelf(RegIdx));
256 if (It == Rematerializations.end())
257 return NoReg;
258 const RematsOf &Remats = It->getSecond();
259
260 SlotIndex BestSlot;
261 RegisterIdx BestRegIdx = NoReg;
262 for (RegisterIdx RematRegIdx : Remats) {
263 const Reg &RematReg = getReg(RegIdx: RematRegIdx);
264 if (RematReg.DefRegion != Region || RematReg.Uses.empty())
265 continue;
266 SlotIndex RematRegSlot =
267 LIS.getInstructionIndex(Instr: *RematReg.DefMI).getRegSlot();
268 if (RematRegSlot < Before &&
269 (BestRegIdx == NoReg || RematRegSlot > BestSlot)) {
270 BestSlot = RematRegSlot;
271 BestRegIdx = RematRegIdx;
272 }
273 }
274 return BestRegIdx;
275}
276
277void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) {
278 if (!getReg(RegIdx: RootIdx).Uses.empty())
279 return;
280
281 // Traverse the root's dependency DAG depth-first to find the set of registers
282 // we can delete and a legal order to delete them in.
283 SmallVector<RegisterIdx, 4> DepDAG{RootIdx};
284 SmallSetVector<RegisterIdx, 8> DeleteOrder;
285 DeleteOrder.insert(X: RootIdx);
286 do {
287 // A deleted register's dependencies may be deletable too.
288 const Reg &DeleteReg = getReg(RegIdx: DepDAG.pop_back_val());
289 for (const Reg::Dependency &Dep : DeleteReg.Dependencies) {
290 // All dependencies loose a user (the deleted register).
291 Reg &DepReg = Regs[Dep.RegIdx];
292 DepReg.eraseUser(MI: DeleteReg.DefMI, Region: DeleteReg.DefRegion);
293 if (DepReg.Uses.empty()) {
294 DeleteOrder.insert(X: Dep.RegIdx);
295 DepDAG.push_back(Elt: Dep.RegIdx);
296 }
297 }
298 } while (!DepDAG.empty());
299
300 for (RegisterIdx RegIdx : reverse(C&: DeleteOrder)) {
301 Reg &DeleteReg = Regs[RegIdx];
302
303 // It is possible that the defined register we are deleting doesn't have an
304 // interval yet if the LIS hasn't been updated since it was created.
305 Register DefReg = DeleteReg.getDefReg();
306 if (LIS.hasInterval(Reg: DefReg))
307 LIS.removeInterval(Reg: DefReg);
308 LISUpdates.erase(V: RegIdx);
309
310 deleteReg(RegIdx);
311 if (isRematerializedRegister(RegIdx)) {
312 // Delete rematerialized register from its origin's rematerializations.
313 const RegisterIdx OriginIdx = getOriginOf(RematRegIdx: RegIdx);
314 RematsOf &OriginRemats = Rematerializations.at(Val: OriginIdx);
315 assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link");
316 OriginRemats.erase(V: RegIdx);
317 if (OriginRemats.empty())
318 Rematerializations.erase(Val: OriginIdx);
319 }
320 LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n");
321 }
322}
323
324void Rematerializer::deleteReg(RegisterIdx RegIdx) {
325 noteRegDeleted(RegIdx);
326
327 Reg &DeleteReg = Regs[RegIdx];
328 assert(DeleteReg.DefMI && "register was already deleted");
329 // It is not possible for the deleted instruction to be the upper region
330 // boundary since we don't ever consider them rematerializable.
331 MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first;
332 if (RegionBegin == DeleteReg.DefMI)
333 RegionBegin = std::next(x: MachineBasicBlock::iterator(DeleteReg.DefMI));
334 LIS.RemoveMachineInstrFromMaps(MI&: *DeleteReg.DefMI);
335 DeleteReg.DefMI->eraseFromParent();
336 DeleteReg.DefMI = nullptr;
337}
338
339Rematerializer::Rematerializer(MachineFunction &MF,
340 SmallVectorImpl<RegionBoundaries> &Regions,
341 LiveIntervals &LIS)
342 : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS),
343 TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) {
344#ifdef EXPENSIVE_CHECKS
345 // Check that regions are valid.
346 DenseSet<MachineInstr *> SeenMIs;
347 for (const auto &[RegionBegin, RegionEnd] : Regions) {
348 assert(RegionBegin != RegionEnd && "empty region");
349 for (auto MI = RegionBegin; MI != RegionEnd; ++MI) {
350 bool IsNewMI = SeenMIs.insert(&*MI).second;
351 assert(IsNewMI && "overlapping regions");
352 assert(!MI->isTerminator() && "terminator in region");
353 }
354 if (RegionEnd != RegionBegin->getParent()->end()) {
355 bool IsNewMI = SeenMIs.insert(&*RegionEnd).second;
356 assert(IsNewMI && "overlapping regions (upper bound)");
357 }
358 }
359#endif
360}
361
362bool Rematerializer::analyze() {
363 Regs.clear();
364 UnrematableOprds.clear();
365 Origins.clear();
366 Rematerializations.clear();
367 RegionMBB.clear();
368 RegToIdx.clear();
369 LISUpdates.clear();
370 if (Regions.empty())
371 return false;
372
373 /// Maps all MIs to their parent region. Region terminators are considered
374 /// part of the region they terminate.
375 DenseMap<MachineInstr *, unsigned> MIRegion;
376
377 // Initialize MI to containing region mapping.
378 RegionMBB.reserve(N: Regions.size());
379 for (unsigned I = 0, E = Regions.size(); I < E; ++I) {
380 RegionBoundaries Region = Regions[I];
381 assert(Region.first != Region.second && "empty cannot be region");
382 for (auto MI = Region.first; MI != Region.second; ++MI) {
383 assert(!MIRegion.contains(&*MI) && "regions should not intersect");
384 MIRegion.insert(KV: {&*MI, I});
385 }
386 MachineBasicBlock &MBB = *Region.first->getParent();
387 RegionMBB.push_back(Elt: &MBB);
388
389 // A terminator instruction is considered part of the region it terminates.
390 if (Region.second != MBB.end()) {
391 MachineInstr *RegionTerm = &*Region.second;
392 assert(!MIRegion.contains(RegionTerm) && "regions should not intersect");
393 MIRegion.insert(KV: {RegionTerm, I});
394 }
395 }
396
397 const unsigned NumVirtRegs = MRI.getNumVirtRegs();
398 BitVector SeenRegs(NumVirtRegs);
399 for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) {
400 if (!SeenRegs[I])
401 addRegIfRematerializable(VirtRegIdx: I, MIRegion, SeenRegs);
402 }
403 assert(Regs.size() == UnrematableOprds.size());
404
405 LLVM_DEBUG({
406 for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I)
407 dbgs() << printDependencyDAG(I) << '\n';
408 });
409 return !Regs.empty();
410}
411
412void Rematerializer::addRegIfRematerializable(
413 unsigned VirtRegIdx, const DenseMap<MachineInstr *, unsigned> &MIRegion,
414 BitVector &SeenRegs) {
415 assert(!SeenRegs[VirtRegIdx] && "register already seen");
416 Register DefReg = Register::index2VirtReg(Index: VirtRegIdx);
417 SeenRegs.set(VirtRegIdx);
418
419 MachineOperand *MO = MRI.getOneDef(Reg: DefReg);
420 if (!MO)
421 return;
422 MachineInstr &DefMI = *MO->getParent();
423 if (!isMIRematerializable(MI: DefMI))
424 return;
425 auto DefRegion = MIRegion.find(Val: &DefMI);
426 if (DefRegion == MIRegion.end())
427 return;
428
429 Reg RematReg;
430 RematReg.DefMI = &DefMI;
431 RematReg.DefRegion = DefRegion->second;
432 unsigned SubIdx = DefMI.getOperand(i: 0).getSubReg();
433 RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx)
434 : MRI.getMaxLaneMaskForVReg(Reg: DefReg);
435
436 // Collect the candidate's direct users, both rematerializable and
437 // unrematerializable. MIs outside provided regions cannot be tracked so the
438 // registers they use are not safely rematerializable.
439 for (MachineInstr &UseMI : MRI.use_nodbg_instructions(Reg: DefReg)) {
440 if (auto UseRegion = MIRegion.find(Val: &UseMI); UseRegion != MIRegion.end())
441 RematReg.addUser(MI: &UseMI, Region: UseRegion->second);
442 else
443 return;
444 }
445 if (RematReg.Uses.empty())
446 return;
447
448 // Collect the candidate's dependencies. If the same register is used
449 // multiple times we just need to consider it once.
450 SmallDenseSet<Register, 4> AllDepRegs;
451 SmallVector<unsigned, 2> UnrematDeps;
452 for (const auto &[MOIdx, MO] : enumerate(First: RematReg.DefMI->operands())) {
453 Register DepReg = getRegDependency(MO);
454 if (!DepReg || !AllDepRegs.insert(V: DepReg).second)
455 continue;
456 unsigned DepRegIdx = DepReg.virtRegIndex();
457 if (!SeenRegs[DepRegIdx])
458 addRegIfRematerializable(VirtRegIdx: DepRegIdx, MIRegion, SeenRegs);
459 if (auto DepIt = RegToIdx.find(Val: DepReg); DepIt != RegToIdx.end())
460 RematReg.Dependencies.push_back(Elt: Reg::Dependency(MOIdx, DepIt->second));
461 else
462 UnrematDeps.push_back(Elt: MOIdx);
463 }
464
465 // The register is rematerializable.
466 RegToIdx.insert(KV: {DefReg, Regs.size()});
467 Regs.push_back(Elt: RematReg);
468 UnrematableOprds.push_back(Elt: UnrematDeps);
469}
470
471bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const {
472 if (!TII.isReMaterializable(MI))
473 return false;
474
475 assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual");
476 assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def");
477
478 for (const MachineOperand &MO : MI.all_uses()) {
479 // We can't remat physreg uses, unless it is a constant or an ignorable
480 // use (e.g. implicit exec use on VALU instructions)
481 if (MO.getReg().isPhysical()) {
482 if (MRI.isConstantPhysReg(PhysReg: MO.getReg()) || TII.isIgnorableUse(MO))
483 continue;
484 return false;
485 }
486 }
487
488 return true;
489}
490
491RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const {
492 if (!MI.getNumOperands() || !MI.getOperand(i: 0).isReg() ||
493 MI.getOperand(i: 0).readsReg())
494 return NoReg;
495 Register Reg = MI.getOperand(i: 0).getReg();
496 auto UserRegIt = RegToIdx.find(Val: Reg);
497 if (UserRegIt == RegToIdx.end())
498 return NoReg;
499 return UserRegIt->second;
500}
501
502RegisterIdx Rematerializer::rematerializeReg(
503 RegisterIdx RegIdx, unsigned UseRegion,
504 MachineBasicBlock::iterator InsertPos,
505 SmallVectorImpl<Reg::Dependency> &&Dependencies) {
506 RegisterIdx NewRegIdx = Regs.size();
507
508 Reg &NewReg = Regs.emplace_back();
509 Reg &FromReg = Regs[RegIdx];
510 NewReg.Mask = FromReg.Mask;
511 NewReg.DefRegion = UseRegion;
512 NewReg.Dependencies = std::move(Dependencies);
513
514 // Track rematerialization link between registers. Origins are always
515 // registers that existed originally, and rematerializations are always
516 // attached to them.
517 const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx);
518 Origins.push_back(Elt: OriginIdx);
519 Rematerializations[OriginIdx].insert(V: NewRegIdx);
520
521 // Use the TII to rematerialize the defining instruction with a new defined
522 // register.
523 Register NewDefReg = MRI.cloneVirtualRegister(VReg: FromReg.getDefReg());
524 TII.reMaterialize(MBB&: *RegionMBB[UseRegion], MI: InsertPos, DestReg: NewDefReg, SubIdx: 0,
525 Orig: *FromReg.DefMI);
526 NewReg.DefMI = &*std::prev(x: InsertPos);
527 RegToIdx.insert(KV: {NewDefReg, NewRegIdx});
528 postRematerialization(ModelRegIdx: RegIdx, RematRegIdx: NewRegIdx, InsertPos);
529
530 noteRegCreated(RegIdx: NewRegIdx);
531 LLVM_DEBUG(dbgs() << "** Rematerialized " << printID(RegIdx) << " as "
532 << printRematReg(NewRegIdx) << '\n');
533 return NewRegIdx;
534}
535
536void Rematerializer::recreateReg(
537 RegisterIdx RegIdx, unsigned DefRegion,
538 MachineBasicBlock::iterator InsertPos, Register DefReg,
539 SmallVectorImpl<Reg::Dependency> &&Dependencies) {
540 assert(RegToIdx.contains(DefReg) && "unknown defined register");
541 assert(RegToIdx.at(DefReg) == RegIdx && "incorrect defined register");
542 assert(!getReg(RegIdx).DefMI && "register is still alive");
543
544 Reg &OriginReg = Regs[RegIdx];
545 OriginReg.DefRegion = DefRegion;
546 OriginReg.Dependencies = std::move(Dependencies);
547
548 // Re-establish the link between origin and rematerialization if necessary.
549 const bool RecreateOriginalReg = isOriginalRegister(RegIdx);
550 if (!RecreateOriginalReg)
551 Rematerializations[getOriginOf(RematRegIdx: RegIdx)].insert(V: RegIdx);
552
553 // Rematerialize from one of the existing rematerializations or from the
554 // origin. We expect at least one to exist, otherwise it would mean the value
555 // held by the original register is no longer available anywhere in the MF.
556 RegisterIdx ModelRegIdx;
557 if (RecreateOriginalReg) {
558 assert(Rematerializations.contains(RegIdx) && "expected remats");
559 ModelRegIdx = *Rematerializations.at(Val: RegIdx).begin();
560 } else {
561 assert(getReg(getOriginOf(RegIdx)).DefMI && "expected alive origin");
562 ModelRegIdx = getOriginOf(RematRegIdx: RegIdx);
563 }
564 const MachineInstr &ModelDefMI = *getReg(RegIdx: ModelRegIdx).DefMI;
565
566 TII.reMaterialize(MBB&: *RegionMBB[DefRegion], MI: InsertPos, DestReg: DefReg, SubIdx: 0, Orig: ModelDefMI);
567 OriginReg.DefMI = &*std::prev(x: InsertPos);
568 postRematerialization(ModelRegIdx, RematRegIdx: RegIdx, InsertPos);
569 LLVM_DEBUG(dbgs() << "** Recreated " << printID(RegIdx) << " as "
570 << printRematReg(RegIdx) << '\n');
571}
572
573void Rematerializer::postRematerialization(
574 RegisterIdx ModelRegIdx, RegisterIdx RematRegIdx,
575 MachineBasicBlock::iterator InsertPos) {
576
577 // The start of the new register's region may have changed.
578 Reg &ModelReg = Regs[ModelRegIdx], &RematReg = Regs[RematRegIdx];
579 LIS.InsertMachineInstrInMaps(MI&: *RematReg.DefMI);
580 MachineBasicBlock::iterator &RegionBegin = Regions[RematReg.DefRegion].first;
581 if (RegionBegin == std::next(x: MachineBasicBlock::iterator(RematReg.DefMI)))
582 RegionBegin = RematReg.DefMI;
583
584 // Replace dependencies as needed in the rematerialized MI. All dependencies
585 // of the latter gain a new user.
586 auto ZipedDeps = zip_equal(t&: ModelReg.Dependencies, u&: RematReg.Dependencies);
587 for (const auto &[OldDep, NewDep] : ZipedDeps) {
588 assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch");
589 LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": "
590 << printID(OldDep.RegIdx) << " -> "
591 << printID(NewDep.RegIdx) << '\n');
592
593 Reg &NewDepReg = Regs[NewDep.RegIdx];
594 if (OldDep.RegIdx != NewDep.RegIdx) {
595 Register OldDefReg = ModelReg.DefMI->getOperand(i: OldDep.MOIdx).getReg();
596 RematReg.DefMI->substituteRegister(FromReg: OldDefReg, ToReg: NewDepReg.getDefReg(), SubIdx: 0,
597 RegInfo: TRI);
598 LISUpdates.insert(V: OldDep.RegIdx);
599 }
600 NewDepReg.addUser(MI: RematReg.DefMI, Region: RematReg.DefRegion);
601 LISUpdates.insert(V: NewDep.RegIdx);
602 }
603}
604
605std::pair<MachineInstr *, MachineInstr *>
606Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion,
607 const LiveIntervals &LIS) const {
608 auto It = Uses.find(Val: UseRegion);
609 if (It == Uses.end())
610 return {nullptr, nullptr};
611 const RegionUsers &RegionUsers = It->getSecond();
612 assert(!RegionUsers.empty() && "empty userset in region");
613
614 auto User = RegionUsers.begin(), UserEnd = RegionUsers.end();
615 MachineInstr *FirstMI = *User, *LastMI = FirstMI;
616 SlotIndex FirstIndex = LIS.getInstructionIndex(Instr: *FirstMI),
617 LastIndex = FirstIndex;
618
619 while (++User != UserEnd) {
620 SlotIndex UserIndex = LIS.getInstructionIndex(Instr: **User);
621 if (UserIndex < FirstIndex) {
622 FirstIndex = UserIndex;
623 FirstMI = *User;
624 } else if (UserIndex > LastIndex) {
625 LastIndex = UserIndex;
626 LastMI = *User;
627 }
628 }
629
630 return {FirstMI, LastMI};
631}
632
633void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) {
634 Uses[Region].insert(V: MI);
635}
636
637void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers,
638 unsigned Region) {
639 Uses[Region].insert_range(R: NewUsers);
640}
641
642void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) {
643 RegionUsers &RUsers = Uses.at(Val: Region);
644 assert(RUsers.contains(MI) && "user not in region");
645 if (RUsers.size() == 1)
646 Uses.erase(Val: Region);
647 else
648 RUsers.erase(V: MI);
649}
650
651Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const {
652 return Printable([&, RootIdx](raw_ostream &OS) {
653 DenseMap<RegisterIdx, unsigned> RegDepths;
654 std::function<void(RegisterIdx, unsigned)> WalkTree =
655 [&](RegisterIdx RegIdx, unsigned Depth) -> void {
656 unsigned MaxDepth = std::max(a: RegDepths.lookup_or(Val: RegIdx, Default&: Depth), b: Depth);
657 RegDepths.emplace_or_assign(Key: RegIdx, Args&: MaxDepth);
658 for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies)
659 WalkTree(Dep.RegIdx, Depth + 1);
660 };
661 WalkTree(RootIdx, 0);
662
663 // Sort in decreasing depth order to print root at the bottom.
664 SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(),
665 RegDepths.end());
666 sort(C&: Regs, Comp: [](const auto &LHS, const auto &RHS) {
667 return LHS.second > RHS.second;
668 });
669
670 OS << printID(RegIdx: RootIdx) << " has " << Regs.size() - 1 << " dependencies\n";
671 for (const auto &[RegIdx, Depth] : Regs) {
672 OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' '
673 << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n';
674 }
675 OS << printRegUsers(RegIdx: RootIdx);
676 });
677}
678
679Printable Rematerializer::printID(RegisterIdx RegIdx) const {
680 return Printable([&, RegIdx](raw_ostream &OS) {
681 const Reg &PrintReg = getReg(RegIdx);
682 OS << '(' << RegIdx << '/';
683 if (!PrintReg.DefMI) {
684 OS << "<dead>";
685 } else {
686 OS << printReg(Reg: PrintReg.getDefReg(), TRI: &TRI,
687 SubIdx: PrintReg.DefMI->getOperand(i: 0).getSubReg(), MRI: &MRI);
688 }
689 OS << ")[" << PrintReg.DefRegion << "]";
690 });
691}
692
693Printable Rematerializer::printRematReg(RegisterIdx RegIdx,
694 bool SkipRegions) const {
695 return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) {
696 const Reg &PrintReg = getReg(RegIdx);
697 if (!SkipRegions) {
698 OS << printID(RegIdx) << " [" << PrintReg.DefRegion;
699 if (!PrintReg.Uses.empty()) {
700 assert(PrintReg.DefMI && "dead register cannot have uses");
701 const LiveInterval &LI = LIS.getInterval(Reg: PrintReg.getDefReg());
702 // First display all regions in which the register is live-through and
703 // not used.
704 bool First = true;
705 for (const auto [I, Bounds] : enumerate(First&: Regions)) {
706 if (Bounds.first == Bounds.second)
707 continue;
708 if (!PrintReg.Uses.contains(Val: I) &&
709 LI.liveAt(index: LIS.getInstructionIndex(Instr: *Bounds.first)) &&
710 LI.liveAt(index: LIS.getInstructionIndex(Instr: *std::prev(x: Bounds.second))
711 .getRegSlot())) {
712 OS << (First ? " - " : ",") << I;
713 First = false;
714 }
715 }
716 OS << (First ? " --> " : " -> ");
717
718 // Then display regions in which the register is used.
719 auto It = PrintReg.Uses.begin();
720 OS << It->first;
721 while (++It != PrintReg.Uses.end())
722 OS << "," << It->first;
723 }
724 OS << "] ";
725 }
726 OS << printID(RegIdx) << ' ';
727 PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
728 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
729 OS << " @ ";
730 LIS.getInstructionIndex(Instr: *PrintReg.DefMI).print(os&: OS);
731 });
732}
733
734Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const {
735 return Printable([&, RegIdx](raw_ostream &OS) {
736 for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) {
737 for (MachineInstr *MI : Users)
738 OS << " User " << printUser(MI, UseRegion) << '\n';
739 }
740 });
741}
742
743Printable Rematerializer::printUser(const MachineInstr *MI,
744 std::optional<unsigned> UseRegion) const {
745 return Printable([&, MI, UseRegion](raw_ostream &OS) {
746 RegisterIdx RegIdx = getDefRegIdx(MI: *MI);
747 if (RegIdx != NoReg) {
748 OS << printID(RegIdx);
749 } else {
750 OS << "(-/-)[";
751 if (UseRegion)
752 OS << *UseRegion;
753 else
754 OS << '?';
755 OS << ']';
756 }
757 OS << ' ';
758 MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false,
759 /*SkipDebugLoc=*/false, /*AddNewLine=*/false);
760 OS << " @ ";
761 LIS.getInstructionIndex(Instr: *MI).print(os&: OS);
762 });
763}
764
765Rollbacker::RollbackInfo::RollbackInfo(const Rematerializer &Remater,
766 RegisterIdx RegIdx) {
767 const Rematerializer::Reg &Reg = Remater.getReg(RegIdx);
768 DefReg = Reg.getDefReg();
769 DefRegion = Reg.DefRegion;
770 Dependencies = Reg.Dependencies;
771
772 InsertPos = std::next(x: Reg.DefMI->getIterator());
773 if (InsertPos != Reg.DefMI->getParent()->end())
774 NextRegIdx = Remater.getDefRegIdx(MI: *InsertPos);
775 else
776 NextRegIdx = Rematerializer::NoReg;
777}
778
779void Rollbacker::rematerializerNoteRegCreated(const Rematerializer &Remater,
780 RegisterIdx RegIdx) {
781 if (RollingBack)
782 return;
783 Rematerializations[Remater.getOriginOf(RematRegIdx: RegIdx)].insert(V: RegIdx);
784}
785
786void Rollbacker::rematerializerNoteRegDeleted(const Rematerializer &Remater,
787 RegisterIdx RegIdx) {
788 if (RollingBack || Remater.isRematerializedRegister(RegIdx))
789 return;
790 DeadRegs.try_emplace(Key: RegIdx, Args: Remater, Args&: RegIdx);
791}
792
793void Rollbacker::rollback(Rematerializer &Remater) {
794 RollingBack = true;
795
796 // Re-create deleted registers.
797 for (auto &[RegIdx, Info] : DeadRegs) {
798 assert(!Remater.getReg(RegIdx).isAlive() && "register should be dead");
799
800 // The MI that was originally just after the MI defining the register we
801 // are trying to re-create may have been deleted. In such cases, we can
802 // re-create at that MI's own insert position (and apply the same logic
803 // recursively).
804 MachineBasicBlock::iterator InsertPos = Info.InsertPos;
805 RegisterIdx NextRegIdx = Info.NextRegIdx;
806 while (NextRegIdx != Rematerializer::NoReg) {
807 const auto *NextRegRollback = DeadRegs.find(Key: NextRegIdx);
808 if (NextRegRollback == DeadRegs.end())
809 break;
810 InsertPos = NextRegRollback->second.InsertPos;
811 NextRegIdx = NextRegRollback->second.NextRegIdx;
812 }
813 Remater.recreateReg(RegIdx, DefRegion: Info.DefRegion, InsertPos, DefReg: Info.DefReg,
814 Dependencies: std::move(Info.Dependencies));
815 }
816
817 // Rollback rematerializations.
818 for (const auto &[RegIdx, RematsOf] : Rematerializations) {
819 for (RegisterIdx RematRegIdx : RematsOf) {
820 // It is possible that rematerializations were deleted. Their users would
821 // have been transfered to some other rematerialization so we can safely
822 // ignore them. Original registers that were deleted were just re-created
823 // so we do not need to check for that.
824 if (Remater.getReg(RegIdx: RematRegIdx).isAlive())
825 Remater.transferAllUsers(FromRegIdx: RematRegIdx, ToRegIdx: RegIdx);
826 }
827 }
828
829 Remater.updateLiveIntervals();
830 DeadRegs.clear();
831 Rematerializations.clear();
832 RollingBack = false;
833}
834