1//===- Tracker.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Tracker.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/IR/BasicBlock.h"
12#include "llvm/IR/Instruction.h"
13#include "llvm/IR/StructuralHash.h"
14#include "llvm/SandboxIR/Instruction.h"
15
16using namespace llvm::sandboxir;
17
18#ifndef NDEBUG
19
20std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const {
21 std::string Result;
22 raw_string_ostream SS(Result);
23 F.print(SS, /*AssemblyAnnotationWriter=*/nullptr);
24 return Result;
25}
26
27IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const {
28 ContextSnapshot Result;
29 for (const auto &Entry : Ctx.LLVMModuleToModuleMap)
30 for (const auto &F : *Entry.first) {
31 FunctionSnapshot Snapshot;
32 Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true);
33 Snapshot.TextualIR = dumpIR(F);
34 Result[&F] = Snapshot;
35 }
36 return Result;
37}
38
39bool IRSnapshotChecker::diff(const ContextSnapshot &Orig,
40 const ContextSnapshot &Curr) const {
41 bool DifferenceFound = false;
42 for (const auto &[F, OrigFS] : Orig) {
43 auto CurrFSIt = Curr.find(F);
44 if (CurrFSIt == Curr.end()) {
45 DifferenceFound = true;
46 dbgs() << "Function " << F->getName() << " not found in current IR.\n";
47 dbgs() << OrigFS.TextualIR << "\n";
48 continue;
49 }
50 const FunctionSnapshot &CurrFS = CurrFSIt->second;
51 if (OrigFS.Hash != CurrFS.Hash) {
52 DifferenceFound = true;
53 dbgs() << "Found IR difference in Function " << F->getName() << "\n";
54 dbgs() << "Original:\n" << OrigFS.TextualIR << "\n";
55 dbgs() << "Current:\n" << CurrFS.TextualIR << "\n";
56 }
57 }
58 // Check that Curr doesn't contain any new functions.
59 for (const auto &[F, CurrFS] : Curr) {
60 if (!Orig.contains(F)) {
61 DifferenceFound = true;
62 dbgs() << "Function " << F->getName()
63 << " found in current IR but not in original snapshot.\n";
64 dbgs() << CurrFS.TextualIR << "\n";
65 }
66 }
67 return DifferenceFound;
68}
69
70void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); }
71
72void IRSnapshotChecker::expectNoDiff() {
73 ContextSnapshot CurrContextSnapshot = takeSnapshot();
74 if (diff(OrigContextSnapshot, CurrContextSnapshot)) {
75 llvm_unreachable(
76 "Original and current IR differ! Probably a checkpointing bug.");
77 }
78}
79
80void UseSet::dump() const {
81 dump(dbgs());
82 dbgs() << "\n";
83}
84
85void UseSwap::dump() const {
86 dump(dbgs());
87 dbgs() << "\n";
88}
89#endif // NDEBUG
90
91PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx)
92 : PHI(PHI), RemovedIdx(RemovedIdx) {
93 RemovedV = PHI->getIncomingValue(Idx: RemovedIdx);
94 RemovedBB = PHI->getIncomingBlock(Idx: RemovedIdx);
95}
96
97void PHIRemoveIncoming::revert(Tracker &Tracker) {
98 // Special case: if the removed incoming value is the last.
99 unsigned NumIncoming = PHI->getNumIncomingValues();
100 if (NumIncoming == RemovedIdx) {
101 PHI->addIncoming(V: RemovedV, BB: RemovedBB);
102 return;
103 }
104 // Move the incoming value currently at `RemovedIdx` to the end, restore the
105 // old incoming value back to `RemovedIdx`.
106 PHI->addIncoming(V: PHI->getIncomingValue(Idx: RemovedIdx),
107 BB: PHI->getIncomingBlock(Idx: RemovedIdx));
108 PHI->setIncomingValue(Idx: RemovedIdx, V: RemovedV);
109 PHI->setIncomingBlock(Idx: RemovedIdx, BB: RemovedBB);
110}
111
112#ifndef NDEBUG
113void PHIRemoveIncoming::dump() const {
114 dump(dbgs());
115 dbgs() << "\n";
116}
117#endif // NDEBUG
118
119PHIAddIncoming::PHIAddIncoming(PHINode *PHI)
120 : PHI(PHI), Idx(PHI->getNumIncomingValues()) {}
121
122void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); }
123
124#ifndef NDEBUG
125void PHIAddIncoming::dump() const {
126 dump(dbgs());
127 dbgs() << "\n";
128}
129#endif // NDEBUG
130
131Tracker::~Tracker() {
132 assert(Changes.empty() && "You must accept or revert changes!");
133}
134
135EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr)
136 : ErasedIPtr(std::move(ErasedIPtr)) {
137 auto *I = cast<Instruction>(Val: this->ErasedIPtr.get());
138 auto LLVMInstrs = I->getLLVMInstrs();
139 // Iterate in reverse program order.
140 for (auto *LLVMI : reverse(C&: LLVMInstrs)) {
141 SmallVector<llvm::Value *> Operands;
142 Operands.reserve(N: LLVMI->getNumOperands());
143 for (auto [OpNum, Use] : enumerate(First: LLVMI->operands()))
144 Operands.push_back(Elt: Use.get());
145 InstrData.push_back(Elt: {.Operands: Operands, .LLVMI: LLVMI});
146 }
147 assert(is_sorted(InstrData,
148 [](const auto &D0, const auto &D1) {
149 return D0.LLVMI->comesBefore(D1.LLVMI);
150 }) &&
151 "Expected reverse program order!");
152 auto *BotLLVMI = cast<llvm::Instruction>(Val: I->Val);
153 if (BotLLVMI->getNextNode() != nullptr)
154 NextLLVMIOrBB = BotLLVMI->getNextNode();
155 else
156 NextLLVMIOrBB = BotLLVMI->getParent();
157}
158
159void EraseFromParent::accept() {
160 for (const auto &IData : InstrData)
161 IData.LLVMI->deleteValue();
162}
163
164void EraseFromParent::revert(Tracker &Tracker) {
165 // Place the bottom-most instruction first.
166 auto [Operands, BotLLVMI] = InstrData[0];
167 if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(Val&: NextLLVMIOrBB)) {
168 BotLLVMI->insertBefore(InsertPos: NextLLVMI->getIterator());
169 } else {
170 auto *LLVMBB = cast<llvm::BasicBlock *>(Val&: NextLLVMIOrBB);
171 BotLLVMI->insertInto(ParentBB: LLVMBB, It: LLVMBB->end());
172 }
173 for (auto [OpNum, Op] : enumerate(First&: Operands))
174 BotLLVMI->setOperand(i: OpNum, Val: Op);
175
176 // Go over the rest of the instructions and stack them on top.
177 for (auto [Operands, LLVMI] : drop_begin(RangeOrContainer&: InstrData)) {
178 LLVMI->insertBefore(InsertPos: BotLLVMI->getIterator());
179 for (auto [OpNum, Op] : enumerate(First&: Operands))
180 LLVMI->setOperand(i: OpNum, Val: Op);
181 BotLLVMI = LLVMI;
182 }
183 Tracker.getContext().registerValue(VPtr: std::move(ErasedIPtr));
184}
185
186#ifndef NDEBUG
187void EraseFromParent::dump() const {
188 dump(dbgs());
189 dbgs() << "\n";
190}
191#endif // NDEBUG
192
193RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) {
194 if (auto *NextI = RemovedI->getNextNode())
195 NextInstrOrBB = NextI;
196 else
197 NextInstrOrBB = RemovedI->getParent();
198}
199
200void RemoveFromParent::revert(Tracker &Tracker) {
201 if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) {
202 RemovedI->insertBefore(BeforeI: NextI);
203 } else {
204 auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB);
205 RemovedI->insertInto(BB, WhereIt: BB->end());
206 }
207}
208
209#ifndef NDEBUG
210void RemoveFromParent::dump() const {
211 dump(dbgs());
212 dbgs() << "\n";
213}
214#endif
215
216CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
217 : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
218
219void CatchSwitchAddHandler::revert(Tracker &Tracker) {
220 // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
221 // once it gets implemented.
222 auto *LLVMCSI = cast<llvm::CatchSwitchInst>(Val: CSI->Val);
223 LLVMCSI->removeHandler(HI: LLVMCSI->handler_begin() + HandlerIdx);
224}
225
226SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) {
227 for (const auto &C : Switch->cases())
228 Cases.push_back(Elt: {.Val: C.getCaseValue(), .Dest: C.getCaseSuccessor()});
229}
230
231void SwitchRemoveCase::revert(Tracker &Tracker) {
232 // SwitchInst::removeCase doesn't provide any guarantees about the order of
233 // cases after removal. In order to preserve the original ordering, we save
234 // all of them and, when reverting, clear them all then insert them in the
235 // desired order. This still relies on the fact that `addCase` will insert
236 // them at the end, but it is documented to invalidate `case_end()` so it's
237 // probably okay.
238 unsigned NumCases = Switch->getNumCases();
239 for (unsigned I = 0; I < NumCases; ++I)
240 Switch->removeCase(It: Switch->case_begin());
241 for (auto &Case : Cases)
242 Switch->addCase(OnVal: Case.Val, Dest: Case.Dest);
243}
244
245#ifndef NDEBUG
246void SwitchRemoveCase::dump() const {
247 dump(dbgs());
248 dbgs() << "\n";
249}
250#endif // NDEBUG
251
252void SwitchAddCase::revert(Tracker &Tracker) {
253 auto It = Switch->findCaseValue(C: Val);
254 Switch->removeCase(It);
255}
256
257#ifndef NDEBUG
258void SwitchAddCase::dump() const {
259 dump(dbgs());
260 dbgs() << "\n";
261}
262#endif // NDEBUG
263
264MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
265 if (auto *NextI = MovedI->getNextNode())
266 NextInstrOrBB = NextI;
267 else
268 NextInstrOrBB = MovedI->getParent();
269}
270
271void MoveInstr::revert(Tracker &Tracker) {
272 if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) {
273 MovedI->moveBefore(Before: NextI);
274 } else {
275 auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB);
276 MovedI->moveBefore(BB&: *BB, WhereIt: BB->end());
277 }
278}
279
280#ifndef NDEBUG
281void MoveInstr::dump() const {
282 dump(dbgs());
283 dbgs() << "\n";
284}
285#endif
286
287void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); }
288
289InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {}
290
291#ifndef NDEBUG
292void InsertIntoBB::dump() const {
293 dump(dbgs());
294 dbgs() << "\n";
295}
296#endif
297
298void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); }
299
300#ifndef NDEBUG
301void CreateAndInsertInst::dump() const {
302 dump(dbgs());
303 dbgs() << "\n";
304}
305#endif
306
307ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI)
308 : SVI(SVI), PrevMask(SVI->getShuffleMask()) {}
309
310void ShuffleVectorSetMask::revert(Tracker &Tracker) {
311 SVI->setShuffleMask(PrevMask);
312}
313
314#ifndef NDEBUG
315void ShuffleVectorSetMask::dump() const {
316 dump(dbgs());
317 dbgs() << "\n";
318}
319#endif
320
321CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
322
323void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
324#ifndef NDEBUG
325void CmpSwapOperands::dump() const {
326 dump(dbgs());
327 dbgs() << "\n";
328}
329#endif
330
331void Tracker::save() {
332 State = TrackerState::Record;
333#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
334 SnapshotChecker.save();
335#endif
336}
337
338void Tracker::revert() {
339 assert(State == TrackerState::Record && "Forgot to save()!");
340 State = TrackerState::Reverting;
341 for (auto &Change : reverse(C&: Changes))
342 Change->revert(Tracker&: *this);
343 Changes.clear();
344#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
345 SnapshotChecker.expectNoDiff();
346#endif
347 State = TrackerState::Disabled;
348}
349
350void Tracker::accept() {
351 assert(State == TrackerState::Record && "Forgot to save()!");
352 State = TrackerState::Disabled;
353 for (auto &Change : Changes)
354 Change->accept();
355 Changes.clear();
356}
357
358#ifndef NDEBUG
359void Tracker::dump(raw_ostream &OS) const {
360 for (auto [Idx, ChangePtr] : enumerate(Changes)) {
361 OS << Idx << ". ";
362 ChangePtr->dump(OS);
363 OS << "\n";
364 }
365}
366void Tracker::dump() const {
367 dump(dbgs());
368 dbgs() << "\n";
369}
370#endif // NDEBUG
371