1//===- Region.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Region.h"
10#include "llvm/SandboxIR/Function.h"
11
12namespace llvm::sandboxir {
13
14InstructionCost ScoreBoard::getCost(Instruction *I) const {
15 auto *LLVMI = cast<llvm::Instruction>(Val: I->Val);
16 SmallVector<const llvm::Value *> Operands(LLVMI->operands());
17 return TTI.getInstructionCost(U: LLVMI, Operands, CostKind);
18}
19
20void ScoreBoard::remove(Instruction *I) {
21 auto Cost = getCost(I);
22 if (Rgn.contains(I))
23 // If `I` is one the newly added ones, then we should adjust `AfterCost`
24 AfterCost -= Cost;
25 else
26 // If `I` is one of the original instructions (outside the region) then it
27 // is part of the original code, so adjust `BeforeCost`.
28 BeforeCost += Cost;
29}
30
31#ifndef NDEBUG
32void ScoreBoard::dump() const { dump(dbgs()); }
33#endif
34
35Region::Region(Context &Ctx, TargetTransformInfo &TTI)
36 : Ctx(Ctx), Scoreboard(*this, TTI) {
37 LLVMContext &LLVMCtx = Ctx.LLVMCtx;
38 auto *RegionStrMD = MDString::get(Context&: LLVMCtx, Str: RegionStr);
39 RegionMDN = MDNode::getDistinct(Context&: LLVMCtx, MDs: {RegionStrMD});
40
41 CreateInstCB = Ctx.registerCreateInstrCallback(
42 CB: [this](Instruction *NewInst) { add(I: NewInst); });
43 EraseInstCB = Ctx.registerEraseInstrCallback(CB: [this](Instruction *ErasedInst) {
44 remove(I: ErasedInst);
45 removeFromAux(I: ErasedInst);
46 });
47}
48
49Region::~Region() {
50 Ctx.unregisterCreateInstrCallback(ID: CreateInstCB);
51 Ctx.unregisterEraseInstrCallback(ID: EraseInstCB);
52}
53
54void Region::addImpl(Instruction *I, bool IgnoreCost) {
55 Insts.insert(X: I);
56 // TODO: Consider tagging instructions lazily.
57 cast<llvm::Instruction>(Val: I->Val)->setMetadata(Kind: MDKind, Node: RegionMDN);
58 if (!IgnoreCost)
59 // Keep track of the instruction cost.
60 Scoreboard.add(I);
61}
62
63void Region::setAux(ArrayRef<Instruction *> Aux) {
64 this->Aux = SmallVector<Instruction *>(Aux);
65 auto &LLVMCtx = Ctx.LLVMCtx;
66 for (auto [Idx, I] : enumerate(First&: Aux)) {
67 llvm::ConstantInt *IdxC =
68 llvm::ConstantInt::get(Ty: llvm::Type::getInt32Ty(C&: LLVMCtx), V: Idx, IsSigned: false);
69 assert(cast<llvm::Instruction>(I->Val)->getMetadata(AuxMDKind) == nullptr &&
70 "Instruction already in Aux!");
71 cast<llvm::Instruction>(Val: I->Val)->setMetadata(
72 Kind: AuxMDKind, Node: MDNode::get(Context&: LLVMCtx, MDs: ConstantAsMetadata::get(C: IdxC)));
73 // Aux instrs should always be in a region.
74 addImpl(I, /*DontTrackCost=*/IgnoreCost: true);
75 }
76}
77
78void Region::setAux(unsigned Idx, Instruction *I) {
79 assert((Idx >= Aux.size() || Aux[Idx] == nullptr) &&
80 "There is already an Instruction at Idx in Aux!");
81 unsigned ExpectedSz = Idx + 1;
82 if (Aux.size() < ExpectedSz) {
83 auto SzBefore = Aux.size();
84 Aux.resize(N: ExpectedSz);
85 // Initialize the gap with nullptr.
86 for (unsigned Idx = SzBefore; Idx + 1 < ExpectedSz; ++Idx)
87 Aux[Idx] = nullptr;
88 }
89 Aux[Idx] = I;
90 // Aux instrs should always be in a region.
91 addImpl(I, /*DontTrackCost=*/IgnoreCost: true);
92}
93
94void Region::dropAuxMetadata(Instruction *I) {
95 auto *LLVMI = cast<llvm::Instruction>(Val: I->Val);
96 LLVMI->setMetadata(Kind: AuxMDKind, Node: nullptr);
97}
98
99void Region::removeFromAux(Instruction *I) {
100 auto It = find(Range&: Aux, Val: I);
101 if (It == Aux.end())
102 return;
103 dropAuxMetadata(I);
104 Aux.erase(CI: It);
105}
106
107void Region::clearAux() {
108 for (unsigned Idx : seq<unsigned>(Begin: 0, End: Aux.size()))
109 dropAuxMetadata(I: Aux[Idx]);
110 Aux.clear();
111}
112
113void Region::remove(Instruction *I) {
114 // Keep track of the instruction cost. This need to be done *before* we remove
115 // `I` from the region.
116 Scoreboard.remove(I);
117
118 Insts.remove(X: I);
119 cast<llvm::Instruction>(Val: I->Val)->setMetadata(Kind: MDKind, Node: nullptr);
120}
121
122#ifndef NDEBUG
123bool Region::operator==(const Region &Other) const {
124 if (Insts.size() != Other.Insts.size())
125 return false;
126 if (!std::is_permutation(Insts.begin(), Insts.end(), Other.Insts.begin()))
127 return false;
128 return true;
129}
130
131void Region::dump(raw_ostream &OS) const {
132 for (auto *I : Insts)
133 OS << *I << "\n";
134 if (!Aux.empty()) {
135 OS << "\nAux:\n";
136 for (auto *I : Aux) {
137 if (I == nullptr)
138 OS << "NULL\n";
139 else
140 OS << *I << "\n";
141 }
142 }
143}
144
145void Region::dump() const {
146 dump(dbgs());
147 dbgs() << "\n";
148}
149#endif // NDEBUG
150
151SmallVector<std::unique_ptr<Region>>
152Region::createRegionsFromMD(Function &F, TargetTransformInfo &TTI) {
153 SmallVector<std::unique_ptr<Region>> Regions;
154 DenseMap<MDNode *, Region *> MDNToRegion;
155 auto &Ctx = F.getContext();
156 for (BasicBlock &BB : F) {
157 for (Instruction &Inst : BB) {
158 auto *LLVMI = cast<llvm::Instruction>(Val: Inst.Val);
159 Region *R = nullptr;
160 if (auto *MDN = LLVMI->getMetadata(Kind: MDKind)) {
161 auto [It, Inserted] = MDNToRegion.try_emplace(Key: MDN);
162 if (Inserted) {
163 Regions.push_back(Elt: std::make_unique<Region>(args&: Ctx, args&: TTI));
164 R = Regions.back().get();
165 It->second = R;
166 } else {
167 R = It->second;
168 }
169 R->addImpl(I: &Inst, /*IgnoreCost=*/true);
170 }
171 if (auto *AuxMDN = LLVMI->getMetadata(Kind: AuxMDKind)) {
172 llvm::Constant *IdxC =
173 dyn_cast<ConstantAsMetadata>(Val: AuxMDN->getOperand(I: 0))->getValue();
174 auto Idx = cast<llvm::ConstantInt>(Val: IdxC)->getSExtValue();
175 if (R == nullptr) {
176 errs() << "No region specified for Aux: '" << *LLVMI << "'\n";
177 exit(status: 1);
178 }
179 R->setAux(Idx, I: &Inst);
180 }
181 }
182 }
183#ifndef NDEBUG
184 // Check that there are no gaps in the Aux vector.
185 for (auto &RPtr : Regions)
186 for (auto *I : RPtr->getAux())
187 assert(I != nullptr && "Gap in Aux!");
188#endif
189 return Regions;
190}
191
192} // namespace llvm::sandboxir
193