1//===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Context.h"
10#include "llvm/IR/InlineAsm.h"
11#include "llvm/SandboxIR/Function.h"
12#include "llvm/SandboxIR/Instruction.h"
13#include "llvm/SandboxIR/Module.h"
14
15namespace llvm::sandboxir {
16
17std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18 std::unique_ptr<Value> Erased;
19 auto It = LLVMValueToValueMap.find(Val: V);
20 if (It != LLVMValueToValueMap.end()) {
21 auto *Val = It->second.release();
22 Erased = std::unique_ptr<Value>(Val);
23 LLVMValueToValueMap.erase(I: It);
24 }
25 return Erased;
26}
27
28std::unique_ptr<Value> Context::detach(Value *V) {
29 assert(V->getSubclassID() != Value::ClassID::Constant &&
30 "Can't detach a constant!");
31 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32 return detachLLVMValue(V: V->Val);
33}
34
35Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36 assert(VPtr->getSubclassID() != Value::ClassID::User &&
37 "Can't register a user!");
38
39 Value *V = VPtr.get();
40 [[maybe_unused]] auto Pair =
41 LLVMValueToValueMap.insert(KV: {VPtr->Val, std::move(VPtr)});
42 assert(Pair.second && "Already exists!");
43
44 // Track creation of instructions.
45 // Please note that we don't allow the creation of detached instructions,
46 // meaning that the instructions need to be inserted into a block upon
47 // creation. This is why the tracker class combines creation and insertion.
48 if (auto *I = dyn_cast<Instruction>(Val: V)) {
49 getTracker().emplaceIfTracking<CreateAndInsertInst>(Args: I);
50 runCreateInstrCallbacks(I);
51 }
52
53 return V;
54}
55
56Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMV);
58 auto It = Pair.first;
59 if (!Pair.second)
60 return It->second.get();
61
62 // Instruction
63 if (auto *LLVMI = dyn_cast<llvm::Instruction>(Val: LLVMV)) {
64 switch (LLVMI->getOpcode()) {
65 case llvm::Instruction::VAArg: {
66 auto *LLVMVAArg = cast<llvm::VAArgInst>(Val: LLVMV);
67 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68 return It->second.get();
69 }
70 case llvm::Instruction::Freeze: {
71 auto *LLVMFreeze = cast<llvm::FreezeInst>(Val: LLVMV);
72 It->second =
73 std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74 return It->second.get();
75 }
76 case llvm::Instruction::Fence: {
77 auto *LLVMFence = cast<llvm::FenceInst>(Val: LLVMV);
78 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79 return It->second.get();
80 }
81 case llvm::Instruction::Select: {
82 auto *LLVMSel = cast<llvm::SelectInst>(Val: LLVMV);
83 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84 return It->second.get();
85 }
86 case llvm::Instruction::ExtractElement: {
87 auto *LLVMIns = cast<llvm::ExtractElementInst>(Val: LLVMV);
88 It->second = std::unique_ptr<ExtractElementInst>(
89 new ExtractElementInst(LLVMIns, *this));
90 return It->second.get();
91 }
92 case llvm::Instruction::InsertElement: {
93 auto *LLVMIns = cast<llvm::InsertElementInst>(Val: LLVMV);
94 It->second = std::unique_ptr<InsertElementInst>(
95 new InsertElementInst(LLVMIns, *this));
96 return It->second.get();
97 }
98 case llvm::Instruction::ShuffleVector: {
99 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(Val: LLVMV);
100 It->second = std::unique_ptr<ShuffleVectorInst>(
101 new ShuffleVectorInst(LLVMIns, *this));
102 return It->second.get();
103 }
104 case llvm::Instruction::ExtractValue: {
105 auto *LLVMIns = cast<llvm::ExtractValueInst>(Val: LLVMV);
106 It->second = std::unique_ptr<ExtractValueInst>(
107 new ExtractValueInst(LLVMIns, *this));
108 return It->second.get();
109 }
110 case llvm::Instruction::InsertValue: {
111 auto *LLVMIns = cast<llvm::InsertValueInst>(Val: LLVMV);
112 It->second =
113 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114 return It->second.get();
115 }
116 case llvm::Instruction::Br: {
117 auto *LLVMBr = cast<llvm::BranchInst>(Val: LLVMV);
118 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
119 return It->second.get();
120 }
121 case llvm::Instruction::Load: {
122 auto *LLVMLd = cast<llvm::LoadInst>(Val: LLVMV);
123 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
124 return It->second.get();
125 }
126 case llvm::Instruction::Store: {
127 auto *LLVMSt = cast<llvm::StoreInst>(Val: LLVMV);
128 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
129 return It->second.get();
130 }
131 case llvm::Instruction::Ret: {
132 auto *LLVMRet = cast<llvm::ReturnInst>(Val: LLVMV);
133 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
134 return It->second.get();
135 }
136 case llvm::Instruction::Call: {
137 auto *LLVMCall = cast<llvm::CallInst>(Val: LLVMV);
138 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
139 return It->second.get();
140 }
141 case llvm::Instruction::Invoke: {
142 auto *LLVMInvoke = cast<llvm::InvokeInst>(Val: LLVMV);
143 It->second =
144 std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
145 return It->second.get();
146 }
147 case llvm::Instruction::CallBr: {
148 auto *LLVMCallBr = cast<llvm::CallBrInst>(Val: LLVMV);
149 It->second =
150 std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
151 return It->second.get();
152 }
153 case llvm::Instruction::LandingPad: {
154 auto *LLVMLPad = cast<llvm::LandingPadInst>(Val: LLVMV);
155 It->second =
156 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
157 return It->second.get();
158 }
159 case llvm::Instruction::CatchPad: {
160 auto *LLVMCPI = cast<llvm::CatchPadInst>(Val: LLVMV);
161 It->second =
162 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
163 return It->second.get();
164 }
165 case llvm::Instruction::CleanupPad: {
166 auto *LLVMCPI = cast<llvm::CleanupPadInst>(Val: LLVMV);
167 It->second =
168 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
169 return It->second.get();
170 }
171 case llvm::Instruction::CatchRet: {
172 auto *LLVMCRI = cast<llvm::CatchReturnInst>(Val: LLVMV);
173 It->second =
174 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
175 return It->second.get();
176 }
177 case llvm::Instruction::CleanupRet: {
178 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(Val: LLVMV);
179 It->second = std::unique_ptr<CleanupReturnInst>(
180 new CleanupReturnInst(LLVMCRI, *this));
181 return It->second.get();
182 }
183 case llvm::Instruction::GetElementPtr: {
184 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(Val: LLVMV);
185 It->second = std::unique_ptr<GetElementPtrInst>(
186 new GetElementPtrInst(LLVMGEP, *this));
187 return It->second.get();
188 }
189 case llvm::Instruction::CatchSwitch: {
190 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(Val: LLVMV);
191 It->second = std::unique_ptr<CatchSwitchInst>(
192 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
193 return It->second.get();
194 }
195 case llvm::Instruction::Resume: {
196 auto *LLVMResumeInst = cast<llvm::ResumeInst>(Val: LLVMV);
197 It->second =
198 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
199 return It->second.get();
200 }
201 case llvm::Instruction::Switch: {
202 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(Val: LLVMV);
203 It->second =
204 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
205 return It->second.get();
206 }
207 case llvm::Instruction::FNeg: {
208 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(Val: LLVMV);
209 It->second = std::unique_ptr<UnaryOperator>(
210 new UnaryOperator(LLVMUnaryOperator, *this));
211 return It->second.get();
212 }
213 case llvm::Instruction::Add:
214 case llvm::Instruction::FAdd:
215 case llvm::Instruction::Sub:
216 case llvm::Instruction::FSub:
217 case llvm::Instruction::Mul:
218 case llvm::Instruction::FMul:
219 case llvm::Instruction::UDiv:
220 case llvm::Instruction::SDiv:
221 case llvm::Instruction::FDiv:
222 case llvm::Instruction::URem:
223 case llvm::Instruction::SRem:
224 case llvm::Instruction::FRem:
225 case llvm::Instruction::Shl:
226 case llvm::Instruction::LShr:
227 case llvm::Instruction::AShr:
228 case llvm::Instruction::And:
229 case llvm::Instruction::Or:
230 case llvm::Instruction::Xor: {
231 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(Val: LLVMV);
232 It->second = std::unique_ptr<BinaryOperator>(
233 new BinaryOperator(LLVMBinaryOperator, *this));
234 return It->second.get();
235 }
236 case llvm::Instruction::AtomicRMW: {
237 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(Val: LLVMV);
238 It->second = std::unique_ptr<AtomicRMWInst>(
239 new AtomicRMWInst(LLVMAtomicRMW, *this));
240 return It->second.get();
241 }
242 case llvm::Instruction::AtomicCmpXchg: {
243 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(Val: LLVMV);
244 It->second = std::unique_ptr<AtomicCmpXchgInst>(
245 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
246 return It->second.get();
247 }
248 case llvm::Instruction::Alloca: {
249 auto *LLVMAlloca = cast<llvm::AllocaInst>(Val: LLVMV);
250 It->second =
251 std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
252 return It->second.get();
253 }
254 case llvm::Instruction::ZExt:
255 case llvm::Instruction::SExt:
256 case llvm::Instruction::FPToUI:
257 case llvm::Instruction::FPToSI:
258 case llvm::Instruction::FPExt:
259 case llvm::Instruction::PtrToAddr:
260 case llvm::Instruction::PtrToInt:
261 case llvm::Instruction::IntToPtr:
262 case llvm::Instruction::SIToFP:
263 case llvm::Instruction::UIToFP:
264 case llvm::Instruction::Trunc:
265 case llvm::Instruction::FPTrunc:
266 case llvm::Instruction::BitCast:
267 case llvm::Instruction::AddrSpaceCast: {
268 auto *LLVMCast = cast<llvm::CastInst>(Val: LLVMV);
269 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
270 return It->second.get();
271 }
272 case llvm::Instruction::PHI: {
273 auto *LLVMPhi = cast<llvm::PHINode>(Val: LLVMV);
274 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
275 return It->second.get();
276 }
277 case llvm::Instruction::ICmp: {
278 auto *LLVMICmp = cast<llvm::ICmpInst>(Val: LLVMV);
279 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
280 return It->second.get();
281 }
282 case llvm::Instruction::FCmp: {
283 auto *LLVMFCmp = cast<llvm::FCmpInst>(Val: LLVMV);
284 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
285 return It->second.get();
286 }
287 case llvm::Instruction::Unreachable: {
288 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(Val: LLVMV);
289 It->second = std::unique_ptr<UnreachableInst>(
290 new UnreachableInst(LLVMUnreachable, *this));
291 return It->second.get();
292 }
293 default:
294 break;
295 }
296 It->second = std::unique_ptr<OpaqueInst>(
297 new OpaqueInst(cast<llvm::Instruction>(Val: LLVMV), *this));
298 return It->second.get();
299 }
300 // Constant
301 if (auto *LLVMC = dyn_cast<llvm::Constant>(Val: LLVMV)) {
302 switch (LLVMC->getValueID()) {
303 case llvm::Value::ConstantIntVal:
304 It->second = std::unique_ptr<ConstantInt>(
305 new ConstantInt(cast<llvm::ConstantInt>(Val: LLVMC), *this));
306 return It->second.get();
307 case llvm::Value::ConstantFPVal:
308 It->second = std::unique_ptr<ConstantFP>(
309 new ConstantFP(cast<llvm::ConstantFP>(Val: LLVMC), *this));
310 return It->second.get();
311 case llvm::Value::BlockAddressVal:
312 It->second = std::unique_ptr<BlockAddress>(
313 new BlockAddress(cast<llvm::BlockAddress>(Val: LLVMC), *this));
314 return It->second.get();
315 case llvm::Value::ConstantTokenNoneVal:
316 It->second = std::unique_ptr<ConstantTokenNone>(
317 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(Val: LLVMC), *this));
318 return It->second.get();
319 case llvm::Value::ConstantAggregateZeroVal: {
320 auto *CAZ = cast<llvm::ConstantAggregateZero>(Val: LLVMC);
321 It->second = std::unique_ptr<ConstantAggregateZero>(
322 new ConstantAggregateZero(CAZ, *this));
323 auto *Ret = It->second.get();
324 // Must create sandboxir for elements.
325 auto EC = CAZ->getElementCount();
326 if (EC.isFixed()) {
327 for (auto ElmIdx : seq<unsigned>(Begin: 0, End: EC.getFixedValue()))
328 getOrCreateValueInternal(LLVMV: CAZ->getElementValue(Idx: ElmIdx), U: CAZ);
329 }
330 return Ret;
331 }
332 case llvm::Value::ConstantPointerNullVal:
333 It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
334 cast<llvm::ConstantPointerNull>(Val: LLVMC), *this));
335 return It->second.get();
336 case llvm::Value::PoisonValueVal:
337 It->second = std::unique_ptr<PoisonValue>(
338 new PoisonValue(cast<llvm::PoisonValue>(Val: LLVMC), *this));
339 return It->second.get();
340 case llvm::Value::UndefValueVal:
341 It->second = std::unique_ptr<UndefValue>(
342 new UndefValue(cast<llvm::UndefValue>(Val: LLVMC), *this));
343 return It->second.get();
344 case llvm::Value::DSOLocalEquivalentVal: {
345 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(Val: LLVMC);
346 It->second = std::unique_ptr<DSOLocalEquivalent>(
347 new DSOLocalEquivalent(DSOLE, *this));
348 auto *Ret = It->second.get();
349 getOrCreateValueInternal(LLVMV: DSOLE->getGlobalValue(), U: DSOLE);
350 return Ret;
351 }
352 case llvm::Value::ConstantArrayVal:
353 It->second = std::unique_ptr<ConstantArray>(
354 new ConstantArray(cast<llvm::ConstantArray>(Val: LLVMC), *this));
355 break;
356 case llvm::Value::ConstantStructVal:
357 It->second = std::unique_ptr<ConstantStruct>(
358 new ConstantStruct(cast<llvm::ConstantStruct>(Val: LLVMC), *this));
359 break;
360 case llvm::Value::ConstantVectorVal:
361 It->second = std::unique_ptr<ConstantVector>(
362 new ConstantVector(cast<llvm::ConstantVector>(Val: LLVMC), *this));
363 break;
364 case llvm::Value::ConstantDataArrayVal:
365 It->second = std::unique_ptr<ConstantDataArray>(
366 new ConstantDataArray(cast<llvm::ConstantDataArray>(Val: LLVMC), *this));
367 break;
368 case llvm::Value::ConstantDataVectorVal:
369 It->second = std::unique_ptr<ConstantDataVector>(
370 new ConstantDataVector(cast<llvm::ConstantDataVector>(Val: LLVMC), *this));
371 break;
372 case llvm::Value::FunctionVal:
373 It->second = std::unique_ptr<Function>(
374 new Function(cast<llvm::Function>(Val: LLVMC), *this));
375 break;
376 case llvm::Value::GlobalIFuncVal:
377 It->second = std::unique_ptr<GlobalIFunc>(
378 new GlobalIFunc(cast<llvm::GlobalIFunc>(Val: LLVMC), *this));
379 break;
380 case llvm::Value::GlobalVariableVal:
381 It->second = std::unique_ptr<GlobalVariable>(
382 new GlobalVariable(cast<llvm::GlobalVariable>(Val: LLVMC), *this));
383 break;
384 case llvm::Value::GlobalAliasVal:
385 It->second = std::unique_ptr<GlobalAlias>(
386 new GlobalAlias(cast<llvm::GlobalAlias>(Val: LLVMC), *this));
387 break;
388 case llvm::Value::NoCFIValueVal:
389 It->second = std::unique_ptr<NoCFIValue>(
390 new NoCFIValue(cast<llvm::NoCFIValue>(Val: LLVMC), *this));
391 break;
392 case llvm::Value::ConstantPtrAuthVal:
393 It->second = std::unique_ptr<ConstantPtrAuth>(
394 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(Val: LLVMC), *this));
395 break;
396 case llvm::Value::ConstantExprVal:
397 It->second = std::unique_ptr<ConstantExpr>(
398 new ConstantExpr(cast<llvm::ConstantExpr>(Val: LLVMC), *this));
399 break;
400 default:
401 It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
402 break;
403 }
404 auto *NewC = It->second.get();
405 for (llvm::Value *COp : LLVMC->operands())
406 getOrCreateValueInternal(LLVMV: COp, U: LLVMC);
407 return NewC;
408 }
409 // Argument
410 if (auto *LLVMArg = dyn_cast<llvm::Argument>(Val: LLVMV)) {
411 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
412 return It->second.get();
413 }
414 // BasicBlock
415 if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(Val: LLVMV)) {
416 assert(isa<llvm::BlockAddress>(U) &&
417 "This won't create a SBBB, don't call this function directly!");
418 if (auto *SBBB = getValue(V: LLVMBB))
419 return SBBB;
420 return nullptr;
421 }
422 // Metadata
423 if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(Val: LLVMV)) {
424 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
425 return It->second.get();
426 }
427 // InlineAsm
428 if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(Val: LLVMV)) {
429 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
430 return It->second.get();
431 }
432 llvm_unreachable("Unhandled LLVMV type!");
433}
434
435Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
436 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMArg);
437 auto It = Pair.first;
438 if (Pair.second) {
439 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
440 return cast<Argument>(Val: It->second.get());
441 }
442 return cast<Argument>(Val: It->second.get());
443}
444
445Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
446 return cast<Constant>(Val: getOrCreateValueInternal(LLVMV: LLVMC, U: nullptr));
447}
448
449BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
450 assert(getValue(LLVMBB) == nullptr && "Already exists!");
451 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
452 auto *BB = cast<BasicBlock>(Val: registerValue(VPtr: std::move(NewBBPtr)));
453 // Create SandboxIR for BB's body.
454 BB->buildBasicBlockFromLLVMIR(LLVMBB);
455 return BB;
456}
457
458VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
459 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
460 return cast<VAArgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
461}
462
463FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
464 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
465 return cast<FreezeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
466}
467
468FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
469 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
470 return cast<FenceInst>(Val: registerValue(VPtr: std::move(NewPtr)));
471}
472
473SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
474 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
475 return cast<SelectInst>(Val: registerValue(VPtr: std::move(NewPtr)));
476}
477
478ExtractElementInst *
479Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
480 auto NewPtr =
481 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
482 return cast<ExtractElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
483}
484
485InsertElementInst *
486Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
487 auto NewPtr =
488 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
489 return cast<InsertElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
490}
491
492ShuffleVectorInst *
493Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
494 auto NewPtr =
495 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
496 return cast<ShuffleVectorInst>(Val: registerValue(VPtr: std::move(NewPtr)));
497}
498
499ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
500 auto NewPtr =
501 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
502 return cast<ExtractValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
503}
504
505InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
506 auto NewPtr =
507 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
508 return cast<InsertValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
509}
510
511BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
512 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
513 return cast<BranchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
514}
515
516LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
517 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
518 return cast<LoadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
519}
520
521StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
522 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
523 return cast<StoreInst>(Val: registerValue(VPtr: std::move(NewPtr)));
524}
525
526ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
527 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
528 return cast<ReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
529}
530
531CallInst *Context::createCallInst(llvm::CallInst *I) {
532 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
533 return cast<CallInst>(Val: registerValue(VPtr: std::move(NewPtr)));
534}
535
536InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
537 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
538 return cast<InvokeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
539}
540
541CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
542 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
543 return cast<CallBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
544}
545
546UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
547 auto NewPtr =
548 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
549 return cast<UnreachableInst>(Val: registerValue(VPtr: std::move(NewPtr)));
550}
551LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
552 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
553 return cast<LandingPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
554}
555CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
556 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
557 return cast<CatchPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
558}
559CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
560 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
561 return cast<CleanupPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
562}
563CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
564 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
565 return cast<CatchReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
566}
567CleanupReturnInst *
568Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
569 auto NewPtr =
570 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
571 return cast<CleanupReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
572}
573GetElementPtrInst *
574Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
575 auto NewPtr =
576 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
577 return cast<GetElementPtrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
578}
579CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
580 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
581 return cast<CatchSwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
582}
583ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
584 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
585 return cast<ResumeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
586}
587SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
588 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
589 return cast<SwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
590}
591UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
592 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
593 return cast<UnaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
594}
595BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
596 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
597 return cast<BinaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
598}
599AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
600 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
601 return cast<AtomicRMWInst>(Val: registerValue(VPtr: std::move(NewPtr)));
602}
603AtomicCmpXchgInst *
604Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
605 auto NewPtr =
606 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
607 return cast<AtomicCmpXchgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
608}
609AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
610 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
611 return cast<AllocaInst>(Val: registerValue(VPtr: std::move(NewPtr)));
612}
613CastInst *Context::createCastInst(llvm::CastInst *I) {
614 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
615 return cast<CastInst>(Val: registerValue(VPtr: std::move(NewPtr)));
616}
617PHINode *Context::createPHINode(llvm::PHINode *I) {
618 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
619 return cast<PHINode>(Val: registerValue(VPtr: std::move(NewPtr)));
620}
621ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
622 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
623 return cast<ICmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
624}
625FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
626 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
627 return cast<FCmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
628}
629Value *Context::getValue(llvm::Value *V) const {
630 auto It = LLVMValueToValueMap.find(Val: V);
631 if (It != LLVMValueToValueMap.end())
632 return It->second.get();
633 return nullptr;
634}
635
636Context::Context(LLVMContext &LLVMCtx)
637 : LLVMCtx(LLVMCtx), IRTracker(*this),
638 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
639
640Context::~Context() = default;
641
642void Context::clear() {
643 // TODO: Ideally we should clear only function-scope objects, and keep global
644 // objects, like Constants to avoid recreating them.
645 LLVMValueToValueMap.clear();
646}
647
648Module *Context::getModule(llvm::Module *LLVMM) const {
649 auto It = LLVMModuleToModuleMap.find(Val: LLVMM);
650 if (It != LLVMModuleToModuleMap.end())
651 return It->second.get();
652 return nullptr;
653}
654
655Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
656 auto Pair = LLVMModuleToModuleMap.try_emplace(Key: LLVMM);
657 auto It = Pair.first;
658 if (!Pair.second)
659 return It->second.get();
660 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
661 return It->second.get();
662}
663
664Function *Context::createFunction(llvm::Function *F) {
665 // Create the module if needed before we create the new sandboxir::Function.
666 // Note: this won't fully populate the module. The only globals that will be
667 // available will be the ones being used within the function.
668 getOrCreateModule(LLVMM: F->getParent());
669
670 // There may be a function declaration already defined. Regardless destroy it.
671 if (Function *ExistingF = cast_or_null<Function>(Val: getValue(V: F)))
672 detach(V: ExistingF);
673
674 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
675 auto *SBF = cast<Function>(Val: registerValue(VPtr: std::move(NewFPtr)));
676 // Create arguments.
677 for (auto &Arg : F->args())
678 getOrCreateArgument(LLVMArg: &Arg);
679 // Create BBs.
680 for (auto &BB : *F)
681 createBasicBlock(LLVMBB: &BB);
682 return SBF;
683}
684
685Module *Context::createModule(llvm::Module *LLVMM) {
686 auto *M = getOrCreateModule(LLVMM);
687 // Create the functions.
688 for (auto &LLVMF : *LLVMM)
689 createFunction(F: &LLVMF);
690 // Create globals.
691 for (auto &Global : LLVMM->globals())
692 getOrCreateValue(LLVMV: &Global);
693 // Create aliases.
694 for (auto &Alias : LLVMM->aliases())
695 getOrCreateValue(LLVMV: &Alias);
696 // Create ifuncs.
697 for (auto &IFunc : LLVMM->ifuncs())
698 getOrCreateValue(LLVMV: &IFunc);
699
700 return M;
701}
702
703void Context::runEraseInstrCallbacks(Instruction *I) {
704 for (const auto &CBEntry : EraseInstrCallbacks)
705 CBEntry.second(I);
706}
707
708void Context::runCreateInstrCallbacks(Instruction *I) {
709 for (auto &CBEntry : CreateInstrCallbacks)
710 CBEntry.second(I);
711}
712
713void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
714 for (auto &CBEntry : MoveInstrCallbacks)
715 CBEntry.second(I, WhereIt);
716}
717
718void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
719 for (auto &CBEntry : SetUseCallbacks)
720 CBEntry.second(U, NewSrc);
721}
722
723// An arbitrary limit, to check for accidental misuse. We expect a small number
724// of callbacks to be registered at a time, but we can increase this number if
725// we discover we needed more.
726[[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
727
728Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
729 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
730 "EraseInstrCallbacks size limit exceeded");
731 CallbackID ID{NextCallbackID++};
732 EraseInstrCallbacks[ID] = std::move(CB);
733 return ID;
734}
735void Context::unregisterEraseInstrCallback(CallbackID ID) {
736 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(Key: ID);
737 assert(Erased &&
738 "Callback ID not found in EraseInstrCallbacks during deregistration");
739}
740
741Context::CallbackID
742Context::registerCreateInstrCallback(CreateInstrCallback CB) {
743 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
744 "CreateInstrCallbacks size limit exceeded");
745 CallbackID ID{NextCallbackID++};
746 CreateInstrCallbacks[ID] = std::move(CB);
747 return ID;
748}
749void Context::unregisterCreateInstrCallback(CallbackID ID) {
750 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(Key: ID);
751 assert(Erased &&
752 "Callback ID not found in CreateInstrCallbacks during deregistration");
753}
754
755Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
756 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
757 "MoveInstrCallbacks size limit exceeded");
758 CallbackID ID{NextCallbackID++};
759 MoveInstrCallbacks[ID] = std::move(CB);
760 return ID;
761}
762void Context::unregisterMoveInstrCallback(CallbackID ID) {
763 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(Key: ID);
764 assert(Erased &&
765 "Callback ID not found in MoveInstrCallbacks during deregistration");
766}
767
768Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
769 assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
770 "SetUseCallbacks size limit exceeded");
771 CallbackID ID{NextCallbackID++};
772 SetUseCallbacks[ID] = std::move(CB);
773 return ID;
774}
775void Context::unregisterSetUseCallback(CallbackID ID) {
776 [[maybe_unused]] bool Erased = SetUseCallbacks.erase(Key: ID);
777 assert(Erased &&
778 "Callback ID not found in SetUseCallbacks during deregistration");
779}
780
781} // namespace llvm::sandboxir
782