| 1 | //===- Tracker.cpp --------------------------------------------------------===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | |
| 9 | #include "llvm/SandboxIR/Tracker.h" |
| 10 | #include "llvm/ADT/STLExtras.h" |
| 11 | #include "llvm/IR/BasicBlock.h" |
| 12 | #include "llvm/IR/Instruction.h" |
| 13 | #include "llvm/IR/StructuralHash.h" |
| 14 | #include "llvm/SandboxIR/Instruction.h" |
| 15 | |
| 16 | using namespace llvm::sandboxir; |
| 17 | |
| 18 | #ifndef NDEBUG |
| 19 | |
| 20 | std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const { |
| 21 | std::string Result; |
| 22 | raw_string_ostream SS(Result); |
| 23 | F.print(SS, /*AssemblyAnnotationWriter=*/nullptr); |
| 24 | return Result; |
| 25 | } |
| 26 | |
| 27 | IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const { |
| 28 | ContextSnapshot Result; |
| 29 | for (const auto &Entry : Ctx.LLVMModuleToModuleMap) |
| 30 | for (const auto &F : *Entry.first) { |
| 31 | FunctionSnapshot Snapshot; |
| 32 | Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true); |
| 33 | Snapshot.TextualIR = dumpIR(F); |
| 34 | Result[&F] = Snapshot; |
| 35 | } |
| 36 | return Result; |
| 37 | } |
| 38 | |
| 39 | bool IRSnapshotChecker::diff(const ContextSnapshot &Orig, |
| 40 | const ContextSnapshot &Curr) const { |
| 41 | bool DifferenceFound = false; |
| 42 | for (const auto &[F, OrigFS] : Orig) { |
| 43 | auto CurrFSIt = Curr.find(F); |
| 44 | if (CurrFSIt == Curr.end()) { |
| 45 | DifferenceFound = true; |
| 46 | dbgs() << "Function " << F->getName() << " not found in current IR.\n" ; |
| 47 | dbgs() << OrigFS.TextualIR << "\n" ; |
| 48 | continue; |
| 49 | } |
| 50 | const FunctionSnapshot &CurrFS = CurrFSIt->second; |
| 51 | if (OrigFS.Hash != CurrFS.Hash) { |
| 52 | DifferenceFound = true; |
| 53 | dbgs() << "Found IR difference in Function " << F->getName() << "\n" ; |
| 54 | dbgs() << "Original:\n" << OrigFS.TextualIR << "\n" ; |
| 55 | dbgs() << "Current:\n" << CurrFS.TextualIR << "\n" ; |
| 56 | } |
| 57 | } |
| 58 | // Check that Curr doesn't contain any new functions. |
| 59 | for (const auto &[F, CurrFS] : Curr) { |
| 60 | if (!Orig.contains(F)) { |
| 61 | DifferenceFound = true; |
| 62 | dbgs() << "Function " << F->getName() |
| 63 | << " found in current IR but not in original snapshot.\n" ; |
| 64 | dbgs() << CurrFS.TextualIR << "\n" ; |
| 65 | } |
| 66 | } |
| 67 | return DifferenceFound; |
| 68 | } |
| 69 | |
| 70 | void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); } |
| 71 | |
| 72 | void IRSnapshotChecker::expectNoDiff() { |
| 73 | ContextSnapshot CurrContextSnapshot = takeSnapshot(); |
| 74 | if (diff(OrigContextSnapshot, CurrContextSnapshot)) { |
| 75 | llvm_unreachable( |
| 76 | "Original and current IR differ! Probably a checkpointing bug." ); |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | void UseSet::dump() const { |
| 81 | dump(dbgs()); |
| 82 | dbgs() << "\n" ; |
| 83 | } |
| 84 | |
| 85 | void UseSwap::dump() const { |
| 86 | dump(dbgs()); |
| 87 | dbgs() << "\n" ; |
| 88 | } |
| 89 | #endif // NDEBUG |
| 90 | |
| 91 | PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx) |
| 92 | : PHI(PHI), RemovedIdx(RemovedIdx) { |
| 93 | RemovedV = PHI->getIncomingValue(Idx: RemovedIdx); |
| 94 | RemovedBB = PHI->getIncomingBlock(Idx: RemovedIdx); |
| 95 | } |
| 96 | |
| 97 | void PHIRemoveIncoming::revert(Tracker &Tracker) { |
| 98 | // Special case: if the removed incoming value is the last. |
| 99 | unsigned NumIncoming = PHI->getNumIncomingValues(); |
| 100 | if (NumIncoming == RemovedIdx) { |
| 101 | PHI->addIncoming(V: RemovedV, BB: RemovedBB); |
| 102 | return; |
| 103 | } |
| 104 | // Move the incoming value currently at `RemovedIdx` to the end, restore the |
| 105 | // old incoming value back to `RemovedIdx`. |
| 106 | PHI->addIncoming(V: PHI->getIncomingValue(Idx: RemovedIdx), |
| 107 | BB: PHI->getIncomingBlock(Idx: RemovedIdx)); |
| 108 | PHI->setIncomingValue(Idx: RemovedIdx, V: RemovedV); |
| 109 | PHI->setIncomingBlock(Idx: RemovedIdx, BB: RemovedBB); |
| 110 | } |
| 111 | |
| 112 | #ifndef NDEBUG |
| 113 | void PHIRemoveIncoming::dump() const { |
| 114 | dump(dbgs()); |
| 115 | dbgs() << "\n" ; |
| 116 | } |
| 117 | #endif // NDEBUG |
| 118 | |
| 119 | PHIAddIncoming::PHIAddIncoming(PHINode *PHI) |
| 120 | : PHI(PHI), Idx(PHI->getNumIncomingValues()) {} |
| 121 | |
| 122 | void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); } |
| 123 | |
| 124 | #ifndef NDEBUG |
| 125 | void PHIAddIncoming::dump() const { |
| 126 | dump(dbgs()); |
| 127 | dbgs() << "\n" ; |
| 128 | } |
| 129 | #endif // NDEBUG |
| 130 | |
| 131 | Tracker::~Tracker() { |
| 132 | assert(Changes.empty() && "You must accept or revert changes!" ); |
| 133 | } |
| 134 | |
| 135 | EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr) |
| 136 | : ErasedIPtr(std::move(ErasedIPtr)) { |
| 137 | auto *I = cast<Instruction>(Val: this->ErasedIPtr.get()); |
| 138 | auto LLVMInstrs = I->getLLVMInstrs(); |
| 139 | // Iterate in reverse program order. |
| 140 | for (auto *LLVMI : reverse(C&: LLVMInstrs)) { |
| 141 | SmallVector<llvm::Value *> Operands; |
| 142 | Operands.reserve(N: LLVMI->getNumOperands()); |
| 143 | for (auto [OpNum, Use] : enumerate(First: LLVMI->operands())) |
| 144 | Operands.push_back(Elt: Use.get()); |
| 145 | InstrData.push_back(Elt: {.Operands: Operands, .LLVMI: LLVMI}); |
| 146 | } |
| 147 | assert(is_sorted(InstrData, |
| 148 | [](const auto &D0, const auto &D1) { |
| 149 | return D0.LLVMI->comesBefore(D1.LLVMI); |
| 150 | }) && |
| 151 | "Expected reverse program order!" ); |
| 152 | auto *BotLLVMI = cast<llvm::Instruction>(Val: I->Val); |
| 153 | if (BotLLVMI->getNextNode() != nullptr) |
| 154 | NextLLVMIOrBB = BotLLVMI->getNextNode(); |
| 155 | else |
| 156 | NextLLVMIOrBB = BotLLVMI->getParent(); |
| 157 | } |
| 158 | |
| 159 | void EraseFromParent::accept() { |
| 160 | for (const auto &IData : InstrData) |
| 161 | IData.LLVMI->deleteValue(); |
| 162 | } |
| 163 | |
| 164 | void EraseFromParent::revert(Tracker &Tracker) { |
| 165 | // Place the bottom-most instruction first. |
| 166 | auto [Operands, BotLLVMI] = InstrData[0]; |
| 167 | if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(Val&: NextLLVMIOrBB)) { |
| 168 | BotLLVMI->insertBefore(InsertPos: NextLLVMI->getIterator()); |
| 169 | } else { |
| 170 | auto *LLVMBB = cast<llvm::BasicBlock *>(Val&: NextLLVMIOrBB); |
| 171 | BotLLVMI->insertInto(ParentBB: LLVMBB, It: LLVMBB->end()); |
| 172 | } |
| 173 | for (auto [OpNum, Op] : enumerate(First&: Operands)) |
| 174 | BotLLVMI->setOperand(i: OpNum, Val: Op); |
| 175 | |
| 176 | // Go over the rest of the instructions and stack them on top. |
| 177 | for (auto [Operands, LLVMI] : drop_begin(RangeOrContainer&: InstrData)) { |
| 178 | LLVMI->insertBefore(InsertPos: BotLLVMI->getIterator()); |
| 179 | for (auto [OpNum, Op] : enumerate(First&: Operands)) |
| 180 | LLVMI->setOperand(i: OpNum, Val: Op); |
| 181 | BotLLVMI = LLVMI; |
| 182 | } |
| 183 | Tracker.getContext().registerValue(VPtr: std::move(ErasedIPtr)); |
| 184 | } |
| 185 | |
| 186 | #ifndef NDEBUG |
| 187 | void EraseFromParent::dump() const { |
| 188 | dump(dbgs()); |
| 189 | dbgs() << "\n" ; |
| 190 | } |
| 191 | #endif // NDEBUG |
| 192 | |
| 193 | RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) { |
| 194 | if (auto *NextI = RemovedI->getNextNode()) |
| 195 | NextInstrOrBB = NextI; |
| 196 | else |
| 197 | NextInstrOrBB = RemovedI->getParent(); |
| 198 | } |
| 199 | |
| 200 | void RemoveFromParent::revert(Tracker &Tracker) { |
| 201 | if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) { |
| 202 | RemovedI->insertBefore(BeforeI: NextI); |
| 203 | } else { |
| 204 | auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB); |
| 205 | RemovedI->insertInto(BB, WhereIt: BB->end()); |
| 206 | } |
| 207 | } |
| 208 | |
| 209 | #ifndef NDEBUG |
| 210 | void RemoveFromParent::dump() const { |
| 211 | dump(dbgs()); |
| 212 | dbgs() << "\n" ; |
| 213 | } |
| 214 | #endif |
| 215 | |
| 216 | CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI) |
| 217 | : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {} |
| 218 | |
| 219 | void CatchSwitchAddHandler::revert(Tracker &Tracker) { |
| 220 | // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler() |
| 221 | // once it gets implemented. |
| 222 | auto *LLVMCSI = cast<llvm::CatchSwitchInst>(Val: CSI->Val); |
| 223 | LLVMCSI->removeHandler(HI: LLVMCSI->handler_begin() + HandlerIdx); |
| 224 | } |
| 225 | |
| 226 | SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { |
| 227 | for (const auto &C : Switch->cases()) |
| 228 | Cases.push_back(Elt: {.Val: C.getCaseValue(), .Dest: C.getCaseSuccessor()}); |
| 229 | } |
| 230 | |
| 231 | void SwitchRemoveCase::revert(Tracker &Tracker) { |
| 232 | // SwitchInst::removeCase doesn't provide any guarantees about the order of |
| 233 | // cases after removal. In order to preserve the original ordering, we save |
| 234 | // all of them and, when reverting, clear them all then insert them in the |
| 235 | // desired order. This still relies on the fact that `addCase` will insert |
| 236 | // them at the end, but it is documented to invalidate `case_end()` so it's |
| 237 | // probably okay. |
| 238 | unsigned NumCases = Switch->getNumCases(); |
| 239 | for (unsigned I = 0; I < NumCases; ++I) |
| 240 | Switch->removeCase(It: Switch->case_begin()); |
| 241 | for (auto &Case : Cases) |
| 242 | Switch->addCase(OnVal: Case.Val, Dest: Case.Dest); |
| 243 | } |
| 244 | |
| 245 | #ifndef NDEBUG |
| 246 | void SwitchRemoveCase::dump() const { |
| 247 | dump(dbgs()); |
| 248 | dbgs() << "\n" ; |
| 249 | } |
| 250 | #endif // NDEBUG |
| 251 | |
| 252 | void SwitchAddCase::revert(Tracker &Tracker) { |
| 253 | auto It = Switch->findCaseValue(C: Val); |
| 254 | Switch->removeCase(It); |
| 255 | } |
| 256 | |
| 257 | #ifndef NDEBUG |
| 258 | void SwitchAddCase::dump() const { |
| 259 | dump(dbgs()); |
| 260 | dbgs() << "\n" ; |
| 261 | } |
| 262 | #endif // NDEBUG |
| 263 | |
| 264 | MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) { |
| 265 | if (auto *NextI = MovedI->getNextNode()) |
| 266 | NextInstrOrBB = NextI; |
| 267 | else |
| 268 | NextInstrOrBB = MovedI->getParent(); |
| 269 | } |
| 270 | |
| 271 | void MoveInstr::revert(Tracker &Tracker) { |
| 272 | if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) { |
| 273 | MovedI->moveBefore(Before: NextI); |
| 274 | } else { |
| 275 | auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB); |
| 276 | MovedI->moveBefore(BB&: *BB, WhereIt: BB->end()); |
| 277 | } |
| 278 | } |
| 279 | |
| 280 | #ifndef NDEBUG |
| 281 | void MoveInstr::dump() const { |
| 282 | dump(dbgs()); |
| 283 | dbgs() << "\n" ; |
| 284 | } |
| 285 | #endif |
| 286 | |
| 287 | void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); } |
| 288 | |
| 289 | InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {} |
| 290 | |
| 291 | #ifndef NDEBUG |
| 292 | void InsertIntoBB::dump() const { |
| 293 | dump(dbgs()); |
| 294 | dbgs() << "\n" ; |
| 295 | } |
| 296 | #endif |
| 297 | |
| 298 | void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); } |
| 299 | |
| 300 | #ifndef NDEBUG |
| 301 | void CreateAndInsertInst::dump() const { |
| 302 | dump(dbgs()); |
| 303 | dbgs() << "\n" ; |
| 304 | } |
| 305 | #endif |
| 306 | |
| 307 | ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI) |
| 308 | : SVI(SVI), PrevMask(SVI->getShuffleMask()) {} |
| 309 | |
| 310 | void ShuffleVectorSetMask::revert(Tracker &Tracker) { |
| 311 | SVI->setShuffleMask(PrevMask); |
| 312 | } |
| 313 | |
| 314 | #ifndef NDEBUG |
| 315 | void ShuffleVectorSetMask::dump() const { |
| 316 | dump(dbgs()); |
| 317 | dbgs() << "\n" ; |
| 318 | } |
| 319 | #endif |
| 320 | |
| 321 | CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {} |
| 322 | |
| 323 | void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); } |
| 324 | #ifndef NDEBUG |
| 325 | void CmpSwapOperands::dump() const { |
| 326 | dump(dbgs()); |
| 327 | dbgs() << "\n" ; |
| 328 | } |
| 329 | #endif |
| 330 | |
| 331 | void Tracker::save() { |
| 332 | State = TrackerState::Record; |
| 333 | // Record the last index in `Changes` that we will revert. |
| 334 | Snapshots.push_back(Elt: Changes.size()); |
| 335 | #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
| 336 | SnapshotChecker.emplace_back(Ctx); |
| 337 | SnapshotChecker.back().save(); |
| 338 | #endif |
| 339 | } |
| 340 | |
| 341 | void Tracker::revert(bool RevertAll) { |
| 342 | assert(State == TrackerState::Record && "Forgot to save()!" ); |
| 343 | State = TrackerState::Reverting; |
| 344 | unsigned UntilChangeIdx = RevertAll ? 0 : Snapshots.back(); |
| 345 | const unsigned ToRevert = Changes.size() - UntilChangeIdx; |
| 346 | unsigned CntReverts = 0; |
| 347 | for (auto &Change : reverse(C&: Changes)) { |
| 348 | // Stop reverting if we reach the index of the last snapshot. |
| 349 | if (CntReverts++ == ToRevert) |
| 350 | break; |
| 351 | Change->revert(Tracker&: *this); |
| 352 | #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
| 353 | // There may be multiple changes between snapshots, so use the snapshot |
| 354 | // checker only if this change has an associated snapshot. |
| 355 | unsigned ChangeIdx = Changes.size() - CntReverts; |
| 356 | bool ChangeHasSnapshot = ChangeIdx == Snapshots.back(); |
| 357 | if (ChangeHasSnapshot) { |
| 358 | SnapshotChecker.back().expectNoDiff(); |
| 359 | SnapshotChecker.pop_back(); |
| 360 | } |
| 361 | #endif |
| 362 | } |
| 363 | Changes.erase(CS: Changes.end() - ToRevert, CE: Changes.end()); |
| 364 | if (RevertAll) |
| 365 | Snapshots.clear(); |
| 366 | else |
| 367 | Snapshots.pop_back(); |
| 368 | State = Snapshots.empty() ? TrackerState::Disabled : TrackerState::Record; |
| 369 | } |
| 370 | |
| 371 | void Tracker::accept(bool AcceptAll) { |
| 372 | assert(State == TrackerState::Record && "Forgot to save()!" ); |
| 373 | if (!AcceptAll && Snapshots.size() > 1) { |
| 374 | // Just remove the last stacked checkpoint. |
| 375 | Snapshots.pop_back(); |
| 376 | return; |
| 377 | } |
| 378 | State = TrackerState::Disabled; |
| 379 | for (auto &Change : Changes) |
| 380 | Change->accept(); |
| 381 | Changes.clear(); |
| 382 | Snapshots.clear(); |
| 383 | #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
| 384 | SnapshotChecker.clear(); |
| 385 | #endif |
| 386 | } |
| 387 | |
| 388 | #ifndef NDEBUG |
| 389 | void Tracker::dump(raw_ostream &OS) const { |
| 390 | unsigned SnapshotCnt = 0; |
| 391 | for (auto [Idx, ChangePtr] : enumerate(Changes)) { |
| 392 | OS << Idx << ". " ; |
| 393 | ChangePtr->dump(OS); |
| 394 | if (find(Snapshots, Idx) != Snapshots.end()) |
| 395 | OS << " [Snapshot " << SnapshotCnt++ << "]" ; |
| 396 | OS << "\n" ; |
| 397 | } |
| 398 | } |
| 399 | void Tracker::dump() const { |
| 400 | dump(dbgs()); |
| 401 | dbgs() << "\n" ; |
| 402 | } |
| 403 | #endif // NDEBUG |
| 404 | |