1 | //===- Tracker.cpp --------------------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "llvm/SandboxIR/Tracker.h" |
10 | #include "llvm/ADT/STLExtras.h" |
11 | #include "llvm/IR/BasicBlock.h" |
12 | #include "llvm/IR/Instruction.h" |
13 | #include "llvm/IR/StructuralHash.h" |
14 | #include "llvm/SandboxIR/Instruction.h" |
15 | |
16 | using namespace llvm::sandboxir; |
17 | |
18 | #ifndef NDEBUG |
19 | |
20 | std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const { |
21 | std::string Result; |
22 | raw_string_ostream SS(Result); |
23 | F.print(SS, /*AssemblyAnnotationWriter=*/nullptr); |
24 | return Result; |
25 | } |
26 | |
27 | IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const { |
28 | ContextSnapshot Result; |
29 | for (const auto &Entry : Ctx.LLVMModuleToModuleMap) |
30 | for (const auto &F : *Entry.first) { |
31 | FunctionSnapshot Snapshot; |
32 | Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true); |
33 | Snapshot.TextualIR = dumpIR(F); |
34 | Result[&F] = Snapshot; |
35 | } |
36 | return Result; |
37 | } |
38 | |
39 | bool IRSnapshotChecker::diff(const ContextSnapshot &Orig, |
40 | const ContextSnapshot &Curr) const { |
41 | bool DifferenceFound = false; |
42 | for (const auto &[F, OrigFS] : Orig) { |
43 | auto CurrFSIt = Curr.find(F); |
44 | if (CurrFSIt == Curr.end()) { |
45 | DifferenceFound = true; |
46 | dbgs() << "Function " << F->getName() << " not found in current IR.\n" ; |
47 | dbgs() << OrigFS.TextualIR << "\n" ; |
48 | continue; |
49 | } |
50 | const FunctionSnapshot &CurrFS = CurrFSIt->second; |
51 | if (OrigFS.Hash != CurrFS.Hash) { |
52 | DifferenceFound = true; |
53 | dbgs() << "Found IR difference in Function " << F->getName() << "\n" ; |
54 | dbgs() << "Original:\n" << OrigFS.TextualIR << "\n" ; |
55 | dbgs() << "Current:\n" << CurrFS.TextualIR << "\n" ; |
56 | } |
57 | } |
58 | // Check that Curr doesn't contain any new functions. |
59 | for (const auto &[F, CurrFS] : Curr) { |
60 | if (!Orig.contains(F)) { |
61 | DifferenceFound = true; |
62 | dbgs() << "Function " << F->getName() |
63 | << " found in current IR but not in original snapshot.\n" ; |
64 | dbgs() << CurrFS.TextualIR << "\n" ; |
65 | } |
66 | } |
67 | return DifferenceFound; |
68 | } |
69 | |
70 | void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); } |
71 | |
72 | void IRSnapshotChecker::expectNoDiff() { |
73 | ContextSnapshot CurrContextSnapshot = takeSnapshot(); |
74 | if (diff(OrigContextSnapshot, CurrContextSnapshot)) { |
75 | llvm_unreachable( |
76 | "Original and current IR differ! Probably a checkpointing bug." ); |
77 | } |
78 | } |
79 | |
80 | void UseSet::dump() const { |
81 | dump(dbgs()); |
82 | dbgs() << "\n" ; |
83 | } |
84 | |
85 | void UseSwap::dump() const { |
86 | dump(dbgs()); |
87 | dbgs() << "\n" ; |
88 | } |
89 | #endif // NDEBUG |
90 | |
91 | PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx) |
92 | : PHI(PHI), RemovedIdx(RemovedIdx) { |
93 | RemovedV = PHI->getIncomingValue(Idx: RemovedIdx); |
94 | RemovedBB = PHI->getIncomingBlock(Idx: RemovedIdx); |
95 | } |
96 | |
97 | void PHIRemoveIncoming::revert(Tracker &Tracker) { |
98 | // Special case: if the PHI is now empty, as we don't need to care about the |
99 | // order of the incoming values. |
100 | unsigned NumIncoming = PHI->getNumIncomingValues(); |
101 | if (NumIncoming == 0) { |
102 | PHI->addIncoming(V: RemovedV, BB: RemovedBB); |
103 | return; |
104 | } |
105 | // Shift all incoming values by one starting from the end until `Idx`. |
106 | // Start by adding a copy of the last incoming values. |
107 | unsigned LastIdx = NumIncoming - 1; |
108 | PHI->addIncoming(V: PHI->getIncomingValue(Idx: LastIdx), |
109 | BB: PHI->getIncomingBlock(Idx: LastIdx)); |
110 | for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) { |
111 | auto *PrevV = PHI->getIncomingValue(Idx: Idx - 1); |
112 | auto *PrevBB = PHI->getIncomingBlock(Idx: Idx - 1); |
113 | PHI->setIncomingValue(Idx, V: PrevV); |
114 | PHI->setIncomingBlock(Idx, BB: PrevBB); |
115 | } |
116 | PHI->setIncomingValue(Idx: RemovedIdx, V: RemovedV); |
117 | PHI->setIncomingBlock(Idx: RemovedIdx, BB: RemovedBB); |
118 | } |
119 | |
120 | #ifndef NDEBUG |
121 | void PHIRemoveIncoming::dump() const { |
122 | dump(dbgs()); |
123 | dbgs() << "\n" ; |
124 | } |
125 | #endif // NDEBUG |
126 | |
127 | PHIAddIncoming::PHIAddIncoming(PHINode *PHI) |
128 | : PHI(PHI), Idx(PHI->getNumIncomingValues()) {} |
129 | |
130 | void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); } |
131 | |
132 | #ifndef NDEBUG |
133 | void PHIAddIncoming::dump() const { |
134 | dump(dbgs()); |
135 | dbgs() << "\n" ; |
136 | } |
137 | #endif // NDEBUG |
138 | |
139 | Tracker::~Tracker() { |
140 | assert(Changes.empty() && "You must accept or revert changes!" ); |
141 | } |
142 | |
143 | EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr) |
144 | : ErasedIPtr(std::move(ErasedIPtr)) { |
145 | auto *I = cast<Instruction>(Val: this->ErasedIPtr.get()); |
146 | auto LLVMInstrs = I->getLLVMInstrs(); |
147 | // Iterate in reverse program order. |
148 | for (auto *LLVMI : reverse(C&: LLVMInstrs)) { |
149 | SmallVector<llvm::Value *> Operands; |
150 | Operands.reserve(N: LLVMI->getNumOperands()); |
151 | for (auto [OpNum, Use] : enumerate(First: LLVMI->operands())) |
152 | Operands.push_back(Elt: Use.get()); |
153 | InstrData.push_back(Elt: {.Operands: Operands, .LLVMI: LLVMI}); |
154 | } |
155 | assert(is_sorted(InstrData, |
156 | [](const auto &D0, const auto &D1) { |
157 | return D0.LLVMI->comesBefore(D1.LLVMI); |
158 | }) && |
159 | "Expected reverse program order!" ); |
160 | auto *BotLLVMI = cast<llvm::Instruction>(Val: I->Val); |
161 | if (BotLLVMI->getNextNode() != nullptr) |
162 | NextLLVMIOrBB = BotLLVMI->getNextNode(); |
163 | else |
164 | NextLLVMIOrBB = BotLLVMI->getParent(); |
165 | } |
166 | |
167 | void EraseFromParent::accept() { |
168 | for (const auto &IData : InstrData) |
169 | IData.LLVMI->deleteValue(); |
170 | } |
171 | |
172 | void EraseFromParent::revert(Tracker &Tracker) { |
173 | // Place the bottom-most instruction first. |
174 | auto [Operands, BotLLVMI] = InstrData[0]; |
175 | if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(Val&: NextLLVMIOrBB)) { |
176 | BotLLVMI->insertBefore(InsertPos: NextLLVMI->getIterator()); |
177 | } else { |
178 | auto *LLVMBB = cast<llvm::BasicBlock *>(Val&: NextLLVMIOrBB); |
179 | BotLLVMI->insertInto(ParentBB: LLVMBB, It: LLVMBB->end()); |
180 | } |
181 | for (auto [OpNum, Op] : enumerate(First&: Operands)) |
182 | BotLLVMI->setOperand(i: OpNum, Val: Op); |
183 | |
184 | // Go over the rest of the instructions and stack them on top. |
185 | for (auto [Operands, LLVMI] : drop_begin(RangeOrContainer&: InstrData)) { |
186 | LLVMI->insertBefore(InsertPos: BotLLVMI->getIterator()); |
187 | for (auto [OpNum, Op] : enumerate(First&: Operands)) |
188 | LLVMI->setOperand(i: OpNum, Val: Op); |
189 | BotLLVMI = LLVMI; |
190 | } |
191 | Tracker.getContext().registerValue(VPtr: std::move(ErasedIPtr)); |
192 | } |
193 | |
194 | #ifndef NDEBUG |
195 | void EraseFromParent::dump() const { |
196 | dump(dbgs()); |
197 | dbgs() << "\n" ; |
198 | } |
199 | #endif // NDEBUG |
200 | |
201 | RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) { |
202 | if (auto *NextI = RemovedI->getNextNode()) |
203 | NextInstrOrBB = NextI; |
204 | else |
205 | NextInstrOrBB = RemovedI->getParent(); |
206 | } |
207 | |
208 | void RemoveFromParent::revert(Tracker &Tracker) { |
209 | if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) { |
210 | RemovedI->insertBefore(BeforeI: NextI); |
211 | } else { |
212 | auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB); |
213 | RemovedI->insertInto(BB, WhereIt: BB->end()); |
214 | } |
215 | } |
216 | |
217 | #ifndef NDEBUG |
218 | void RemoveFromParent::dump() const { |
219 | dump(dbgs()); |
220 | dbgs() << "\n" ; |
221 | } |
222 | #endif |
223 | |
224 | CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI) |
225 | : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {} |
226 | |
227 | void CatchSwitchAddHandler::revert(Tracker &Tracker) { |
228 | // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler() |
229 | // once it gets implemented. |
230 | auto *LLVMCSI = cast<llvm::CatchSwitchInst>(Val: CSI->Val); |
231 | LLVMCSI->removeHandler(HI: LLVMCSI->handler_begin() + HandlerIdx); |
232 | } |
233 | |
234 | SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { |
235 | for (const auto &C : Switch->cases()) |
236 | Cases.push_back(Elt: {.Val: C.getCaseValue(), .Dest: C.getCaseSuccessor()}); |
237 | } |
238 | |
239 | void SwitchRemoveCase::revert(Tracker &Tracker) { |
240 | // SwitchInst::removeCase doesn't provide any guarantees about the order of |
241 | // cases after removal. In order to preserve the original ordering, we save |
242 | // all of them and, when reverting, clear them all then insert them in the |
243 | // desired order. This still relies on the fact that `addCase` will insert |
244 | // them at the end, but it is documented to invalidate `case_end()` so it's |
245 | // probably okay. |
246 | unsigned NumCases = Switch->getNumCases(); |
247 | for (unsigned I = 0; I < NumCases; ++I) |
248 | Switch->removeCase(It: Switch->case_begin()); |
249 | for (auto &Case : Cases) |
250 | Switch->addCase(OnVal: Case.Val, Dest: Case.Dest); |
251 | } |
252 | |
253 | #ifndef NDEBUG |
254 | void SwitchRemoveCase::dump() const { |
255 | dump(dbgs()); |
256 | dbgs() << "\n" ; |
257 | } |
258 | #endif // NDEBUG |
259 | |
260 | void SwitchAddCase::revert(Tracker &Tracker) { |
261 | auto It = Switch->findCaseValue(C: Val); |
262 | Switch->removeCase(It); |
263 | } |
264 | |
265 | #ifndef NDEBUG |
266 | void SwitchAddCase::dump() const { |
267 | dump(dbgs()); |
268 | dbgs() << "\n" ; |
269 | } |
270 | #endif // NDEBUG |
271 | |
272 | MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) { |
273 | if (auto *NextI = MovedI->getNextNode()) |
274 | NextInstrOrBB = NextI; |
275 | else |
276 | NextInstrOrBB = MovedI->getParent(); |
277 | } |
278 | |
279 | void MoveInstr::revert(Tracker &Tracker) { |
280 | if (auto *NextI = dyn_cast<Instruction *>(Val&: NextInstrOrBB)) { |
281 | MovedI->moveBefore(Before: NextI); |
282 | } else { |
283 | auto *BB = cast<BasicBlock *>(Val&: NextInstrOrBB); |
284 | MovedI->moveBefore(BB&: *BB, WhereIt: BB->end()); |
285 | } |
286 | } |
287 | |
288 | #ifndef NDEBUG |
289 | void MoveInstr::dump() const { |
290 | dump(dbgs()); |
291 | dbgs() << "\n" ; |
292 | } |
293 | #endif |
294 | |
295 | void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); } |
296 | |
297 | InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {} |
298 | |
299 | #ifndef NDEBUG |
300 | void InsertIntoBB::dump() const { |
301 | dump(dbgs()); |
302 | dbgs() << "\n" ; |
303 | } |
304 | #endif |
305 | |
306 | void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); } |
307 | |
308 | #ifndef NDEBUG |
309 | void CreateAndInsertInst::dump() const { |
310 | dump(dbgs()); |
311 | dbgs() << "\n" ; |
312 | } |
313 | #endif |
314 | |
315 | ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI) |
316 | : SVI(SVI), PrevMask(SVI->getShuffleMask()) {} |
317 | |
318 | void ShuffleVectorSetMask::revert(Tracker &Tracker) { |
319 | SVI->setShuffleMask(PrevMask); |
320 | } |
321 | |
322 | #ifndef NDEBUG |
323 | void ShuffleVectorSetMask::dump() const { |
324 | dump(dbgs()); |
325 | dbgs() << "\n" ; |
326 | } |
327 | #endif |
328 | |
329 | CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {} |
330 | |
331 | void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); } |
332 | #ifndef NDEBUG |
333 | void CmpSwapOperands::dump() const { |
334 | dump(dbgs()); |
335 | dbgs() << "\n" ; |
336 | } |
337 | #endif |
338 | |
339 | void Tracker::save() { |
340 | State = TrackerState::Record; |
341 | #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
342 | SnapshotChecker.save(); |
343 | #endif |
344 | } |
345 | |
346 | void Tracker::revert() { |
347 | assert(State == TrackerState::Record && "Forgot to save()!" ); |
348 | State = TrackerState::Reverting; |
349 | for (auto &Change : reverse(C&: Changes)) |
350 | Change->revert(Tracker&: *this); |
351 | Changes.clear(); |
352 | #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) |
353 | SnapshotChecker.expectNoDiff(); |
354 | #endif |
355 | State = TrackerState::Disabled; |
356 | } |
357 | |
358 | void Tracker::accept() { |
359 | assert(State == TrackerState::Record && "Forgot to save()!" ); |
360 | State = TrackerState::Disabled; |
361 | for (auto &Change : Changes) |
362 | Change->accept(); |
363 | Changes.clear(); |
364 | } |
365 | |
366 | #ifndef NDEBUG |
367 | void Tracker::dump(raw_ostream &OS) const { |
368 | for (auto [Idx, ChangePtr] : enumerate(Changes)) { |
369 | OS << Idx << ". " ; |
370 | ChangePtr->dump(OS); |
371 | OS << "\n" ; |
372 | } |
373 | } |
374 | void Tracker::dump() const { |
375 | dump(dbgs()); |
376 | dbgs() << "\n" ; |
377 | } |
378 | #endif // NDEBUG |
379 | |