1//===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Context.h"
10#include "llvm/IR/InlineAsm.h"
11#include "llvm/SandboxIR/Function.h"
12#include "llvm/SandboxIR/Instruction.h"
13#include "llvm/SandboxIR/Module.h"
14
15namespace llvm::sandboxir {
16
17std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18 std::unique_ptr<Value> Erased;
19 auto It = LLVMValueToValueMap.find(Val: V);
20 if (It != LLVMValueToValueMap.end()) {
21 auto *Val = It->second.release();
22 Erased = std::unique_ptr<Value>(Val);
23 LLVMValueToValueMap.erase(I: It);
24 }
25 return Erased;
26}
27
28std::unique_ptr<Value> Context::detach(Value *V) {
29 assert(V->getSubclassID() != Value::ClassID::Constant &&
30 "Can't detach a constant!");
31 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32 return detachLLVMValue(V: V->Val);
33}
34
35Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36 assert(VPtr->getSubclassID() != Value::ClassID::User &&
37 "Can't register a user!");
38
39 Value *V = VPtr.get();
40 [[maybe_unused]] auto Pair =
41 LLVMValueToValueMap.insert(KV: {VPtr->Val, std::move(VPtr)});
42 assert(Pair.second && "Already exists!");
43
44 // Track creation of instructions.
45 // Please note that we don't allow the creation of detached instructions,
46 // meaning that the instructions need to be inserted into a block upon
47 // creation. This is why the tracker class combines creation and insertion.
48 if (auto *I = dyn_cast<Instruction>(Val: V)) {
49 getTracker().emplaceIfTracking<CreateAndInsertInst>(Args: I);
50 runCreateInstrCallbacks(I);
51 }
52
53 return V;
54}
55
56Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMV);
58 auto It = Pair.first;
59 if (!Pair.second)
60 return It->second.get();
61
62 // Instruction
63 if (auto *LLVMI = dyn_cast<llvm::Instruction>(Val: LLVMV)) {
64 switch (LLVMI->getOpcode()) {
65 case llvm::Instruction::VAArg: {
66 auto *LLVMVAArg = cast<llvm::VAArgInst>(Val: LLVMV);
67 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68 return It->second.get();
69 }
70 case llvm::Instruction::Freeze: {
71 auto *LLVMFreeze = cast<llvm::FreezeInst>(Val: LLVMV);
72 It->second =
73 std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74 return It->second.get();
75 }
76 case llvm::Instruction::Fence: {
77 auto *LLVMFence = cast<llvm::FenceInst>(Val: LLVMV);
78 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79 return It->second.get();
80 }
81 case llvm::Instruction::Select: {
82 auto *LLVMSel = cast<llvm::SelectInst>(Val: LLVMV);
83 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84 return It->second.get();
85 }
86 case llvm::Instruction::ExtractElement: {
87 auto *LLVMIns = cast<llvm::ExtractElementInst>(Val: LLVMV);
88 It->second = std::unique_ptr<ExtractElementInst>(
89 new ExtractElementInst(LLVMIns, *this));
90 return It->second.get();
91 }
92 case llvm::Instruction::InsertElement: {
93 auto *LLVMIns = cast<llvm::InsertElementInst>(Val: LLVMV);
94 It->second = std::unique_ptr<InsertElementInst>(
95 new InsertElementInst(LLVMIns, *this));
96 return It->second.get();
97 }
98 case llvm::Instruction::ShuffleVector: {
99 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(Val: LLVMV);
100 It->second = std::unique_ptr<ShuffleVectorInst>(
101 new ShuffleVectorInst(LLVMIns, *this));
102 return It->second.get();
103 }
104 case llvm::Instruction::ExtractValue: {
105 auto *LLVMIns = cast<llvm::ExtractValueInst>(Val: LLVMV);
106 It->second = std::unique_ptr<ExtractValueInst>(
107 new ExtractValueInst(LLVMIns, *this));
108 return It->second.get();
109 }
110 case llvm::Instruction::InsertValue: {
111 auto *LLVMIns = cast<llvm::InsertValueInst>(Val: LLVMV);
112 It->second =
113 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114 return It->second.get();
115 }
116 case llvm::Instruction::Br: {
117 auto *LLVMBr = cast<llvm::BranchInst>(Val: LLVMV);
118 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
119 return It->second.get();
120 }
121 case llvm::Instruction::Load: {
122 auto *LLVMLd = cast<llvm::LoadInst>(Val: LLVMV);
123 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
124 return It->second.get();
125 }
126 case llvm::Instruction::Store: {
127 auto *LLVMSt = cast<llvm::StoreInst>(Val: LLVMV);
128 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
129 return It->second.get();
130 }
131 case llvm::Instruction::Ret: {
132 auto *LLVMRet = cast<llvm::ReturnInst>(Val: LLVMV);
133 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
134 return It->second.get();
135 }
136 case llvm::Instruction::Call: {
137 auto *LLVMCall = cast<llvm::CallInst>(Val: LLVMV);
138 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
139 return It->second.get();
140 }
141 case llvm::Instruction::Invoke: {
142 auto *LLVMInvoke = cast<llvm::InvokeInst>(Val: LLVMV);
143 It->second =
144 std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
145 return It->second.get();
146 }
147 case llvm::Instruction::CallBr: {
148 auto *LLVMCallBr = cast<llvm::CallBrInst>(Val: LLVMV);
149 It->second =
150 std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
151 return It->second.get();
152 }
153 case llvm::Instruction::LandingPad: {
154 auto *LLVMLPad = cast<llvm::LandingPadInst>(Val: LLVMV);
155 It->second =
156 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
157 return It->second.get();
158 }
159 case llvm::Instruction::CatchPad: {
160 auto *LLVMCPI = cast<llvm::CatchPadInst>(Val: LLVMV);
161 It->second =
162 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
163 return It->second.get();
164 }
165 case llvm::Instruction::CleanupPad: {
166 auto *LLVMCPI = cast<llvm::CleanupPadInst>(Val: LLVMV);
167 It->second =
168 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
169 return It->second.get();
170 }
171 case llvm::Instruction::CatchRet: {
172 auto *LLVMCRI = cast<llvm::CatchReturnInst>(Val: LLVMV);
173 It->second =
174 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
175 return It->second.get();
176 }
177 case llvm::Instruction::CleanupRet: {
178 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(Val: LLVMV);
179 It->second = std::unique_ptr<CleanupReturnInst>(
180 new CleanupReturnInst(LLVMCRI, *this));
181 return It->second.get();
182 }
183 case llvm::Instruction::GetElementPtr: {
184 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(Val: LLVMV);
185 It->second = std::unique_ptr<GetElementPtrInst>(
186 new GetElementPtrInst(LLVMGEP, *this));
187 return It->second.get();
188 }
189 case llvm::Instruction::CatchSwitch: {
190 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(Val: LLVMV);
191 It->second = std::unique_ptr<CatchSwitchInst>(
192 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
193 return It->second.get();
194 }
195 case llvm::Instruction::Resume: {
196 auto *LLVMResumeInst = cast<llvm::ResumeInst>(Val: LLVMV);
197 It->second =
198 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
199 return It->second.get();
200 }
201 case llvm::Instruction::Switch: {
202 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(Val: LLVMV);
203 It->second =
204 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
205 return It->second.get();
206 }
207 case llvm::Instruction::FNeg: {
208 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(Val: LLVMV);
209 It->second = std::unique_ptr<UnaryOperator>(
210 new UnaryOperator(LLVMUnaryOperator, *this));
211 return It->second.get();
212 }
213 case llvm::Instruction::Add:
214 case llvm::Instruction::FAdd:
215 case llvm::Instruction::Sub:
216 case llvm::Instruction::FSub:
217 case llvm::Instruction::Mul:
218 case llvm::Instruction::FMul:
219 case llvm::Instruction::UDiv:
220 case llvm::Instruction::SDiv:
221 case llvm::Instruction::FDiv:
222 case llvm::Instruction::URem:
223 case llvm::Instruction::SRem:
224 case llvm::Instruction::FRem:
225 case llvm::Instruction::Shl:
226 case llvm::Instruction::LShr:
227 case llvm::Instruction::AShr:
228 case llvm::Instruction::And:
229 case llvm::Instruction::Or:
230 case llvm::Instruction::Xor: {
231 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(Val: LLVMV);
232 It->second = std::unique_ptr<BinaryOperator>(
233 new BinaryOperator(LLVMBinaryOperator, *this));
234 return It->second.get();
235 }
236 case llvm::Instruction::AtomicRMW: {
237 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(Val: LLVMV);
238 It->second = std::unique_ptr<AtomicRMWInst>(
239 new AtomicRMWInst(LLVMAtomicRMW, *this));
240 return It->second.get();
241 }
242 case llvm::Instruction::AtomicCmpXchg: {
243 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(Val: LLVMV);
244 It->second = std::unique_ptr<AtomicCmpXchgInst>(
245 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
246 return It->second.get();
247 }
248 case llvm::Instruction::Alloca: {
249 auto *LLVMAlloca = cast<llvm::AllocaInst>(Val: LLVMV);
250 It->second =
251 std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
252 return It->second.get();
253 }
254 case llvm::Instruction::ZExt:
255 case llvm::Instruction::SExt:
256 case llvm::Instruction::FPToUI:
257 case llvm::Instruction::FPToSI:
258 case llvm::Instruction::FPExt:
259 case llvm::Instruction::PtrToInt:
260 case llvm::Instruction::IntToPtr:
261 case llvm::Instruction::SIToFP:
262 case llvm::Instruction::UIToFP:
263 case llvm::Instruction::Trunc:
264 case llvm::Instruction::FPTrunc:
265 case llvm::Instruction::BitCast:
266 case llvm::Instruction::AddrSpaceCast: {
267 auto *LLVMCast = cast<llvm::CastInst>(Val: LLVMV);
268 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
269 return It->second.get();
270 }
271 case llvm::Instruction::PHI: {
272 auto *LLVMPhi = cast<llvm::PHINode>(Val: LLVMV);
273 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
274 return It->second.get();
275 }
276 case llvm::Instruction::ICmp: {
277 auto *LLVMICmp = cast<llvm::ICmpInst>(Val: LLVMV);
278 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
279 return It->second.get();
280 }
281 case llvm::Instruction::FCmp: {
282 auto *LLVMFCmp = cast<llvm::FCmpInst>(Val: LLVMV);
283 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
284 return It->second.get();
285 }
286 case llvm::Instruction::Unreachable: {
287 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(Val: LLVMV);
288 It->second = std::unique_ptr<UnreachableInst>(
289 new UnreachableInst(LLVMUnreachable, *this));
290 return It->second.get();
291 }
292 default:
293 break;
294 }
295 It->second = std::unique_ptr<OpaqueInst>(
296 new OpaqueInst(cast<llvm::Instruction>(Val: LLVMV), *this));
297 return It->second.get();
298 }
299 // Constant
300 if (auto *LLVMC = dyn_cast<llvm::Constant>(Val: LLVMV)) {
301 switch (LLVMC->getValueID()) {
302 case llvm::Value::ConstantIntVal:
303 It->second = std::unique_ptr<ConstantInt>(
304 new ConstantInt(cast<llvm::ConstantInt>(Val: LLVMC), *this));
305 return It->second.get();
306 case llvm::Value::ConstantFPVal:
307 It->second = std::unique_ptr<ConstantFP>(
308 new ConstantFP(cast<llvm::ConstantFP>(Val: LLVMC), *this));
309 return It->second.get();
310 case llvm::Value::BlockAddressVal:
311 It->second = std::unique_ptr<BlockAddress>(
312 new BlockAddress(cast<llvm::BlockAddress>(Val: LLVMC), *this));
313 return It->second.get();
314 case llvm::Value::ConstantTokenNoneVal:
315 It->second = std::unique_ptr<ConstantTokenNone>(
316 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(Val: LLVMC), *this));
317 return It->second.get();
318 case llvm::Value::ConstantAggregateZeroVal: {
319 auto *CAZ = cast<llvm::ConstantAggregateZero>(Val: LLVMC);
320 It->second = std::unique_ptr<ConstantAggregateZero>(
321 new ConstantAggregateZero(CAZ, *this));
322 auto *Ret = It->second.get();
323 // Must create sandboxir for elements.
324 auto EC = CAZ->getElementCount();
325 if (EC.isFixed()) {
326 for (auto ElmIdx : seq<unsigned>(Begin: 0, End: EC.getFixedValue()))
327 getOrCreateValueInternal(LLVMV: CAZ->getElementValue(Idx: ElmIdx), U: CAZ);
328 }
329 return Ret;
330 }
331 case llvm::Value::ConstantPointerNullVal:
332 It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
333 cast<llvm::ConstantPointerNull>(Val: LLVMC), *this));
334 return It->second.get();
335 case llvm::Value::PoisonValueVal:
336 It->second = std::unique_ptr<PoisonValue>(
337 new PoisonValue(cast<llvm::PoisonValue>(Val: LLVMC), *this));
338 return It->second.get();
339 case llvm::Value::UndefValueVal:
340 It->second = std::unique_ptr<UndefValue>(
341 new UndefValue(cast<llvm::UndefValue>(Val: LLVMC), *this));
342 return It->second.get();
343 case llvm::Value::DSOLocalEquivalentVal: {
344 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(Val: LLVMC);
345 It->second = std::unique_ptr<DSOLocalEquivalent>(
346 new DSOLocalEquivalent(DSOLE, *this));
347 auto *Ret = It->second.get();
348 getOrCreateValueInternal(LLVMV: DSOLE->getGlobalValue(), U: DSOLE);
349 return Ret;
350 }
351 case llvm::Value::ConstantArrayVal:
352 It->second = std::unique_ptr<ConstantArray>(
353 new ConstantArray(cast<llvm::ConstantArray>(Val: LLVMC), *this));
354 break;
355 case llvm::Value::ConstantStructVal:
356 It->second = std::unique_ptr<ConstantStruct>(
357 new ConstantStruct(cast<llvm::ConstantStruct>(Val: LLVMC), *this));
358 break;
359 case llvm::Value::ConstantVectorVal:
360 It->second = std::unique_ptr<ConstantVector>(
361 new ConstantVector(cast<llvm::ConstantVector>(Val: LLVMC), *this));
362 break;
363 case llvm::Value::ConstantDataArrayVal:
364 It->second = std::unique_ptr<ConstantDataArray>(
365 new ConstantDataArray(cast<llvm::ConstantDataArray>(Val: LLVMC), *this));
366 break;
367 case llvm::Value::ConstantDataVectorVal:
368 It->second = std::unique_ptr<ConstantDataVector>(
369 new ConstantDataVector(cast<llvm::ConstantDataVector>(Val: LLVMC), *this));
370 break;
371 case llvm::Value::FunctionVal:
372 It->second = std::unique_ptr<Function>(
373 new Function(cast<llvm::Function>(Val: LLVMC), *this));
374 break;
375 case llvm::Value::GlobalIFuncVal:
376 It->second = std::unique_ptr<GlobalIFunc>(
377 new GlobalIFunc(cast<llvm::GlobalIFunc>(Val: LLVMC), *this));
378 break;
379 case llvm::Value::GlobalVariableVal:
380 It->second = std::unique_ptr<GlobalVariable>(
381 new GlobalVariable(cast<llvm::GlobalVariable>(Val: LLVMC), *this));
382 break;
383 case llvm::Value::GlobalAliasVal:
384 It->second = std::unique_ptr<GlobalAlias>(
385 new GlobalAlias(cast<llvm::GlobalAlias>(Val: LLVMC), *this));
386 break;
387 case llvm::Value::NoCFIValueVal:
388 It->second = std::unique_ptr<NoCFIValue>(
389 new NoCFIValue(cast<llvm::NoCFIValue>(Val: LLVMC), *this));
390 break;
391 case llvm::Value::ConstantPtrAuthVal:
392 It->second = std::unique_ptr<ConstantPtrAuth>(
393 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(Val: LLVMC), *this));
394 break;
395 case llvm::Value::ConstantExprVal:
396 It->second = std::unique_ptr<ConstantExpr>(
397 new ConstantExpr(cast<llvm::ConstantExpr>(Val: LLVMC), *this));
398 break;
399 default:
400 It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
401 break;
402 }
403 auto *NewC = It->second.get();
404 for (llvm::Value *COp : LLVMC->operands())
405 getOrCreateValueInternal(LLVMV: COp, U: LLVMC);
406 return NewC;
407 }
408 // Argument
409 if (auto *LLVMArg = dyn_cast<llvm::Argument>(Val: LLVMV)) {
410 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
411 return It->second.get();
412 }
413 // BasicBlock
414 if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(Val: LLVMV)) {
415 assert(isa<llvm::BlockAddress>(U) &&
416 "This won't create a SBBB, don't call this function directly!");
417 if (auto *SBBB = getValue(V: LLVMBB))
418 return SBBB;
419 return nullptr;
420 }
421 // Metadata
422 if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(Val: LLVMV)) {
423 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
424 return It->second.get();
425 }
426 // InlineAsm
427 if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(Val: LLVMV)) {
428 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
429 return It->second.get();
430 }
431 llvm_unreachable("Unhandled LLVMV type!");
432}
433
434Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
435 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMArg);
436 auto It = Pair.first;
437 if (Pair.second) {
438 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
439 return cast<Argument>(Val: It->second.get());
440 }
441 return cast<Argument>(Val: It->second.get());
442}
443
444Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
445 return cast<Constant>(Val: getOrCreateValueInternal(LLVMV: LLVMC, U: 0));
446}
447
448BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
449 assert(getValue(LLVMBB) == nullptr && "Already exists!");
450 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
451 auto *BB = cast<BasicBlock>(Val: registerValue(VPtr: std::move(NewBBPtr)));
452 // Create SandboxIR for BB's body.
453 BB->buildBasicBlockFromLLVMIR(LLVMBB);
454 return BB;
455}
456
457VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
458 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
459 return cast<VAArgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
460}
461
462FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
463 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
464 return cast<FreezeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
465}
466
467FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
468 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
469 return cast<FenceInst>(Val: registerValue(VPtr: std::move(NewPtr)));
470}
471
472SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
473 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
474 return cast<SelectInst>(Val: registerValue(VPtr: std::move(NewPtr)));
475}
476
477ExtractElementInst *
478Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
479 auto NewPtr =
480 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
481 return cast<ExtractElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
482}
483
484InsertElementInst *
485Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
486 auto NewPtr =
487 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
488 return cast<InsertElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
489}
490
491ShuffleVectorInst *
492Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
493 auto NewPtr =
494 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
495 return cast<ShuffleVectorInst>(Val: registerValue(VPtr: std::move(NewPtr)));
496}
497
498ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
499 auto NewPtr =
500 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
501 return cast<ExtractValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
502}
503
504InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
505 auto NewPtr =
506 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
507 return cast<InsertValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
508}
509
510BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
511 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
512 return cast<BranchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
513}
514
515LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
516 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
517 return cast<LoadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
518}
519
520StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
521 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
522 return cast<StoreInst>(Val: registerValue(VPtr: std::move(NewPtr)));
523}
524
525ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
526 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
527 return cast<ReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
528}
529
530CallInst *Context::createCallInst(llvm::CallInst *I) {
531 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
532 return cast<CallInst>(Val: registerValue(VPtr: std::move(NewPtr)));
533}
534
535InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
536 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
537 return cast<InvokeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
538}
539
540CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
541 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
542 return cast<CallBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
543}
544
545UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
546 auto NewPtr =
547 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
548 return cast<UnreachableInst>(Val: registerValue(VPtr: std::move(NewPtr)));
549}
550LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
551 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
552 return cast<LandingPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
553}
554CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
555 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
556 return cast<CatchPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
557}
558CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
559 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
560 return cast<CleanupPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
561}
562CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
563 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
564 return cast<CatchReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
565}
566CleanupReturnInst *
567Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
568 auto NewPtr =
569 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
570 return cast<CleanupReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
571}
572GetElementPtrInst *
573Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
574 auto NewPtr =
575 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
576 return cast<GetElementPtrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
577}
578CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
579 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
580 return cast<CatchSwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
581}
582ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
583 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
584 return cast<ResumeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
585}
586SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
587 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
588 return cast<SwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
589}
590UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
591 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
592 return cast<UnaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
593}
594BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
595 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
596 return cast<BinaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
597}
598AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
599 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
600 return cast<AtomicRMWInst>(Val: registerValue(VPtr: std::move(NewPtr)));
601}
602AtomicCmpXchgInst *
603Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
604 auto NewPtr =
605 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
606 return cast<AtomicCmpXchgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
607}
608AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
609 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
610 return cast<AllocaInst>(Val: registerValue(VPtr: std::move(NewPtr)));
611}
612CastInst *Context::createCastInst(llvm::CastInst *I) {
613 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
614 return cast<CastInst>(Val: registerValue(VPtr: std::move(NewPtr)));
615}
616PHINode *Context::createPHINode(llvm::PHINode *I) {
617 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
618 return cast<PHINode>(Val: registerValue(VPtr: std::move(NewPtr)));
619}
620ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
621 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
622 return cast<ICmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
623}
624FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
625 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
626 return cast<FCmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
627}
628Value *Context::getValue(llvm::Value *V) const {
629 auto It = LLVMValueToValueMap.find(Val: V);
630 if (It != LLVMValueToValueMap.end())
631 return It->second.get();
632 return nullptr;
633}
634
635Context::Context(LLVMContext &LLVMCtx)
636 : LLVMCtx(LLVMCtx), IRTracker(*this),
637 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
638
639Context::~Context() {}
640
641void Context::clear() {
642 // TODO: Ideally we should clear only function-scope objects, and keep global
643 // objects, like Constants to avoid recreating them.
644 LLVMValueToValueMap.clear();
645}
646
647Module *Context::getModule(llvm::Module *LLVMM) const {
648 auto It = LLVMModuleToModuleMap.find(Val: LLVMM);
649 if (It != LLVMModuleToModuleMap.end())
650 return It->second.get();
651 return nullptr;
652}
653
654Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
655 auto Pair = LLVMModuleToModuleMap.try_emplace(Key: LLVMM);
656 auto It = Pair.first;
657 if (!Pair.second)
658 return It->second.get();
659 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
660 return It->second.get();
661}
662
663Function *Context::createFunction(llvm::Function *F) {
664 // Create the module if needed before we create the new sandboxir::Function.
665 // Note: this won't fully populate the module. The only globals that will be
666 // available will be the ones being used within the function.
667 getOrCreateModule(LLVMM: F->getParent());
668
669 // There may be a function declaration already defined. Regardless destroy it.
670 if (Function *ExistingF = cast_or_null<Function>(Val: getValue(V: F)))
671 detach(V: ExistingF);
672
673 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
674 auto *SBF = cast<Function>(Val: registerValue(VPtr: std::move(NewFPtr)));
675 // Create arguments.
676 for (auto &Arg : F->args())
677 getOrCreateArgument(LLVMArg: &Arg);
678 // Create BBs.
679 for (auto &BB : *F)
680 createBasicBlock(LLVMBB: &BB);
681 return SBF;
682}
683
684Module *Context::createModule(llvm::Module *LLVMM) {
685 auto *M = getOrCreateModule(LLVMM);
686 // Create the functions.
687 for (auto &LLVMF : *LLVMM)
688 createFunction(F: &LLVMF);
689 // Create globals.
690 for (auto &Global : LLVMM->globals())
691 getOrCreateValue(LLVMV: &Global);
692 // Create aliases.
693 for (auto &Alias : LLVMM->aliases())
694 getOrCreateValue(LLVMV: &Alias);
695 // Create ifuncs.
696 for (auto &IFunc : LLVMM->ifuncs())
697 getOrCreateValue(LLVMV: &IFunc);
698
699 return M;
700}
701
702void Context::runEraseInstrCallbacks(Instruction *I) {
703 for (const auto &CBEntry : EraseInstrCallbacks)
704 CBEntry.second(I);
705}
706
707void Context::runCreateInstrCallbacks(Instruction *I) {
708 for (auto &CBEntry : CreateInstrCallbacks)
709 CBEntry.second(I);
710}
711
712void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
713 for (auto &CBEntry : MoveInstrCallbacks)
714 CBEntry.second(I, WhereIt);
715}
716
717void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
718 for (auto &CBEntry : SetUseCallbacks)
719 CBEntry.second(U, NewSrc);
720}
721
722// An arbitrary limit, to check for accidental misuse. We expect a small number
723// of callbacks to be registered at a time, but we can increase this number if
724// we discover we needed more.
725[[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
726
727Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
728 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
729 "EraseInstrCallbacks size limit exceeded");
730 CallbackID ID{NextCallbackID++};
731 EraseInstrCallbacks[ID] = CB;
732 return ID;
733}
734void Context::unregisterEraseInstrCallback(CallbackID ID) {
735 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(Key: ID);
736 assert(Erased &&
737 "Callback ID not found in EraseInstrCallbacks during deregistration");
738}
739
740Context::CallbackID
741Context::registerCreateInstrCallback(CreateInstrCallback CB) {
742 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
743 "CreateInstrCallbacks size limit exceeded");
744 CallbackID ID{NextCallbackID++};
745 CreateInstrCallbacks[ID] = CB;
746 return ID;
747}
748void Context::unregisterCreateInstrCallback(CallbackID ID) {
749 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(Key: ID);
750 assert(Erased &&
751 "Callback ID not found in CreateInstrCallbacks during deregistration");
752}
753
754Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
755 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
756 "MoveInstrCallbacks size limit exceeded");
757 CallbackID ID{NextCallbackID++};
758 MoveInstrCallbacks[ID] = CB;
759 return ID;
760}
761void Context::unregisterMoveInstrCallback(CallbackID ID) {
762 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(Key: ID);
763 assert(Erased &&
764 "Callback ID not found in MoveInstrCallbacks during deregistration");
765}
766
767Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
768 assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
769 "SetUseCallbacks size limit exceeded");
770 CallbackID ID{NextCallbackID++};
771 SetUseCallbacks[ID] = CB;
772 return ID;
773}
774void Context::unregisterSetUseCallback(CallbackID ID) {
775 [[maybe_unused]] bool Erased = SetUseCallbacks.erase(Key: ID);
776 assert(Erased &&
777 "Callback ID not found in SetUseCallbacks during deregistration");
778}
779
780} // namespace llvm::sandboxir
781