1//===- Tracker.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Tracker.h"
10#include "llvm/ADT/STLExtras.h"
11#include "llvm/IR/BasicBlock.h"
12#include "llvm/IR/Instruction.h"
13#include "llvm/SandboxIR/SandboxIR.h"
14#include <sstream>
15
16using namespace llvm::sandboxir;
17
18IRChangeBase::IRChangeBase(Tracker &Parent) : Parent(Parent) {
19#ifndef NDEBUG
20 assert(!Parent.InMiddleOfCreatingChange &&
21 "We are in the middle of creating another change!");
22 if (Parent.isTracking())
23 Parent.InMiddleOfCreatingChange = true;
24#endif // NDEBUG
25}
26
27#ifndef NDEBUG
28unsigned IRChangeBase::getIdx() const {
29 auto It =
30 find_if(Parent.Changes, [this](auto &Ptr) { return Ptr.get() == this; });
31 return It - Parent.Changes.begin();
32}
33
34void UseSet::dump() const {
35 dump(dbgs());
36 dbgs() << "\n";
37}
38#endif // NDEBUG
39
40Tracker::~Tracker() {
41 assert(Changes.empty() && "You must accept or revert changes!");
42}
43
44EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr,
45 Tracker &Tracker)
46 : IRChangeBase(Tracker), ErasedIPtr(std::move(ErasedIPtr)) {
47 auto *I = cast<Instruction>(Val: this->ErasedIPtr.get());
48 auto LLVMInstrs = I->getLLVMInstrs();
49 // Iterate in reverse program order.
50 for (auto *LLVMI : reverse(C&: LLVMInstrs)) {
51 SmallVector<llvm::Value *> Operands;
52 Operands.reserve(N: LLVMI->getNumOperands());
53 for (auto [OpNum, Use] : enumerate(First: LLVMI->operands()))
54 Operands.push_back(Elt: Use.get());
55 InstrData.push_back(Elt: {.Operands: Operands, .LLVMI: LLVMI});
56 }
57 assert(is_sorted(InstrData,
58 [](const auto &D0, const auto &D1) {
59 return D0.LLVMI->comesBefore(D1.LLVMI);
60 }) &&
61 "Expected reverse program order!");
62 auto *BotLLVMI = cast<llvm::Instruction>(Val: I->Val);
63 if (BotLLVMI->getNextNode() != nullptr)
64 NextLLVMIOrBB = BotLLVMI->getNextNode();
65 else
66 NextLLVMIOrBB = BotLLVMI->getParent();
67}
68
69void EraseFromParent::accept() {
70 for (const auto &IData : InstrData)
71 IData.LLVMI->deleteValue();
72}
73
74void EraseFromParent::revert() {
75 // Place the bottom-most instruction first.
76 auto [Operands, BotLLVMI] = InstrData[0];
77 if (auto *NextLLVMI = NextLLVMIOrBB.dyn_cast<llvm::Instruction *>()) {
78 BotLLVMI->insertBefore(InsertPos: NextLLVMI);
79 } else {
80 auto *LLVMBB = NextLLVMIOrBB.get<llvm::BasicBlock *>();
81 BotLLVMI->insertInto(ParentBB: LLVMBB, It: LLVMBB->end());
82 }
83 for (auto [OpNum, Op] : enumerate(First&: Operands))
84 BotLLVMI->setOperand(i: OpNum, Val: Op);
85
86 // Go over the rest of the instructions and stack them on top.
87 for (auto [Operands, LLVMI] : drop_begin(RangeOrContainer&: InstrData)) {
88 LLVMI->insertBefore(InsertPos: BotLLVMI);
89 for (auto [OpNum, Op] : enumerate(First&: Operands))
90 LLVMI->setOperand(i: OpNum, Val: Op);
91 BotLLVMI = LLVMI;
92 }
93 Parent.getContext().registerValue(VPtr: std::move(ErasedIPtr));
94}
95
96#ifndef NDEBUG
97void EraseFromParent::dump() const {
98 dump(dbgs());
99 dbgs() << "\n";
100}
101#endif // NDEBUG
102
103RemoveFromParent::RemoveFromParent(Instruction *RemovedI, Tracker &Tracker)
104 : IRChangeBase(Tracker), RemovedI(RemovedI) {
105 if (auto *NextI = RemovedI->getNextNode())
106 NextInstrOrBB = NextI;
107 else
108 NextInstrOrBB = RemovedI->getParent();
109}
110
111void RemoveFromParent::revert() {
112 if (auto *NextI = NextInstrOrBB.dyn_cast<Instruction *>()) {
113 RemovedI->insertBefore(BeforeI: NextI);
114 } else {
115 auto *BB = NextInstrOrBB.get<BasicBlock *>();
116 RemovedI->insertInto(BB, WhereIt: BB->end());
117 }
118}
119
120#ifndef NDEBUG
121void RemoveFromParent::dump() const {
122 dump(dbgs());
123 dbgs() << "\n";
124}
125#endif
126
127MoveInstr::MoveInstr(Instruction *MovedI, Tracker &Tracker)
128 : IRChangeBase(Tracker), MovedI(MovedI) {
129 if (auto *NextI = MovedI->getNextNode())
130 NextInstrOrBB = NextI;
131 else
132 NextInstrOrBB = MovedI->getParent();
133}
134
135void MoveInstr::revert() {
136 if (auto *NextI = NextInstrOrBB.dyn_cast<Instruction *>()) {
137 MovedI->moveBefore(Before: NextI);
138 } else {
139 auto *BB = NextInstrOrBB.get<BasicBlock *>();
140 MovedI->moveBefore(BB&: *BB, WhereIt: BB->end());
141 }
142}
143
144#ifndef NDEBUG
145void MoveInstr::dump() const {
146 dump(dbgs());
147 dbgs() << "\n";
148}
149#endif
150
151void Tracker::track(std::unique_ptr<IRChangeBase> &&Change) {
152 assert(State == TrackerState::Record && "The tracker should be tracking!");
153 Changes.push_back(Elt: std::move(Change));
154
155#ifndef NDEBUG
156 InMiddleOfCreatingChange = false;
157#endif
158}
159
160void Tracker::save() { State = TrackerState::Record; }
161
162void Tracker::revert() {
163 assert(State == TrackerState::Record && "Forgot to save()!");
164 State = TrackerState::Disabled;
165 for (auto &Change : reverse(C&: Changes))
166 Change->revert();
167 Changes.clear();
168}
169
170void Tracker::accept() {
171 assert(State == TrackerState::Record && "Forgot to save()!");
172 State = TrackerState::Disabled;
173 for (auto &Change : Changes)
174 Change->accept();
175 Changes.clear();
176}
177
178#ifndef NDEBUG
179void Tracker::dump(raw_ostream &OS) const {
180 for (const auto &ChangePtr : Changes) {
181 ChangePtr->dump(OS);
182 OS << "\n";
183 }
184}
185void Tracker::dump() const {
186 dump(dbgs());
187 dbgs() << "\n";
188}
189#endif // NDEBUG
190