1//===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Context.h"
10#include "llvm/IR/InlineAsm.h"
11#include "llvm/SandboxIR/Function.h"
12#include "llvm/SandboxIR/Instruction.h"
13#include "llvm/SandboxIR/Module.h"
14
15namespace llvm::sandboxir {
16
17std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18 std::unique_ptr<Value> Erased;
19 auto It = LLVMValueToValueMap.find(Val: V);
20 if (It != LLVMValueToValueMap.end()) {
21 auto *Val = It->second.release();
22 Erased = std::unique_ptr<Value>(Val);
23 LLVMValueToValueMap.erase(I: It);
24 }
25 return Erased;
26}
27
28std::unique_ptr<Value> Context::detach(Value *V) {
29 assert(V->getSubclassID() != Value::ClassID::Constant &&
30 "Can't detach a constant!");
31 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32 return detachLLVMValue(V: V->Val);
33}
34
35Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36 assert(VPtr->getSubclassID() != Value::ClassID::User &&
37 "Can't register a user!");
38
39 Value *V = VPtr.get();
40 [[maybe_unused]] auto Pair =
41 LLVMValueToValueMap.insert(KV: {VPtr->Val, std::move(VPtr)});
42 assert(Pair.second && "Already exists!");
43
44 // Track creation of instructions.
45 // Please note that we don't allow the creation of detached instructions,
46 // meaning that the instructions need to be inserted into a block upon
47 // creation. This is why the tracker class combines creation and insertion.
48 if (auto *I = dyn_cast<Instruction>(Val: V)) {
49 getTracker().emplaceIfTracking<CreateAndInsertInst>(Args: I);
50 runCreateInstrCallbacks(I);
51 }
52
53 return V;
54}
55
56Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMV);
58 auto It = Pair.first;
59 if (!Pair.second)
60 return It->second.get();
61
62 // Instruction
63 if (auto *LLVMI = dyn_cast<llvm::Instruction>(Val: LLVMV)) {
64 switch (LLVMI->getOpcode()) {
65 case llvm::Instruction::VAArg: {
66 auto *LLVMVAArg = cast<llvm::VAArgInst>(Val: LLVMV);
67 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68 return It->second.get();
69 }
70 case llvm::Instruction::Freeze: {
71 auto *LLVMFreeze = cast<llvm::FreezeInst>(Val: LLVMV);
72 It->second =
73 std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74 return It->second.get();
75 }
76 case llvm::Instruction::Fence: {
77 auto *LLVMFence = cast<llvm::FenceInst>(Val: LLVMV);
78 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79 return It->second.get();
80 }
81 case llvm::Instruction::Select: {
82 auto *LLVMSel = cast<llvm::SelectInst>(Val: LLVMV);
83 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84 return It->second.get();
85 }
86 case llvm::Instruction::ExtractElement: {
87 auto *LLVMIns = cast<llvm::ExtractElementInst>(Val: LLVMV);
88 It->second = std::unique_ptr<ExtractElementInst>(
89 new ExtractElementInst(LLVMIns, *this));
90 return It->second.get();
91 }
92 case llvm::Instruction::InsertElement: {
93 auto *LLVMIns = cast<llvm::InsertElementInst>(Val: LLVMV);
94 It->second = std::unique_ptr<InsertElementInst>(
95 new InsertElementInst(LLVMIns, *this));
96 return It->second.get();
97 }
98 case llvm::Instruction::ShuffleVector: {
99 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(Val: LLVMV);
100 It->second = std::unique_ptr<ShuffleVectorInst>(
101 new ShuffleVectorInst(LLVMIns, *this));
102 return It->second.get();
103 }
104 case llvm::Instruction::ExtractValue: {
105 auto *LLVMIns = cast<llvm::ExtractValueInst>(Val: LLVMV);
106 It->second = std::unique_ptr<ExtractValueInst>(
107 new ExtractValueInst(LLVMIns, *this));
108 return It->second.get();
109 }
110 case llvm::Instruction::InsertValue: {
111 auto *LLVMIns = cast<llvm::InsertValueInst>(Val: LLVMV);
112 It->second =
113 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114 return It->second.get();
115 }
116 case llvm::Instruction::UncondBr:
117 case llvm::Instruction::CondBr: {
118 auto *LLVMBr = cast<llvm::BranchInst>(Val: LLVMV);
119 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
120 return It->second.get();
121 }
122 case llvm::Instruction::Load: {
123 auto *LLVMLd = cast<llvm::LoadInst>(Val: LLVMV);
124 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
125 return It->second.get();
126 }
127 case llvm::Instruction::Store: {
128 auto *LLVMSt = cast<llvm::StoreInst>(Val: LLVMV);
129 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
130 return It->second.get();
131 }
132 case llvm::Instruction::Ret: {
133 auto *LLVMRet = cast<llvm::ReturnInst>(Val: LLVMV);
134 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
135 return It->second.get();
136 }
137 case llvm::Instruction::Call: {
138 auto *LLVMCall = cast<llvm::CallInst>(Val: LLVMV);
139 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
140 return It->second.get();
141 }
142 case llvm::Instruction::Invoke: {
143 auto *LLVMInvoke = cast<llvm::InvokeInst>(Val: LLVMV);
144 It->second =
145 std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
146 return It->second.get();
147 }
148 case llvm::Instruction::CallBr: {
149 auto *LLVMCallBr = cast<llvm::CallBrInst>(Val: LLVMV);
150 It->second =
151 std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
152 return It->second.get();
153 }
154 case llvm::Instruction::LandingPad: {
155 auto *LLVMLPad = cast<llvm::LandingPadInst>(Val: LLVMV);
156 It->second =
157 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
158 return It->second.get();
159 }
160 case llvm::Instruction::CatchPad: {
161 auto *LLVMCPI = cast<llvm::CatchPadInst>(Val: LLVMV);
162 It->second =
163 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
164 return It->second.get();
165 }
166 case llvm::Instruction::CleanupPad: {
167 auto *LLVMCPI = cast<llvm::CleanupPadInst>(Val: LLVMV);
168 It->second =
169 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
170 return It->second.get();
171 }
172 case llvm::Instruction::CatchRet: {
173 auto *LLVMCRI = cast<llvm::CatchReturnInst>(Val: LLVMV);
174 It->second =
175 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
176 return It->second.get();
177 }
178 case llvm::Instruction::CleanupRet: {
179 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(Val: LLVMV);
180 It->second = std::unique_ptr<CleanupReturnInst>(
181 new CleanupReturnInst(LLVMCRI, *this));
182 return It->second.get();
183 }
184 case llvm::Instruction::GetElementPtr: {
185 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(Val: LLVMV);
186 It->second = std::unique_ptr<GetElementPtrInst>(
187 new GetElementPtrInst(LLVMGEP, *this));
188 return It->second.get();
189 }
190 case llvm::Instruction::CatchSwitch: {
191 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(Val: LLVMV);
192 It->second = std::unique_ptr<CatchSwitchInst>(
193 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
194 return It->second.get();
195 }
196 case llvm::Instruction::Resume: {
197 auto *LLVMResumeInst = cast<llvm::ResumeInst>(Val: LLVMV);
198 It->second =
199 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
200 return It->second.get();
201 }
202 case llvm::Instruction::Switch: {
203 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(Val: LLVMV);
204 It->second =
205 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
206 return It->second.get();
207 }
208 case llvm::Instruction::FNeg: {
209 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(Val: LLVMV);
210 It->second = std::unique_ptr<UnaryOperator>(
211 new UnaryOperator(LLVMUnaryOperator, *this));
212 return It->second.get();
213 }
214 case llvm::Instruction::Add:
215 case llvm::Instruction::FAdd:
216 case llvm::Instruction::Sub:
217 case llvm::Instruction::FSub:
218 case llvm::Instruction::Mul:
219 case llvm::Instruction::FMul:
220 case llvm::Instruction::UDiv:
221 case llvm::Instruction::SDiv:
222 case llvm::Instruction::FDiv:
223 case llvm::Instruction::URem:
224 case llvm::Instruction::SRem:
225 case llvm::Instruction::FRem:
226 case llvm::Instruction::Shl:
227 case llvm::Instruction::LShr:
228 case llvm::Instruction::AShr:
229 case llvm::Instruction::And:
230 case llvm::Instruction::Or:
231 case llvm::Instruction::Xor: {
232 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(Val: LLVMV);
233 It->second = std::unique_ptr<BinaryOperator>(
234 new BinaryOperator(LLVMBinaryOperator, *this));
235 return It->second.get();
236 }
237 case llvm::Instruction::AtomicRMW: {
238 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(Val: LLVMV);
239 It->second = std::unique_ptr<AtomicRMWInst>(
240 new AtomicRMWInst(LLVMAtomicRMW, *this));
241 return It->second.get();
242 }
243 case llvm::Instruction::AtomicCmpXchg: {
244 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(Val: LLVMV);
245 It->second = std::unique_ptr<AtomicCmpXchgInst>(
246 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
247 return It->second.get();
248 }
249 case llvm::Instruction::Alloca: {
250 auto *LLVMAlloca = cast<llvm::AllocaInst>(Val: LLVMV);
251 It->second =
252 std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
253 return It->second.get();
254 }
255 case llvm::Instruction::ZExt:
256 case llvm::Instruction::SExt:
257 case llvm::Instruction::FPToUI:
258 case llvm::Instruction::FPToSI:
259 case llvm::Instruction::FPExt:
260 case llvm::Instruction::PtrToAddr:
261 case llvm::Instruction::PtrToInt:
262 case llvm::Instruction::IntToPtr:
263 case llvm::Instruction::SIToFP:
264 case llvm::Instruction::UIToFP:
265 case llvm::Instruction::Trunc:
266 case llvm::Instruction::FPTrunc:
267 case llvm::Instruction::BitCast:
268 case llvm::Instruction::AddrSpaceCast: {
269 auto *LLVMCast = cast<llvm::CastInst>(Val: LLVMV);
270 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
271 return It->second.get();
272 }
273 case llvm::Instruction::PHI: {
274 auto *LLVMPhi = cast<llvm::PHINode>(Val: LLVMV);
275 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
276 return It->second.get();
277 }
278 case llvm::Instruction::ICmp: {
279 auto *LLVMICmp = cast<llvm::ICmpInst>(Val: LLVMV);
280 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
281 return It->second.get();
282 }
283 case llvm::Instruction::FCmp: {
284 auto *LLVMFCmp = cast<llvm::FCmpInst>(Val: LLVMV);
285 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
286 return It->second.get();
287 }
288 case llvm::Instruction::Unreachable: {
289 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(Val: LLVMV);
290 It->second = std::unique_ptr<UnreachableInst>(
291 new UnreachableInst(LLVMUnreachable, *this));
292 return It->second.get();
293 }
294 default:
295 break;
296 }
297 It->second = std::unique_ptr<OpaqueInst>(
298 new OpaqueInst(cast<llvm::Instruction>(Val: LLVMV), *this));
299 return It->second.get();
300 }
301 // Constant
302 if (auto *LLVMC = dyn_cast<llvm::Constant>(Val: LLVMV)) {
303 switch (LLVMC->getValueID()) {
304 case llvm::Value::ConstantIntVal:
305 It->second = std::unique_ptr<ConstantInt>(
306 new ConstantInt(cast<llvm::ConstantInt>(Val: LLVMC), *this));
307 return It->second.get();
308 case llvm::Value::ConstantFPVal:
309 It->second = std::unique_ptr<ConstantFP>(
310 new ConstantFP(cast<llvm::ConstantFP>(Val: LLVMC), *this));
311 return It->second.get();
312 case llvm::Value::BlockAddressVal:
313 It->second = std::unique_ptr<BlockAddress>(
314 new BlockAddress(cast<llvm::BlockAddress>(Val: LLVMC), *this));
315 return It->second.get();
316 case llvm::Value::ConstantTokenNoneVal:
317 It->second = std::unique_ptr<ConstantTokenNone>(
318 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(Val: LLVMC), *this));
319 return It->second.get();
320 case llvm::Value::ConstantAggregateZeroVal: {
321 auto *CAZ = cast<llvm::ConstantAggregateZero>(Val: LLVMC);
322 It->second = std::unique_ptr<ConstantAggregateZero>(
323 new ConstantAggregateZero(CAZ, *this));
324 auto *Ret = It->second.get();
325 // Must create sandboxir for elements.
326 auto EC = CAZ->getElementCount();
327 if (EC.isFixed()) {
328 for (auto ElmIdx : seq<unsigned>(Begin: 0, End: EC.getFixedValue()))
329 getOrCreateValueInternal(LLVMV: CAZ->getElementValue(Idx: ElmIdx), U: CAZ);
330 }
331 return Ret;
332 }
333 case llvm::Value::ConstantPointerNullVal:
334 It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
335 cast<llvm::ConstantPointerNull>(Val: LLVMC), *this));
336 return It->second.get();
337 case llvm::Value::PoisonValueVal:
338 It->second = std::unique_ptr<PoisonValue>(
339 new PoisonValue(cast<llvm::PoisonValue>(Val: LLVMC), *this));
340 return It->second.get();
341 case llvm::Value::UndefValueVal:
342 It->second = std::unique_ptr<UndefValue>(
343 new UndefValue(cast<llvm::UndefValue>(Val: LLVMC), *this));
344 return It->second.get();
345 case llvm::Value::DSOLocalEquivalentVal: {
346 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(Val: LLVMC);
347 It->second = std::unique_ptr<DSOLocalEquivalent>(
348 new DSOLocalEquivalent(DSOLE, *this));
349 auto *Ret = It->second.get();
350 getOrCreateValueInternal(LLVMV: DSOLE->getGlobalValue(), U: DSOLE);
351 return Ret;
352 }
353 case llvm::Value::ConstantArrayVal:
354 It->second = std::unique_ptr<ConstantArray>(
355 new ConstantArray(cast<llvm::ConstantArray>(Val: LLVMC), *this));
356 break;
357 case llvm::Value::ConstantStructVal:
358 It->second = std::unique_ptr<ConstantStruct>(
359 new ConstantStruct(cast<llvm::ConstantStruct>(Val: LLVMC), *this));
360 break;
361 case llvm::Value::ConstantVectorVal:
362 It->second = std::unique_ptr<ConstantVector>(
363 new ConstantVector(cast<llvm::ConstantVector>(Val: LLVMC), *this));
364 break;
365 case llvm::Value::ConstantDataArrayVal:
366 It->second = std::unique_ptr<ConstantDataArray>(
367 new ConstantDataArray(cast<llvm::ConstantDataArray>(Val: LLVMC), *this));
368 break;
369 case llvm::Value::ConstantDataVectorVal:
370 It->second = std::unique_ptr<ConstantDataVector>(
371 new ConstantDataVector(cast<llvm::ConstantDataVector>(Val: LLVMC), *this));
372 break;
373 case llvm::Value::FunctionVal:
374 It->second = std::unique_ptr<Function>(
375 new Function(cast<llvm::Function>(Val: LLVMC), *this));
376 break;
377 case llvm::Value::GlobalIFuncVal:
378 It->second = std::unique_ptr<GlobalIFunc>(
379 new GlobalIFunc(cast<llvm::GlobalIFunc>(Val: LLVMC), *this));
380 break;
381 case llvm::Value::GlobalVariableVal:
382 It->second = std::unique_ptr<GlobalVariable>(
383 new GlobalVariable(cast<llvm::GlobalVariable>(Val: LLVMC), *this));
384 break;
385 case llvm::Value::GlobalAliasVal:
386 It->second = std::unique_ptr<GlobalAlias>(
387 new GlobalAlias(cast<llvm::GlobalAlias>(Val: LLVMC), *this));
388 break;
389 case llvm::Value::NoCFIValueVal:
390 It->second = std::unique_ptr<NoCFIValue>(
391 new NoCFIValue(cast<llvm::NoCFIValue>(Val: LLVMC), *this));
392 break;
393 case llvm::Value::ConstantPtrAuthVal:
394 It->second = std::unique_ptr<ConstantPtrAuth>(
395 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(Val: LLVMC), *this));
396 break;
397 case llvm::Value::ConstantExprVal:
398 It->second = std::unique_ptr<ConstantExpr>(
399 new ConstantExpr(cast<llvm::ConstantExpr>(Val: LLVMC), *this));
400 break;
401 default:
402 It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
403 break;
404 }
405 auto *NewC = It->second.get();
406 for (llvm::Value *COp : LLVMC->operands())
407 getOrCreateValueInternal(LLVMV: COp, U: LLVMC);
408 return NewC;
409 }
410 // Argument
411 if (auto *LLVMArg = dyn_cast<llvm::Argument>(Val: LLVMV)) {
412 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
413 return It->second.get();
414 }
415 // BasicBlock
416 if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(Val: LLVMV)) {
417 assert(isa<llvm::BlockAddress>(U) &&
418 "This won't create a SBBB, don't call this function directly!");
419 if (auto *SBBB = getValue(V: LLVMBB))
420 return SBBB;
421 return nullptr;
422 }
423 // Metadata
424 if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(Val: LLVMV)) {
425 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
426 return It->second.get();
427 }
428 // InlineAsm
429 if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(Val: LLVMV)) {
430 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
431 return It->second.get();
432 }
433 llvm_unreachable("Unhandled LLVMV type!");
434}
435
436Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
437 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMArg);
438 auto It = Pair.first;
439 if (Pair.second) {
440 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
441 return cast<Argument>(Val: It->second.get());
442 }
443 return cast<Argument>(Val: It->second.get());
444}
445
446Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
447 return cast<Constant>(Val: getOrCreateValueInternal(LLVMV: LLVMC, U: nullptr));
448}
449
450BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
451 assert(getValue(LLVMBB) == nullptr && "Already exists!");
452 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
453 auto *BB = cast<BasicBlock>(Val: registerValue(VPtr: std::move(NewBBPtr)));
454 // Create SandboxIR for BB's body.
455 BB->buildBasicBlockFromLLVMIR(LLVMBB);
456 return BB;
457}
458
459VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
460 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
461 return cast<VAArgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
462}
463
464FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
465 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
466 return cast<FreezeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
467}
468
469FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
470 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
471 return cast<FenceInst>(Val: registerValue(VPtr: std::move(NewPtr)));
472}
473
474SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
475 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
476 return cast<SelectInst>(Val: registerValue(VPtr: std::move(NewPtr)));
477}
478
479ExtractElementInst *
480Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
481 auto NewPtr =
482 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
483 return cast<ExtractElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
484}
485
486InsertElementInst *
487Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
488 auto NewPtr =
489 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
490 return cast<InsertElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
491}
492
493ShuffleVectorInst *
494Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
495 auto NewPtr =
496 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
497 return cast<ShuffleVectorInst>(Val: registerValue(VPtr: std::move(NewPtr)));
498}
499
500ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
501 auto NewPtr =
502 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
503 return cast<ExtractValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
504}
505
506InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
507 auto NewPtr =
508 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
509 return cast<InsertValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
510}
511
512BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
513 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
514 return cast<BranchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
515}
516
517LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
518 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
519 return cast<LoadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
520}
521
522StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
523 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
524 return cast<StoreInst>(Val: registerValue(VPtr: std::move(NewPtr)));
525}
526
527ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
528 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
529 return cast<ReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
530}
531
532CallInst *Context::createCallInst(llvm::CallInst *I) {
533 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
534 return cast<CallInst>(Val: registerValue(VPtr: std::move(NewPtr)));
535}
536
537InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
538 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
539 return cast<InvokeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
540}
541
542CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
543 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
544 return cast<CallBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
545}
546
547UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
548 auto NewPtr =
549 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
550 return cast<UnreachableInst>(Val: registerValue(VPtr: std::move(NewPtr)));
551}
552LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
553 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
554 return cast<LandingPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
555}
556CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
557 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
558 return cast<CatchPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
559}
560CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
561 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
562 return cast<CleanupPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
563}
564CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
565 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
566 return cast<CatchReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
567}
568CleanupReturnInst *
569Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
570 auto NewPtr =
571 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
572 return cast<CleanupReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
573}
574GetElementPtrInst *
575Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
576 auto NewPtr =
577 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
578 return cast<GetElementPtrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
579}
580CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
581 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
582 return cast<CatchSwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
583}
584ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
585 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
586 return cast<ResumeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
587}
588SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
589 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
590 return cast<SwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
591}
592UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
593 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
594 return cast<UnaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
595}
596BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
597 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
598 return cast<BinaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
599}
600AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
601 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
602 return cast<AtomicRMWInst>(Val: registerValue(VPtr: std::move(NewPtr)));
603}
604AtomicCmpXchgInst *
605Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
606 auto NewPtr =
607 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
608 return cast<AtomicCmpXchgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
609}
610AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
611 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
612 return cast<AllocaInst>(Val: registerValue(VPtr: std::move(NewPtr)));
613}
614CastInst *Context::createCastInst(llvm::CastInst *I) {
615 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
616 return cast<CastInst>(Val: registerValue(VPtr: std::move(NewPtr)));
617}
618PHINode *Context::createPHINode(llvm::PHINode *I) {
619 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
620 return cast<PHINode>(Val: registerValue(VPtr: std::move(NewPtr)));
621}
622ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
623 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
624 return cast<ICmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
625}
626FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
627 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
628 return cast<FCmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
629}
630Value *Context::getValue(llvm::Value *V) const {
631 auto It = LLVMValueToValueMap.find(Val: V);
632 if (It != LLVMValueToValueMap.end())
633 return It->second.get();
634 return nullptr;
635}
636
637Context::Context(LLVMContext &LLVMCtx)
638 : LLVMCtx(LLVMCtx), IRTracker(*this),
639 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
640
641Context::~Context() = default;
642
643void Context::clear() {
644 // TODO: Ideally we should clear only function-scope objects, and keep global
645 // objects, like Constants to avoid recreating them.
646 LLVMValueToValueMap.clear();
647}
648
649Module *Context::getModule(llvm::Module *LLVMM) const {
650 auto It = LLVMModuleToModuleMap.find(Val: LLVMM);
651 if (It != LLVMModuleToModuleMap.end())
652 return It->second.get();
653 return nullptr;
654}
655
656Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
657 auto Pair = LLVMModuleToModuleMap.try_emplace(Key: LLVMM);
658 auto It = Pair.first;
659 if (!Pair.second)
660 return It->second.get();
661 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
662 return It->second.get();
663}
664
665Function *Context::createFunction(llvm::Function *F) {
666 // Create the module if needed before we create the new sandboxir::Function.
667 // Note: this won't fully populate the module. The only globals that will be
668 // available will be the ones being used within the function.
669 getOrCreateModule(LLVMM: F->getParent());
670
671 // There may be a function declaration already defined. Regardless destroy it.
672 if (Function *ExistingF = cast_or_null<Function>(Val: getValue(V: F)))
673 detach(V: ExistingF);
674
675 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
676 auto *SBF = cast<Function>(Val: registerValue(VPtr: std::move(NewFPtr)));
677 // Create arguments.
678 for (auto &Arg : F->args())
679 getOrCreateArgument(LLVMArg: &Arg);
680 // Create BBs.
681 for (auto &BB : *F)
682 createBasicBlock(LLVMBB: &BB);
683 return SBF;
684}
685
686Module *Context::createModule(llvm::Module *LLVMM) {
687 auto *M = getOrCreateModule(LLVMM);
688 // Create the functions.
689 for (auto &LLVMF : *LLVMM)
690 createFunction(F: &LLVMF);
691 // Create globals.
692 for (auto &Global : LLVMM->globals())
693 getOrCreateValue(LLVMV: &Global);
694 // Create aliases.
695 for (auto &Alias : LLVMM->aliases())
696 getOrCreateValue(LLVMV: &Alias);
697 // Create ifuncs.
698 for (auto &IFunc : LLVMM->ifuncs())
699 getOrCreateValue(LLVMV: &IFunc);
700
701 return M;
702}
703
704void Context::runEraseInstrCallbacks(Instruction *I) {
705 for (const auto &CBEntry : EraseInstrCallbacks)
706 CBEntry.second(I);
707}
708
709void Context::runCreateInstrCallbacks(Instruction *I) {
710 for (auto &CBEntry : CreateInstrCallbacks)
711 CBEntry.second(I);
712}
713
714void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
715 for (auto &CBEntry : MoveInstrCallbacks)
716 CBEntry.second(I, WhereIt);
717}
718
719void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
720 for (auto &CBEntry : SetUseCallbacks)
721 CBEntry.second(U, NewSrc);
722}
723
724// An arbitrary limit, to check for accidental misuse. We expect a small number
725// of callbacks to be registered at a time, but we can increase this number if
726// we discover we needed more.
727[[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
728
729Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
730 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
731 "EraseInstrCallbacks size limit exceeded");
732 CallbackID ID{NextCallbackID++};
733 EraseInstrCallbacks[ID] = std::move(CB);
734 return ID;
735}
736void Context::unregisterEraseInstrCallback(CallbackID ID) {
737 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(Key: ID);
738 assert(Erased &&
739 "Callback ID not found in EraseInstrCallbacks during deregistration");
740}
741
742Context::CallbackID
743Context::registerCreateInstrCallback(CreateInstrCallback CB) {
744 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
745 "CreateInstrCallbacks size limit exceeded");
746 CallbackID ID{NextCallbackID++};
747 CreateInstrCallbacks[ID] = std::move(CB);
748 return ID;
749}
750void Context::unregisterCreateInstrCallback(CallbackID ID) {
751 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(Key: ID);
752 assert(Erased &&
753 "Callback ID not found in CreateInstrCallbacks during deregistration");
754}
755
756Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
757 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
758 "MoveInstrCallbacks size limit exceeded");
759 CallbackID ID{NextCallbackID++};
760 MoveInstrCallbacks[ID] = std::move(CB);
761 return ID;
762}
763void Context::unregisterMoveInstrCallback(CallbackID ID) {
764 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(Key: ID);
765 assert(Erased &&
766 "Callback ID not found in MoveInstrCallbacks during deregistration");
767}
768
769Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
770 assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
771 "SetUseCallbacks size limit exceeded");
772 CallbackID ID{NextCallbackID++};
773 SetUseCallbacks[ID] = std::move(CB);
774 return ID;
775}
776void Context::unregisterSetUseCallback(CallbackID ID) {
777 [[maybe_unused]] bool Erased = SetUseCallbacks.erase(Key: ID);
778 assert(Erased &&
779 "Callback ID not found in SetUseCallbacks during deregistration");
780}
781
782} // namespace llvm::sandboxir
783