| 1 | //=====-- Rematerializer.cpp - MIR rematerialization support ----*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //==-----------------------------------------------------------------------===// |
| 8 | // |
| 9 | /// \file |
| 10 | /// Implements helpers for target-independent rematerialization at the MIR |
| 11 | /// level. |
| 12 | // |
| 13 | //===----------------------------------------------------------------------===// |
| 14 | |
| 15 | #include "llvm/CodeGen/Rematerializer.h" |
| 16 | #include "llvm/ADT/MapVector.h" |
| 17 | #include "llvm/ADT/STLExtras.h" |
| 18 | #include "llvm/ADT/SetVector.h" |
| 19 | #include "llvm/CodeGen/LiveIntervals.h" |
| 20 | #include "llvm/CodeGen/MachineBasicBlock.h" |
| 21 | #include "llvm/CodeGen/MachineOperand.h" |
| 22 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
| 23 | #include "llvm/CodeGen/Register.h" |
| 24 | #include "llvm/CodeGen/TargetOpcodes.h" |
| 25 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
| 26 | #include "llvm/Support/Debug.h" |
| 27 | #include <optional> |
| 28 | |
| 29 | #define DEBUG_TYPE "rematerializer" |
| 30 | |
| 31 | using namespace llvm; |
| 32 | using RegisterIdx = Rematerializer::RegisterIdx; |
| 33 | |
| 34 | // Pin the vtable to this file. |
| 35 | void Rematerializer::Listener::anchor() {} |
| 36 | |
| 37 | /// Checks whether the value in \p LI at \p UseIdx is identical to \p OVNI (this |
| 38 | /// implies it is also live there). When \p LI has sub-ranges, checks that |
| 39 | /// all sub-ranges intersecting with \p Mask are also live at \p UseIdx. |
| 40 | static bool isIdenticalAtUse(const VNInfo &OVNI, LaneBitmask Mask, |
| 41 | SlotIndex UseIdx, const LiveInterval &LI) { |
| 42 | if (&OVNI != LI.getVNInfoAt(Idx: UseIdx)) |
| 43 | return false; |
| 44 | |
| 45 | if (LI.hasSubRanges()) { |
| 46 | // Check that intersecting subranges are live at user. |
| 47 | for (const LiveInterval::SubRange &SR : LI.subranges()) { |
| 48 | if ((SR.LaneMask & Mask).none()) |
| 49 | continue; |
| 50 | if (!SR.liveAt(index: UseIdx)) |
| 51 | return false; |
| 52 | |
| 53 | // Early exit if all used lanes are checked. No need to continue. |
| 54 | Mask &= ~SR.LaneMask; |
| 55 | if (Mask.none()) |
| 56 | break; |
| 57 | } |
| 58 | } |
| 59 | return true; |
| 60 | } |
| 61 | |
| 62 | /// If \p MO is a virtual read register, returns it. Otherwise returns the |
| 63 | /// sentinel register. |
| 64 | static Register getRegDependency(const MachineOperand &MO) { |
| 65 | if (!MO.isReg() || !MO.readsReg()) |
| 66 | return Register(); |
| 67 | Register Reg = MO.getReg(); |
| 68 | if (Reg.isPhysical()) { |
| 69 | // By the requirements on trivially rematerializable instructions, a |
| 70 | // physical register use is either constant or ignorable. |
| 71 | return Register(); |
| 72 | } |
| 73 | return Reg; |
| 74 | } |
| 75 | |
| 76 | RegisterIdx Rematerializer::rematerializeToRegion(RegisterIdx RootIdx, |
| 77 | unsigned UseRegion, |
| 78 | DependencyReuseInfo &DRI) { |
| 79 | MachineInstr *FirstMI = |
| 80 | getReg(RegIdx: RootIdx).getRegionUseBounds(UseRegion, LIS).first; |
| 81 | // If there are no users in the region, rematerialize the register at the very |
| 82 | // end of the region. |
| 83 | MachineBasicBlock::iterator InsertPos = |
| 84 | FirstMI ? FirstMI : Regions[UseRegion].second; |
| 85 | RegisterIdx NewRegIdx = |
| 86 | rematerializeToPos(RootIdx, UseRegion, InsertPos, DRI); |
| 87 | transferRegionUsers(FromRegIdx: RootIdx, ToRegIdx: NewRegIdx, UseRegion); |
| 88 | return NewRegIdx; |
| 89 | } |
| 90 | |
| 91 | RegisterIdx |
| 92 | Rematerializer::rematerializeToPos(RegisterIdx RootIdx, unsigned UseRegion, |
| 93 | MachineBasicBlock::iterator InsertPos, |
| 94 | DependencyReuseInfo &DRI) { |
| 95 | assert(!DRI.DependencyMap.contains(RootIdx)); |
| 96 | LLVM_DEBUG(dbgs() << "Rematerializing " << printID(RootIdx) << '\n'); |
| 97 | |
| 98 | // Traverse the root's dependency DAG depth-first to find the set of |
| 99 | // registers we must rematerialize along with it and a legal order to |
| 100 | // rematerialize them in. |
| 101 | SmallVector<RegisterIdx, 4> DepDAG{RootIdx}; |
| 102 | SmallSetVector<RegisterIdx, 8> RematOrder; |
| 103 | RematOrder.insert(X: RootIdx); |
| 104 | do { |
| 105 | RegisterIdx RegIdx = DepDAG.pop_back_val(); |
| 106 | for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) { |
| 107 | // The dependency may already have a rematerialization ready to use. |
| 108 | if (DRI.DependencyMap.contains(Val: Dep.RegIdx)) |
| 109 | continue; |
| 110 | // We may have already seen the dependency in the dependency DAG. |
| 111 | if (RematOrder.contains(key: Dep.RegIdx)) |
| 112 | continue; |
| 113 | DepDAG.push_back(Elt: Dep.RegIdx); |
| 114 | RematOrder.insert(X: Dep.RegIdx); |
| 115 | } |
| 116 | } while (!DepDAG.empty()); |
| 117 | |
| 118 | // Rematerialize all necessary registers in the root's dependency DAG. At each |
| 119 | // rematerialization, dependencies should already be available. |
| 120 | RegisterIdx LastNewIdx; |
| 121 | for (RegisterIdx RegIdx : reverse(C&: RematOrder)) { |
| 122 | assert(!DRI.DependencyMap.contains(RegIdx) && "useless remat" ); |
| 123 | SmallVector<Reg::Dependency, 2> Dependencies; |
| 124 | for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) |
| 125 | Dependencies.emplace_back(Args: Dep.MOIdx, Args&: DRI.DependencyMap.at(Val: Dep.RegIdx)); |
| 126 | LastNewIdx = |
| 127 | rematerializeReg(RegIdx, UseRegion, InsertPos, Dependencies: std::move(Dependencies)); |
| 128 | DRI.DependencyMap.insert(KV: {RegIdx, LastNewIdx}); |
| 129 | } |
| 130 | |
| 131 | return LastNewIdx; |
| 132 | } |
| 133 | |
| 134 | void Rematerializer::rollbackRematsOf(RegisterIdx RootIdx) { |
| 135 | auto Remats = Rematerializations.find(Val: RootIdx); |
| 136 | if (Remats == Rematerializations.end()) |
| 137 | return; |
| 138 | |
| 139 | LLVM_DEBUG(dbgs() << "Rolling back rematerializations of " << printID(RootIdx) |
| 140 | << '\n'); |
| 141 | |
| 142 | reviveRegIfDead(RootIdx); |
| 143 | // All of the rematerialization's users must use the revived register. |
| 144 | for (RegisterIdx RematRegIdx : Remats->getSecond()) { |
| 145 | for (const auto &[UseRegion, RegionUsers] : Regs[RematRegIdx].Uses) |
| 146 | transferRegionUsers(FromRegIdx: RematRegIdx, ToRegIdx: RootIdx, UseRegion); |
| 147 | } |
| 148 | Rematerializations.erase(Val: RootIdx); |
| 149 | |
| 150 | LLVM_DEBUG(dbgs() << "** Rolled back rematerializations of " |
| 151 | << printID(RootIdx) << '\n'); |
| 152 | } |
| 153 | |
| 154 | void Rematerializer::rollback(RegisterIdx RematIdx) { |
| 155 | assert(getReg(RematIdx).DefMI && !Revivable.contains(RematIdx) && |
| 156 | "cannot rollback dead register" ); |
| 157 | const RegisterIdx OriginRegIdx = getOriginOf(RematRegIdx: RematIdx); |
| 158 | reviveRegIfDead(RootIdx: OriginRegIdx); |
| 159 | for (const auto &[UseRegion, RegionUsers] : Regs[RematIdx].Uses) |
| 160 | transferRegionUsers(FromRegIdx: RematIdx, ToRegIdx: OriginRegIdx, UseRegion); |
| 161 | } |
| 162 | |
| 163 | void Rematerializer::reviveRegIfDead(RegisterIdx RootIdx) { |
| 164 | if (getReg(RegIdx: RootIdx).isAlive()) |
| 165 | return; |
| 166 | assert(Revivable.contains(RootIdx) && "not revivable" ); |
| 167 | |
| 168 | // Traverse the root's dependency DAG depth-first to find the set of |
| 169 | // registers we must revive and a legal order to revive them in. |
| 170 | SmallVector<RegisterIdx, 4> DepDAG{RootIdx}; |
| 171 | SmallSetVector<RegisterIdx, 8> ReviveOrder; |
| 172 | ReviveOrder.insert(X: RootIdx); |
| 173 | do { |
| 174 | // All dependencies of a revived register need to be alive too. |
| 175 | const Reg &ReviveReg = getReg(RegIdx: DepDAG.pop_back_val()); |
| 176 | for (const Reg::Dependency &Dep : ReviveReg.Dependencies) { |
| 177 | // We may have already seen the dependency in the dependency DAG. |
| 178 | if (ReviveOrder.contains(key: Dep.RegIdx)) |
| 179 | continue; |
| 180 | |
| 181 | // Dead dependencies need to be revived. |
| 182 | Reg &DepReg = Regs[Dep.RegIdx]; |
| 183 | if (!DepReg.isAlive()) { |
| 184 | assert(Revivable.contains(Dep.RegIdx) && "not revivable" ); |
| 185 | ReviveOrder.insert(X: Dep.RegIdx); |
| 186 | DepDAG.push_back(Elt: Dep.RegIdx); |
| 187 | } |
| 188 | |
| 189 | // All dependencies get a new user (the revived register). |
| 190 | DepReg.addUser(MI: ReviveReg.DefMI, Region: ReviveReg.DefRegion); |
| 191 | LISUpdates.insert(V: Dep.RegIdx); |
| 192 | } |
| 193 | } while (!DepDAG.empty()); |
| 194 | |
| 195 | for (RegisterIdx RegIdx : reverse(C&: ReviveOrder)) { |
| 196 | // Pick any rematerialization to retrieve the original opcode from. |
| 197 | Reg &ReviveReg = Regs[RegIdx]; |
| 198 | assert(Rematerializations.contains(RegIdx) && "no remats" ); |
| 199 | RegisterIdx RematIdx = *Rematerializations.at(Val: RegIdx).begin(); |
| 200 | ReviveReg.DefMI->setDesc(getReg(RegIdx: RematIdx).DefMI->getDesc()); |
| 201 | for (const auto &[MOIdx, Reg] : Revivable.at(Val: RegIdx)) |
| 202 | ReviveReg.DefMI->getOperand(i: MOIdx).setReg(Reg); |
| 203 | Revivable.erase(Val: RegIdx); |
| 204 | LISUpdates.insert(V: RegIdx); |
| 205 | |
| 206 | LLVM_DEBUG({ |
| 207 | dbgs() << "** Revived " << printID(RegIdx) << " @ " ; |
| 208 | LIS.getInstructionIndex(*ReviveReg.DefMI).print(dbgs()); |
| 209 | dbgs() << '\n'; |
| 210 | }); |
| 211 | } |
| 212 | } |
| 213 | |
| 214 | void Rematerializer::transferUser(RegisterIdx FromRegIdx, RegisterIdx ToRegIdx, |
| 215 | unsigned UserRegion, MachineInstr &UserMI) { |
| 216 | transferUserImpl(FromRegIdx, ToRegIdx, UserMI); |
| 217 | Regs[FromRegIdx].eraseUser(MI: &UserMI, Region: UserRegion); |
| 218 | Regs[ToRegIdx].addUser(MI: &UserMI, Region: UserRegion); |
| 219 | deleteRegIfUnused(RootIdx: FromRegIdx); |
| 220 | } |
| 221 | |
| 222 | void Rematerializer::transferRegionUsers(RegisterIdx FromRegIdx, |
| 223 | RegisterIdx ToRegIdx, |
| 224 | unsigned UseRegion) { |
| 225 | auto &FromRegUsers = Regs[FromRegIdx].Uses; |
| 226 | auto UsesIt = FromRegUsers.find(Val: UseRegion); |
| 227 | if (UsesIt == FromRegUsers.end()) |
| 228 | return; |
| 229 | |
| 230 | const SmallDenseSet<MachineInstr *, 4> &RegionUsers = UsesIt->getSecond(); |
| 231 | for (MachineInstr *UserMI : RegionUsers) |
| 232 | transferUserImpl(FromRegIdx, ToRegIdx, UserMI&: *UserMI); |
| 233 | Regs[ToRegIdx].addUsers(NewUsers: RegionUsers, Region: UseRegion); |
| 234 | FromRegUsers.erase(Val: UseRegion); |
| 235 | deleteRegIfUnused(RootIdx: FromRegIdx); |
| 236 | } |
| 237 | |
| 238 | void Rematerializer::transferUserImpl(RegisterIdx FromRegIdx, |
| 239 | RegisterIdx ToRegIdx, |
| 240 | MachineInstr &UserMI) { |
| 241 | assert(FromRegIdx != ToRegIdx && "identical registers" ); |
| 242 | assert(getOriginOrSelf(FromRegIdx) == getOriginOrSelf(ToRegIdx) && |
| 243 | "unrelated registers" ); |
| 244 | |
| 245 | LLVM_DEBUG(dbgs() << "User transfer from " << printID(FromRegIdx) << " to " |
| 246 | << printID(ToRegIdx) << ": " << printUser(&UserMI) << '\n'); |
| 247 | |
| 248 | UserMI.substituteRegister(FromReg: getReg(RegIdx: FromRegIdx).getDefReg(), |
| 249 | ToReg: getReg(RegIdx: ToRegIdx).getDefReg(), SubIdx: 0, RegInfo: TRI); |
| 250 | LISUpdates.insert(V: FromRegIdx); |
| 251 | LISUpdates.insert(V: ToRegIdx); |
| 252 | |
| 253 | // If the user is rematerializable, we must change its dependency to the |
| 254 | // new register. |
| 255 | if (RegisterIdx UserRegIdx = getDefRegIdx(MI: UserMI); UserRegIdx != NoReg) { |
| 256 | // Look for the user's dependency that matches the register. |
| 257 | for (Reg::Dependency &Dep : Regs[UserRegIdx].Dependencies) { |
| 258 | if (Dep.RegIdx == FromRegIdx) { |
| 259 | Dep.RegIdx = ToRegIdx; |
| 260 | return; |
| 261 | } |
| 262 | } |
| 263 | llvm_unreachable("broken dependency" ); |
| 264 | } |
| 265 | } |
| 266 | |
| 267 | void Rematerializer::updateLiveIntervals() { |
| 268 | DenseSet<Register> SeenUnrematRegs; |
| 269 | for (RegisterIdx RegIdx : LISUpdates) { |
| 270 | const Reg &UpdateReg = getReg(RegIdx); |
| 271 | assert((UpdateReg.DefMI || Revivable.contains(RegIdx)) && "dead reg" ); |
| 272 | |
| 273 | Register DefReg = UpdateReg.getDefReg(); |
| 274 | if (LIS.hasInterval(Reg: DefReg)) |
| 275 | LIS.removeInterval(Reg: DefReg); |
| 276 | LIS.createAndComputeVirtRegInterval(Reg: DefReg); |
| 277 | |
| 278 | LLVM_DEBUG({ |
| 279 | dbgs() << "Re-computed interval for " << printID(RegIdx) << ": " ; |
| 280 | LIS.getInterval(DefReg).print(dbgs()); |
| 281 | dbgs() << '\n' << printRegUsers(RegIdx); |
| 282 | }); |
| 283 | |
| 284 | // Update intervals for unrematerializable operands. |
| 285 | for (unsigned MOIdx : getUnrematableOprds(RegIdx)) { |
| 286 | Register UnrematReg = UpdateReg.DefMI->getOperand(i: MOIdx).getReg(); |
| 287 | if (!SeenUnrematRegs.insert(V: UnrematReg).second) |
| 288 | continue; |
| 289 | LIS.removeInterval(Reg: UnrematReg); |
| 290 | LIS.createAndComputeVirtRegInterval(Reg: UnrematReg); |
| 291 | LLVM_DEBUG( |
| 292 | dbgs() << " Re-computed interval for register " |
| 293 | << printReg(UnrematReg, &TRI, |
| 294 | UpdateReg.DefMI->getOperand(MOIdx).getSubReg(), |
| 295 | &MRI) |
| 296 | << '\n'); |
| 297 | } |
| 298 | } |
| 299 | LISUpdates.clear(); |
| 300 | } |
| 301 | |
| 302 | void Rematerializer::commitRematerializations() { |
| 303 | for (auto &[RegIdx, _] : Revivable) |
| 304 | deleteReg(RegIdx); |
| 305 | Revivable.clear(); |
| 306 | } |
| 307 | |
| 308 | bool Rematerializer::isMOIdenticalAtUses(MachineOperand &MO, |
| 309 | ArrayRef<SlotIndex> Uses) const { |
| 310 | if (Uses.empty()) |
| 311 | return true; |
| 312 | Register Reg = MO.getReg(); |
| 313 | unsigned SubIdx = MO.getSubReg(); |
| 314 | LaneBitmask Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) |
| 315 | : MRI.getMaxLaneMaskForVReg(Reg); |
| 316 | const LiveInterval &LI = LIS.getInterval(Reg); |
| 317 | const VNInfo *DefVN = |
| 318 | LI.getVNInfoAt(Idx: LIS.getInstructionIndex(Instr: *MO.getParent()).getRegSlot(EC: true)); |
| 319 | for (SlotIndex Use : Uses) { |
| 320 | if (!isIdenticalAtUse(OVNI: *DefVN, Mask, UseIdx: Use, LI)) |
| 321 | return false; |
| 322 | } |
| 323 | return true; |
| 324 | } |
| 325 | |
| 326 | RegisterIdx Rematerializer::findRematInRegion(RegisterIdx RegIdx, |
| 327 | unsigned Region, |
| 328 | SlotIndex Before) const { |
| 329 | auto It = Rematerializations.find(Val: getOriginOrSelf(RegIdx)); |
| 330 | if (It == Rematerializations.end()) |
| 331 | return NoReg; |
| 332 | const RematsOf &Remats = It->getSecond(); |
| 333 | |
| 334 | SlotIndex BestSlot; |
| 335 | RegisterIdx BestRegIdx = NoReg; |
| 336 | for (RegisterIdx RematRegIdx : Remats) { |
| 337 | const Reg &RematReg = getReg(RegIdx: RematRegIdx); |
| 338 | if (RematReg.DefRegion != Region || RematReg.Uses.empty()) |
| 339 | continue; |
| 340 | SlotIndex RematRegSlot = |
| 341 | LIS.getInstructionIndex(Instr: *RematReg.DefMI).getRegSlot(); |
| 342 | if (RematRegSlot < Before && |
| 343 | (BestRegIdx == NoReg || RematRegSlot > BestSlot)) { |
| 344 | BestSlot = RematRegSlot; |
| 345 | BestRegIdx = RematRegIdx; |
| 346 | } |
| 347 | } |
| 348 | return BestRegIdx; |
| 349 | } |
| 350 | |
| 351 | void Rematerializer::deleteRegIfUnused(RegisterIdx RootIdx) { |
| 352 | if (!getReg(RegIdx: RootIdx).Uses.empty()) |
| 353 | return; |
| 354 | |
| 355 | // Traverse the root's dependency DAG depth-first to find the set of registers |
| 356 | // we can delete and a legal order to delete them in. |
| 357 | SmallVector<RegisterIdx, 4> DepDAG{RootIdx}; |
| 358 | SmallSetVector<RegisterIdx, 8> DeleteOrder; |
| 359 | DeleteOrder.insert(X: RootIdx); |
| 360 | do { |
| 361 | // A deleted register's dependencies may be deletable too. |
| 362 | const Reg &DeleteReg = getReg(RegIdx: DepDAG.pop_back_val()); |
| 363 | for (const Reg::Dependency &Dep : DeleteReg.Dependencies) { |
| 364 | // All dependencies loose a user (the delete register). |
| 365 | Reg &DepReg = Regs[Dep.RegIdx]; |
| 366 | DepReg.eraseUser(MI: DeleteReg.DefMI, Region: DeleteReg.DefRegion); |
| 367 | if (DepReg.Uses.empty()) { |
| 368 | DeleteOrder.insert(X: Dep.RegIdx); |
| 369 | DepDAG.push_back(Elt: Dep.RegIdx); |
| 370 | } |
| 371 | } |
| 372 | } while (!DepDAG.empty()); |
| 373 | |
| 374 | for (RegisterIdx RegIdx : reverse(C&: DeleteOrder)) { |
| 375 | Reg &DeleteReg = Regs[RegIdx]; |
| 376 | LIS.removeInterval(Reg: DeleteReg.getDefReg()); |
| 377 | LISUpdates.erase(V: RegIdx); |
| 378 | const bool IsRematerializedReg = isRematerializedRegister(RegIdx); |
| 379 | if (SupportRollback && !IsRematerializedReg) { |
| 380 | // Replace all read registers with the null one to prevent them from |
| 381 | // showing up in use-lists, which is disallowed for debug instructions in |
| 382 | // live interval calculations. Store mappings between operand indices and |
| 383 | // original registers for potential rollback. |
| 384 | DenseMap<unsigned, Register> &RegMap = |
| 385 | Revivable.try_emplace(Key: RegIdx).first->getSecond(); |
| 386 | for (auto [Idx, MO] : enumerate(First: DeleteReg.DefMI->operands())) { |
| 387 | if (MO.isReg() && MO.readsReg()) { |
| 388 | RegMap.insert(KV: {Idx, MO.getReg()}); |
| 389 | MO.setReg(Register()); |
| 390 | } |
| 391 | } |
| 392 | DeleteReg.DefMI->setDesc(TII.get(Opcode: TargetOpcode::DBG_VALUE)); |
| 393 | } else { |
| 394 | deleteReg(RegIdx); |
| 395 | } |
| 396 | if (IsRematerializedReg) { |
| 397 | // Delete rematerialized register from its origin's rematerializations. |
| 398 | RematsOf &OriginRemats = Rematerializations.at(Val: getOriginOf(RematRegIdx: RegIdx)); |
| 399 | assert(OriginRemats.contains(RegIdx) && "broken remat<->origin link" ); |
| 400 | OriginRemats.erase(V: RegIdx); |
| 401 | if (OriginRemats.empty()) |
| 402 | Rematerializations.erase(Val: RegIdx); |
| 403 | } |
| 404 | LLVM_DEBUG(dbgs() << "** Deleted " << printID(RegIdx) << "\n" ); |
| 405 | } |
| 406 | } |
| 407 | |
| 408 | void Rematerializer::deleteReg(RegisterIdx RegIdx) { |
| 409 | noteRegDeleted(RegIdx); |
| 410 | |
| 411 | Reg &DeleteReg = Regs[RegIdx]; |
| 412 | assert(DeleteReg.DefMI && "register was already deleted" ); |
| 413 | // It is not possible for the deleted instruction to be the upper region |
| 414 | // boundary since we don't ever consider them rematerializable. |
| 415 | MachineBasicBlock::iterator &RegionBegin = Regions[DeleteReg.DefRegion].first; |
| 416 | if (RegionBegin == DeleteReg.DefMI) |
| 417 | RegionBegin = std::next(x: MachineBasicBlock::iterator(DeleteReg.DefMI)); |
| 418 | LIS.RemoveMachineInstrFromMaps(MI&: *DeleteReg.DefMI); |
| 419 | DeleteReg.DefMI->eraseFromParent(); |
| 420 | DeleteReg.DefMI = nullptr; |
| 421 | } |
| 422 | |
| 423 | Rematerializer::Rematerializer(MachineFunction &MF, |
| 424 | SmallVectorImpl<RegionBoundaries> &Regions, |
| 425 | LiveIntervals &LIS) |
| 426 | : Regions(Regions), MRI(MF.getRegInfo()), LIS(LIS), |
| 427 | TII(*MF.getSubtarget().getInstrInfo()), TRI(TII.getRegisterInfo()) { |
| 428 | #ifdef EXPENSIVE_CHECKS |
| 429 | // Check that regions are valid. |
| 430 | DenseSet<MachineInstr *> SeenMIs; |
| 431 | for (const auto &[RegionBegin, RegionEnd] : Regions) { |
| 432 | assert(RegionBegin != RegionEnd && "empty region" ); |
| 433 | for (auto MI = RegionBegin; MI != RegionEnd; ++MI) { |
| 434 | bool IsNewMI = SeenMIs.insert(&*MI).second; |
| 435 | assert(IsNewMI && "overlapping regions" ); |
| 436 | assert(!MI->isTerminator() && "terminator in region" ); |
| 437 | } |
| 438 | if (RegionEnd != RegionBegin->getParent()->end()) { |
| 439 | bool IsNewMI = SeenMIs.insert(&*RegionEnd).second; |
| 440 | assert(IsNewMI && "overlapping regions (upper bound)" ); |
| 441 | } |
| 442 | } |
| 443 | #endif |
| 444 | } |
| 445 | |
| 446 | bool Rematerializer::analyze(bool SupportRollback) { |
| 447 | Regs.clear(); |
| 448 | UnrematableOprds.clear(); |
| 449 | Origins.clear(); |
| 450 | Rematerializations.clear(); |
| 451 | RegionMBB.clear(); |
| 452 | RegToIdx.clear(); |
| 453 | LISUpdates.clear(); |
| 454 | Revivable.clear(); |
| 455 | this->SupportRollback = SupportRollback; |
| 456 | if (Regions.empty()) |
| 457 | return false; |
| 458 | |
| 459 | /// Maps all MIs to their parent region. Region terminators are considered |
| 460 | /// part of the region they terminate. |
| 461 | DenseMap<MachineInstr *, unsigned> MIRegion; |
| 462 | |
| 463 | // Initialize MI to containing region mapping. |
| 464 | RegionMBB.reserve(N: Regions.size()); |
| 465 | for (unsigned I = 0, E = Regions.size(); I < E; ++I) { |
| 466 | RegionBoundaries Region = Regions[I]; |
| 467 | assert(Region.first != Region.second && "empty cannot be region" ); |
| 468 | for (auto MI = Region.first; MI != Region.second; ++MI) { |
| 469 | assert(!MIRegion.contains(&*MI) && "regions should not intersect" ); |
| 470 | MIRegion.insert(KV: {&*MI, I}); |
| 471 | } |
| 472 | MachineBasicBlock &MBB = *Region.first->getParent(); |
| 473 | RegionMBB.push_back(Elt: &MBB); |
| 474 | |
| 475 | // A terminator instruction is considered part of the region it terminates. |
| 476 | if (Region.second != MBB.end()) { |
| 477 | MachineInstr *RegionTerm = &*Region.second; |
| 478 | assert(!MIRegion.contains(RegionTerm) && "regions should not intersect" ); |
| 479 | MIRegion.insert(KV: {RegionTerm, I}); |
| 480 | } |
| 481 | } |
| 482 | |
| 483 | const unsigned NumVirtRegs = MRI.getNumVirtRegs(); |
| 484 | BitVector SeenRegs(NumVirtRegs); |
| 485 | for (unsigned I = 0, E = NumVirtRegs; I != E; ++I) { |
| 486 | if (!SeenRegs[I]) |
| 487 | addRegIfRematerializable(VirtRegIdx: I, MIRegion, SeenRegs); |
| 488 | } |
| 489 | assert(Regs.size() == UnrematableOprds.size()); |
| 490 | |
| 491 | LLVM_DEBUG({ |
| 492 | for (RegisterIdx I = 0, E = getNumRegs(); I < E; ++I) |
| 493 | dbgs() << printDependencyDAG(I) << '\n'; |
| 494 | }); |
| 495 | return !Regs.empty(); |
| 496 | } |
| 497 | |
| 498 | void Rematerializer::addRegIfRematerializable( |
| 499 | unsigned VirtRegIdx, const DenseMap<MachineInstr *, unsigned> &MIRegion, |
| 500 | BitVector &SeenRegs) { |
| 501 | assert(!SeenRegs[VirtRegIdx] && "register already seen" ); |
| 502 | Register DefReg = Register::index2VirtReg(Index: VirtRegIdx); |
| 503 | SeenRegs.set(VirtRegIdx); |
| 504 | |
| 505 | MachineOperand *MO = MRI.getOneDef(Reg: DefReg); |
| 506 | if (!MO) |
| 507 | return; |
| 508 | MachineInstr &DefMI = *MO->getParent(); |
| 509 | if (!isMIRematerializable(MI: DefMI)) |
| 510 | return; |
| 511 | auto DefRegion = MIRegion.find(Val: &DefMI); |
| 512 | if (DefRegion == MIRegion.end()) |
| 513 | return; |
| 514 | |
| 515 | Reg RematReg; |
| 516 | RematReg.DefMI = &DefMI; |
| 517 | RematReg.DefRegion = DefRegion->second; |
| 518 | unsigned SubIdx = DefMI.getOperand(i: 0).getSubReg(); |
| 519 | RematReg.Mask = SubIdx ? TRI.getSubRegIndexLaneMask(SubIdx) |
| 520 | : MRI.getMaxLaneMaskForVReg(Reg: DefReg); |
| 521 | |
| 522 | // Collect the candidate's direct users, both rematerializable and |
| 523 | // unrematerializable. MIs outside provided regions cannot be tracked so the |
| 524 | // registers they use are not safely rematerializable. |
| 525 | for (MachineInstr &UseMI : MRI.use_nodbg_instructions(Reg: DefReg)) { |
| 526 | if (auto UseRegion = MIRegion.find(Val: &UseMI); UseRegion != MIRegion.end()) |
| 527 | RematReg.addUser(MI: &UseMI, Region: UseRegion->second); |
| 528 | else |
| 529 | return; |
| 530 | } |
| 531 | if (RematReg.Uses.empty()) |
| 532 | return; |
| 533 | |
| 534 | // Collect the candidate's dependencies. If the same register is used |
| 535 | // multiple times we just need to consider it once. |
| 536 | SmallDenseSet<Register, 4> AllDepRegs; |
| 537 | SmallVector<unsigned, 2> UnrematDeps; |
| 538 | for (const auto &[MOIdx, MO] : enumerate(First: RematReg.DefMI->operands())) { |
| 539 | Register DepReg = getRegDependency(MO); |
| 540 | if (!DepReg || !AllDepRegs.insert(V: DepReg).second) |
| 541 | continue; |
| 542 | unsigned DepRegIdx = DepReg.virtRegIndex(); |
| 543 | if (!SeenRegs[DepRegIdx]) |
| 544 | addRegIfRematerializable(VirtRegIdx: DepRegIdx, MIRegion, SeenRegs); |
| 545 | if (auto DepIt = RegToIdx.find(Val: DepReg); DepIt != RegToIdx.end()) |
| 546 | RematReg.Dependencies.push_back(Elt: Reg::Dependency(MOIdx, DepIt->second)); |
| 547 | else |
| 548 | UnrematDeps.push_back(Elt: MOIdx); |
| 549 | } |
| 550 | |
| 551 | // The register is rematerializable. |
| 552 | RegToIdx.insert(KV: {DefReg, Regs.size()}); |
| 553 | Regs.push_back(Elt: RematReg); |
| 554 | UnrematableOprds.push_back(Elt: UnrematDeps); |
| 555 | } |
| 556 | |
| 557 | bool Rematerializer::isMIRematerializable(const MachineInstr &MI) const { |
| 558 | if (!TII.isReMaterializable(MI)) |
| 559 | return false; |
| 560 | |
| 561 | assert(MI.getOperand(0).getReg().isVirtual() && "should be virtual" ); |
| 562 | assert(MRI.hasOneDef(MI.getOperand(0).getReg()) && "should have single def" ); |
| 563 | |
| 564 | for (const MachineOperand &MO : MI.all_uses()) { |
| 565 | // We can't remat physreg uses, unless it is a constant or an ignorable |
| 566 | // use (e.g. implicit exec use on VALU instructions) |
| 567 | if (MO.getReg().isPhysical()) { |
| 568 | if (MRI.isConstantPhysReg(PhysReg: MO.getReg()) || TII.isIgnorableUse(MO)) |
| 569 | continue; |
| 570 | return false; |
| 571 | } |
| 572 | } |
| 573 | |
| 574 | return true; |
| 575 | } |
| 576 | |
| 577 | RegisterIdx Rematerializer::getDefRegIdx(const MachineInstr &MI) const { |
| 578 | if (!MI.getNumOperands() || !MI.getOperand(i: 0).isReg() || |
| 579 | MI.getOperand(i: 0).readsReg()) |
| 580 | return NoReg; |
| 581 | Register Reg = MI.getOperand(i: 0).getReg(); |
| 582 | auto UserRegIt = RegToIdx.find(Val: Reg); |
| 583 | if (UserRegIt == RegToIdx.end()) |
| 584 | return NoReg; |
| 585 | return UserRegIt->second; |
| 586 | } |
| 587 | |
| 588 | RegisterIdx Rematerializer::rematerializeReg( |
| 589 | RegisterIdx RegIdx, unsigned UseRegion, |
| 590 | MachineBasicBlock::iterator InsertPos, |
| 591 | SmallVectorImpl<Reg::Dependency> &&Dependencies) { |
| 592 | RegisterIdx NewRegIdx = Regs.size(); |
| 593 | |
| 594 | Reg &NewReg = Regs.emplace_back(); |
| 595 | Reg &FromReg = Regs[RegIdx]; |
| 596 | NewReg.Mask = FromReg.Mask; |
| 597 | NewReg.DefRegion = UseRegion; |
| 598 | NewReg.Dependencies = std::move(Dependencies); |
| 599 | |
| 600 | // Track rematerialization link between registers. Origins are always |
| 601 | // registers that existed originally, and rematerializations are always |
| 602 | // attached to them. |
| 603 | const RegisterIdx OriginIdx = getOriginOrSelf(RegIdx); |
| 604 | Origins.push_back(Elt: OriginIdx); |
| 605 | Rematerializations[OriginIdx].insert(V: NewRegIdx); |
| 606 | |
| 607 | // Use the TII to rematerialize the defining instruction with a new defined |
| 608 | // register. |
| 609 | Register NewDefReg = MRI.cloneVirtualRegister(VReg: FromReg.getDefReg()); |
| 610 | TII.reMaterialize(MBB&: *RegionMBB[UseRegion], MI: InsertPos, DestReg: NewDefReg, SubIdx: 0, |
| 611 | Orig: *FromReg.DefMI); |
| 612 | NewReg.DefMI = &*std::prev(x: InsertPos); |
| 613 | RegToIdx.insert(KV: {NewDefReg, NewRegIdx}); |
| 614 | |
| 615 | // Update the DAG. |
| 616 | RegionBoundaries &Bounds = Regions[UseRegion]; |
| 617 | if (Bounds.first == std::next(x: MachineBasicBlock::iterator(NewReg.DefMI))) |
| 618 | Bounds.first = NewReg.DefMI; |
| 619 | LIS.InsertMachineInstrInMaps(MI&: *NewReg.DefMI); |
| 620 | LISUpdates.insert(V: NewRegIdx); |
| 621 | |
| 622 | // Replace dependencies as needed in the rematerialized MI. All dependencies |
| 623 | // of the latter gain a new user. |
| 624 | auto ZipedDeps = zip_equal(t&: FromReg.Dependencies, u&: NewReg.Dependencies); |
| 625 | for (const auto &[OldDep, NewDep] : ZipedDeps) { |
| 626 | assert(OldDep.MOIdx == NewDep.MOIdx && "operand mismatch" ); |
| 627 | LLVM_DEBUG(dbgs() << " Operand #" << OldDep.MOIdx << ": " |
| 628 | << printID(OldDep.RegIdx) << " -> " |
| 629 | << printID(NewDep.RegIdx) << '\n'); |
| 630 | |
| 631 | Reg &NewDepReg = Regs[NewDep.RegIdx]; |
| 632 | if (OldDep.RegIdx != NewDep.RegIdx) { |
| 633 | Register OldDefReg = FromReg.DefMI->getOperand(i: OldDep.MOIdx).getReg(); |
| 634 | NewReg.DefMI->substituteRegister(FromReg: OldDefReg, ToReg: NewDepReg.getDefReg(), SubIdx: 0, |
| 635 | RegInfo: TRI); |
| 636 | LISUpdates.insert(V: OldDep.RegIdx); |
| 637 | } |
| 638 | NewDepReg.addUser(MI: NewReg.DefMI, Region: UseRegion); |
| 639 | LISUpdates.insert(V: NewDep.RegIdx); |
| 640 | } |
| 641 | |
| 642 | noteRegCreated(RegIdx: NewRegIdx); |
| 643 | |
| 644 | LLVM_DEBUG({ |
| 645 | dbgs() << "** Rematerialized " << printID(RegIdx) << " as " |
| 646 | << printRematReg(NewRegIdx) << '\n'; |
| 647 | }); |
| 648 | return NewRegIdx; |
| 649 | } |
| 650 | |
| 651 | std::pair<MachineInstr *, MachineInstr *> |
| 652 | Rematerializer::Reg::getRegionUseBounds(unsigned UseRegion, |
| 653 | const LiveIntervals &LIS) const { |
| 654 | auto It = Uses.find(Val: UseRegion); |
| 655 | if (It == Uses.end()) |
| 656 | return {nullptr, nullptr}; |
| 657 | const RegionUsers &RegionUsers = It->getSecond(); |
| 658 | assert(!RegionUsers.empty() && "empty userset in region" ); |
| 659 | |
| 660 | auto User = RegionUsers.begin(), UserEnd = RegionUsers.end(); |
| 661 | MachineInstr *FirstMI = *User, *LastMI = FirstMI; |
| 662 | SlotIndex FirstIndex = LIS.getInstructionIndex(Instr: *FirstMI), |
| 663 | LastIndex = FirstIndex; |
| 664 | |
| 665 | while (++User != UserEnd) { |
| 666 | SlotIndex UserIndex = LIS.getInstructionIndex(Instr: **User); |
| 667 | if (UserIndex < FirstIndex) { |
| 668 | FirstIndex = UserIndex; |
| 669 | FirstMI = *User; |
| 670 | } else if (UserIndex > LastIndex) { |
| 671 | LastIndex = UserIndex; |
| 672 | LastMI = *User; |
| 673 | } |
| 674 | } |
| 675 | |
| 676 | return {FirstMI, LastMI}; |
| 677 | } |
| 678 | |
| 679 | void Rematerializer::Reg::addUser(MachineInstr *MI, unsigned Region) { |
| 680 | Uses[Region].insert(V: MI); |
| 681 | } |
| 682 | |
| 683 | void Rematerializer::Reg::addUsers(const RegionUsers &NewUsers, |
| 684 | unsigned Region) { |
| 685 | Uses[Region].insert_range(R: NewUsers); |
| 686 | } |
| 687 | |
| 688 | void Rematerializer::Reg::eraseUser(MachineInstr *MI, unsigned Region) { |
| 689 | RegionUsers &RUsers = Uses.at(Val: Region); |
| 690 | assert(RUsers.contains(MI) && "user not in region" ); |
| 691 | if (RUsers.size() == 1) |
| 692 | Uses.erase(Val: Region); |
| 693 | else |
| 694 | RUsers.erase(V: MI); |
| 695 | } |
| 696 | |
| 697 | Printable Rematerializer::printDependencyDAG(RegisterIdx RootIdx) const { |
| 698 | return Printable([&, RootIdx](raw_ostream &OS) { |
| 699 | DenseMap<RegisterIdx, unsigned> RegDepths; |
| 700 | std::function<void(RegisterIdx, unsigned)> WalkTree = |
| 701 | [&](RegisterIdx RegIdx, unsigned Depth) -> void { |
| 702 | unsigned MaxDepth = std::max(a: RegDepths.lookup_or(Val: RegIdx, Default&: Depth), b: Depth); |
| 703 | RegDepths.emplace_or_assign(Key: RegIdx, Args&: MaxDepth); |
| 704 | for (const Reg::Dependency &Dep : getReg(RegIdx).Dependencies) |
| 705 | WalkTree(Dep.RegIdx, Depth + 1); |
| 706 | }; |
| 707 | WalkTree(RootIdx, 0); |
| 708 | |
| 709 | // Sort in decreasing depth order to print root at the bottom. |
| 710 | SmallVector<std::pair<RegisterIdx, unsigned>> Regs(RegDepths.begin(), |
| 711 | RegDepths.end()); |
| 712 | sort(C&: Regs, Comp: [](const auto &LHS, const auto &RHS) { |
| 713 | return LHS.second > RHS.second; |
| 714 | }); |
| 715 | |
| 716 | OS << printID(RegIdx: RootIdx) << " has " << Regs.size() - 1 << " dependencies\n" ; |
| 717 | for (const auto &[RegIdx, Depth] : Regs) { |
| 718 | OS << indent(Depth, 2) << (Depth ? '|' : '*') << ' ' |
| 719 | << printRematReg(RegIdx, /*SkipRegions=*/Depth) << '\n'; |
| 720 | } |
| 721 | OS << printRegUsers(RegIdx: RootIdx); |
| 722 | }); |
| 723 | } |
| 724 | |
| 725 | Printable Rematerializer::printID(RegisterIdx RegIdx) const { |
| 726 | return Printable([&, RegIdx](raw_ostream &OS) { |
| 727 | const Reg &PrintReg = getReg(RegIdx); |
| 728 | OS << '(' << RegIdx << '/'; |
| 729 | if (!PrintReg.DefMI) { |
| 730 | OS << "<dead>" ; |
| 731 | } else { |
| 732 | OS << printReg(Reg: PrintReg.getDefReg(), TRI: &TRI, |
| 733 | SubIdx: PrintReg.DefMI->getOperand(i: 0).getSubReg(), MRI: &MRI); |
| 734 | } |
| 735 | OS << ")[" << PrintReg.DefRegion << "]" ; |
| 736 | }); |
| 737 | } |
| 738 | |
| 739 | Printable Rematerializer::printRematReg(RegisterIdx RegIdx, |
| 740 | bool SkipRegions) const { |
| 741 | return Printable([&, RegIdx, SkipRegions](raw_ostream &OS) { |
| 742 | const Reg &PrintReg = getReg(RegIdx); |
| 743 | if (!SkipRegions) { |
| 744 | OS << printID(RegIdx) << " [" << PrintReg.DefRegion; |
| 745 | if (!PrintReg.Uses.empty()) { |
| 746 | assert(PrintReg.DefMI && "dead register cannot have uses" ); |
| 747 | const LiveInterval &LI = LIS.getInterval(Reg: PrintReg.getDefReg()); |
| 748 | // First display all regions in which the register is live-through and |
| 749 | // not used. |
| 750 | bool First = true; |
| 751 | for (const auto [I, Bounds] : enumerate(First&: Regions)) { |
| 752 | if (Bounds.first == Bounds.second) |
| 753 | continue; |
| 754 | if (!PrintReg.Uses.contains(Val: I) && |
| 755 | LI.liveAt(index: LIS.getInstructionIndex(Instr: *Bounds.first)) && |
| 756 | LI.liveAt(index: LIS.getInstructionIndex(Instr: *std::prev(x: Bounds.second)) |
| 757 | .getRegSlot())) { |
| 758 | OS << (First ? " - " : "," ) << I; |
| 759 | First = false; |
| 760 | } |
| 761 | } |
| 762 | OS << (First ? " --> " : " -> " ); |
| 763 | |
| 764 | // Then display regions in which the register is used. |
| 765 | auto It = PrintReg.Uses.begin(); |
| 766 | OS << It->first; |
| 767 | while (++It != PrintReg.Uses.end()) |
| 768 | OS << "," << It->first; |
| 769 | } |
| 770 | OS << "] " ; |
| 771 | } |
| 772 | OS << printID(RegIdx) << ' '; |
| 773 | PrintReg.DefMI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, |
| 774 | /*SkipDebugLoc=*/false, /*AddNewLine=*/false); |
| 775 | OS << " @ " ; |
| 776 | LIS.getInstructionIndex(Instr: *PrintReg.DefMI).print(os&: OS); |
| 777 | }); |
| 778 | } |
| 779 | |
| 780 | Printable Rematerializer::printRegUsers(RegisterIdx RegIdx) const { |
| 781 | return Printable([&, RegIdx](raw_ostream &OS) { |
| 782 | for (const auto &[UseRegion, Users] : getReg(RegIdx).Uses) { |
| 783 | for (MachineInstr *MI : Users) |
| 784 | OS << " User " << printUser(MI, UseRegion) << '\n'; |
| 785 | } |
| 786 | }); |
| 787 | } |
| 788 | |
| 789 | Printable Rematerializer::printUser(const MachineInstr *MI, |
| 790 | std::optional<unsigned> UseRegion) const { |
| 791 | return Printable([&, MI, UseRegion](raw_ostream &OS) { |
| 792 | RegisterIdx RegIdx = getDefRegIdx(MI: *MI); |
| 793 | if (RegIdx != NoReg) { |
| 794 | OS << printID(RegIdx); |
| 795 | } else { |
| 796 | OS << "(-/-)[" ; |
| 797 | if (UseRegion) |
| 798 | OS << *UseRegion; |
| 799 | else |
| 800 | OS << '?'; |
| 801 | OS << ']'; |
| 802 | } |
| 803 | OS << ' '; |
| 804 | MI->print(OS, /*IsStandalone=*/true, /*SkipOpers=*/false, |
| 805 | /*SkipDebugLoc=*/false, /*AddNewLine=*/false); |
| 806 | OS << " @ " ; |
| 807 | LIS.getInstructionIndex(Instr: *MI).print(os&: OS); |
| 808 | }); |
| 809 | } |
| 810 | |