1//===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/SandboxIR/Context.h"
10#include "llvm/IR/InlineAsm.h"
11#include "llvm/SandboxIR/Function.h"
12#include "llvm/SandboxIR/Instruction.h"
13#include "llvm/SandboxIR/Module.h"
14
15namespace llvm::sandboxir {
16
17std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18 std::unique_ptr<Value> Erased;
19 auto It = LLVMValueToValueMap.find(Val: V);
20 if (It != LLVMValueToValueMap.end()) {
21 auto *Val = It->second.release();
22 Erased = std::unique_ptr<Value>(Val);
23 LLVMValueToValueMap.erase(I: It);
24 }
25 return Erased;
26}
27
28std::unique_ptr<Value> Context::detach(Value *V) {
29 assert(V->getSubclassID() != Value::ClassID::Constant &&
30 "Can't detach a constant!");
31 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32 return detachLLVMValue(V: V->Val);
33}
34
35Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36 assert(VPtr->getSubclassID() != Value::ClassID::User &&
37 "Can't register a user!");
38
39 Value *V = VPtr.get();
40 [[maybe_unused]] auto Pair =
41 LLVMValueToValueMap.insert(KV: {VPtr->Val, std::move(VPtr)});
42 assert(Pair.second && "Already exists!");
43
44 // Track creation of instructions.
45 // Please note that we don't allow the creation of detached instructions,
46 // meaning that the instructions need to be inserted into a block upon
47 // creation. This is why the tracker class combines creation and insertion.
48 if (auto *I = dyn_cast<Instruction>(Val: V)) {
49 getTracker().emplaceIfTracking<CreateAndInsertInst>(Args: I);
50 runCreateInstrCallbacks(I);
51 }
52
53 return V;
54}
55
56Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMV);
58 auto It = Pair.first;
59 if (!Pair.second)
60 return It->second.get();
61
62 // Instruction
63 if (auto *LLVMI = dyn_cast<llvm::Instruction>(Val: LLVMV)) {
64 switch (LLVMI->getOpcode()) {
65 case llvm::Instruction::VAArg: {
66 auto *LLVMVAArg = cast<llvm::VAArgInst>(Val: LLVMV);
67 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68 return It->second.get();
69 }
70 case llvm::Instruction::Freeze: {
71 auto *LLVMFreeze = cast<llvm::FreezeInst>(Val: LLVMV);
72 It->second =
73 std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74 return It->second.get();
75 }
76 case llvm::Instruction::Fence: {
77 auto *LLVMFence = cast<llvm::FenceInst>(Val: LLVMV);
78 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79 return It->second.get();
80 }
81 case llvm::Instruction::Select: {
82 auto *LLVMSel = cast<llvm::SelectInst>(Val: LLVMV);
83 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84 return It->second.get();
85 }
86 case llvm::Instruction::ExtractElement: {
87 auto *LLVMIns = cast<llvm::ExtractElementInst>(Val: LLVMV);
88 It->second = std::unique_ptr<ExtractElementInst>(
89 new ExtractElementInst(LLVMIns, *this));
90 return It->second.get();
91 }
92 case llvm::Instruction::InsertElement: {
93 auto *LLVMIns = cast<llvm::InsertElementInst>(Val: LLVMV);
94 It->second = std::unique_ptr<InsertElementInst>(
95 new InsertElementInst(LLVMIns, *this));
96 return It->second.get();
97 }
98 case llvm::Instruction::ShuffleVector: {
99 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(Val: LLVMV);
100 It->second = std::unique_ptr<ShuffleVectorInst>(
101 new ShuffleVectorInst(LLVMIns, *this));
102 return It->second.get();
103 }
104 case llvm::Instruction::ExtractValue: {
105 auto *LLVMIns = cast<llvm::ExtractValueInst>(Val: LLVMV);
106 It->second = std::unique_ptr<ExtractValueInst>(
107 new ExtractValueInst(LLVMIns, *this));
108 return It->second.get();
109 }
110 case llvm::Instruction::InsertValue: {
111 auto *LLVMIns = cast<llvm::InsertValueInst>(Val: LLVMV);
112 It->second =
113 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114 return It->second.get();
115 }
116 case llvm::Instruction::UncondBr: {
117 auto *LLVMBr = cast<llvm::UncondBrInst>(Val: LLVMV);
118 It->second =
119 std::unique_ptr<UncondBrInst>(new UncondBrInst(LLVMBr, *this));
120 return It->second.get();
121 }
122 case llvm::Instruction::CondBr: {
123 auto *LLVMBr = cast<llvm::CondBrInst>(Val: LLVMV);
124 It->second = std::unique_ptr<CondBrInst>(new CondBrInst(LLVMBr, *this));
125 return It->second.get();
126 }
127 case llvm::Instruction::Load: {
128 auto *LLVMLd = cast<llvm::LoadInst>(Val: LLVMV);
129 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
130 return It->second.get();
131 }
132 case llvm::Instruction::Store: {
133 auto *LLVMSt = cast<llvm::StoreInst>(Val: LLVMV);
134 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
135 return It->second.get();
136 }
137 case llvm::Instruction::Ret: {
138 auto *LLVMRet = cast<llvm::ReturnInst>(Val: LLVMV);
139 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
140 return It->second.get();
141 }
142 case llvm::Instruction::Call: {
143 auto *LLVMCall = cast<llvm::CallInst>(Val: LLVMV);
144 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
145 return It->second.get();
146 }
147 case llvm::Instruction::Invoke: {
148 auto *LLVMInvoke = cast<llvm::InvokeInst>(Val: LLVMV);
149 It->second =
150 std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
151 return It->second.get();
152 }
153 case llvm::Instruction::CallBr: {
154 auto *LLVMCallBr = cast<llvm::CallBrInst>(Val: LLVMV);
155 It->second =
156 std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
157 return It->second.get();
158 }
159 case llvm::Instruction::LandingPad: {
160 auto *LLVMLPad = cast<llvm::LandingPadInst>(Val: LLVMV);
161 It->second =
162 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
163 return It->second.get();
164 }
165 case llvm::Instruction::CatchPad: {
166 auto *LLVMCPI = cast<llvm::CatchPadInst>(Val: LLVMV);
167 It->second =
168 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
169 return It->second.get();
170 }
171 case llvm::Instruction::CleanupPad: {
172 auto *LLVMCPI = cast<llvm::CleanupPadInst>(Val: LLVMV);
173 It->second =
174 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
175 return It->second.get();
176 }
177 case llvm::Instruction::CatchRet: {
178 auto *LLVMCRI = cast<llvm::CatchReturnInst>(Val: LLVMV);
179 It->second =
180 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
181 return It->second.get();
182 }
183 case llvm::Instruction::CleanupRet: {
184 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(Val: LLVMV);
185 It->second = std::unique_ptr<CleanupReturnInst>(
186 new CleanupReturnInst(LLVMCRI, *this));
187 return It->second.get();
188 }
189 case llvm::Instruction::GetElementPtr: {
190 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(Val: LLVMV);
191 It->second = std::unique_ptr<GetElementPtrInst>(
192 new GetElementPtrInst(LLVMGEP, *this));
193 return It->second.get();
194 }
195 case llvm::Instruction::CatchSwitch: {
196 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(Val: LLVMV);
197 It->second = std::unique_ptr<CatchSwitchInst>(
198 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
199 return It->second.get();
200 }
201 case llvm::Instruction::Resume: {
202 auto *LLVMResumeInst = cast<llvm::ResumeInst>(Val: LLVMV);
203 It->second =
204 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
205 return It->second.get();
206 }
207 case llvm::Instruction::Switch: {
208 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(Val: LLVMV);
209 It->second =
210 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
211 return It->second.get();
212 }
213 case llvm::Instruction::FNeg: {
214 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(Val: LLVMV);
215 It->second = std::unique_ptr<UnaryOperator>(
216 new UnaryOperator(LLVMUnaryOperator, *this));
217 return It->second.get();
218 }
219 case llvm::Instruction::Add:
220 case llvm::Instruction::FAdd:
221 case llvm::Instruction::Sub:
222 case llvm::Instruction::FSub:
223 case llvm::Instruction::Mul:
224 case llvm::Instruction::FMul:
225 case llvm::Instruction::UDiv:
226 case llvm::Instruction::SDiv:
227 case llvm::Instruction::FDiv:
228 case llvm::Instruction::URem:
229 case llvm::Instruction::SRem:
230 case llvm::Instruction::FRem:
231 case llvm::Instruction::Shl:
232 case llvm::Instruction::LShr:
233 case llvm::Instruction::AShr:
234 case llvm::Instruction::And:
235 case llvm::Instruction::Or:
236 case llvm::Instruction::Xor: {
237 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(Val: LLVMV);
238 It->second = std::unique_ptr<BinaryOperator>(
239 new BinaryOperator(LLVMBinaryOperator, *this));
240 return It->second.get();
241 }
242 case llvm::Instruction::AtomicRMW: {
243 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(Val: LLVMV);
244 It->second = std::unique_ptr<AtomicRMWInst>(
245 new AtomicRMWInst(LLVMAtomicRMW, *this));
246 return It->second.get();
247 }
248 case llvm::Instruction::AtomicCmpXchg: {
249 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(Val: LLVMV);
250 It->second = std::unique_ptr<AtomicCmpXchgInst>(
251 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
252 return It->second.get();
253 }
254 case llvm::Instruction::Alloca: {
255 auto *LLVMAlloca = cast<llvm::AllocaInst>(Val: LLVMV);
256 It->second =
257 std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
258 return It->second.get();
259 }
260 case llvm::Instruction::ZExt:
261 case llvm::Instruction::SExt:
262 case llvm::Instruction::FPToUI:
263 case llvm::Instruction::FPToSI:
264 case llvm::Instruction::FPExt:
265 case llvm::Instruction::PtrToAddr:
266 case llvm::Instruction::PtrToInt:
267 case llvm::Instruction::IntToPtr:
268 case llvm::Instruction::SIToFP:
269 case llvm::Instruction::UIToFP:
270 case llvm::Instruction::Trunc:
271 case llvm::Instruction::FPTrunc:
272 case llvm::Instruction::BitCast:
273 case llvm::Instruction::AddrSpaceCast: {
274 auto *LLVMCast = cast<llvm::CastInst>(Val: LLVMV);
275 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
276 return It->second.get();
277 }
278 case llvm::Instruction::PHI: {
279 auto *LLVMPhi = cast<llvm::PHINode>(Val: LLVMV);
280 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
281 return It->second.get();
282 }
283 case llvm::Instruction::ICmp: {
284 auto *LLVMICmp = cast<llvm::ICmpInst>(Val: LLVMV);
285 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
286 return It->second.get();
287 }
288 case llvm::Instruction::FCmp: {
289 auto *LLVMFCmp = cast<llvm::FCmpInst>(Val: LLVMV);
290 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
291 return It->second.get();
292 }
293 case llvm::Instruction::Unreachable: {
294 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(Val: LLVMV);
295 It->second = std::unique_ptr<UnreachableInst>(
296 new UnreachableInst(LLVMUnreachable, *this));
297 return It->second.get();
298 }
299 default:
300 break;
301 }
302 It->second = std::unique_ptr<OpaqueInst>(
303 new OpaqueInst(cast<llvm::Instruction>(Val: LLVMV), *this));
304 return It->second.get();
305 }
306 // Constant
307 if (auto *LLVMC = dyn_cast<llvm::Constant>(Val: LLVMV)) {
308 switch (LLVMC->getValueID()) {
309 case llvm::Value::ConstantIntVal:
310 It->second = std::unique_ptr<ConstantInt>(
311 new ConstantInt(cast<llvm::ConstantInt>(Val: LLVMC), *this));
312 return It->second.get();
313 case llvm::Value::ConstantFPVal:
314 It->second = std::unique_ptr<ConstantFP>(
315 new ConstantFP(cast<llvm::ConstantFP>(Val: LLVMC), *this));
316 return It->second.get();
317 case llvm::Value::BlockAddressVal:
318 It->second = std::unique_ptr<BlockAddress>(
319 new BlockAddress(cast<llvm::BlockAddress>(Val: LLVMC), *this));
320 return It->second.get();
321 case llvm::Value::ConstantTokenNoneVal:
322 It->second = std::unique_ptr<ConstantTokenNone>(
323 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(Val: LLVMC), *this));
324 return It->second.get();
325 case llvm::Value::ConstantAggregateZeroVal: {
326 auto *CAZ = cast<llvm::ConstantAggregateZero>(Val: LLVMC);
327 It->second = std::unique_ptr<ConstantAggregateZero>(
328 new ConstantAggregateZero(CAZ, *this));
329 auto *Ret = It->second.get();
330 // Must create sandboxir for elements.
331 auto EC = CAZ->getElementCount();
332 if (EC.isFixed()) {
333 for (auto ElmIdx : seq<unsigned>(Begin: 0, End: EC.getFixedValue()))
334 getOrCreateValueInternal(LLVMV: CAZ->getElementValue(Idx: ElmIdx), U: CAZ);
335 }
336 return Ret;
337 }
338 case llvm::Value::ConstantPointerNullVal:
339 It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
340 cast<llvm::ConstantPointerNull>(Val: LLVMC), *this));
341 return It->second.get();
342 case llvm::Value::PoisonValueVal:
343 It->second = std::unique_ptr<PoisonValue>(
344 new PoisonValue(cast<llvm::PoisonValue>(Val: LLVMC), *this));
345 return It->second.get();
346 case llvm::Value::UndefValueVal:
347 It->second = std::unique_ptr<UndefValue>(
348 new UndefValue(cast<llvm::UndefValue>(Val: LLVMC), *this));
349 return It->second.get();
350 case llvm::Value::DSOLocalEquivalentVal: {
351 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(Val: LLVMC);
352 It->second = std::unique_ptr<DSOLocalEquivalent>(
353 new DSOLocalEquivalent(DSOLE, *this));
354 auto *Ret = It->second.get();
355 getOrCreateValueInternal(LLVMV: DSOLE->getGlobalValue(), U: DSOLE);
356 return Ret;
357 }
358 case llvm::Value::ConstantArrayVal:
359 It->second = std::unique_ptr<ConstantArray>(
360 new ConstantArray(cast<llvm::ConstantArray>(Val: LLVMC), *this));
361 break;
362 case llvm::Value::ConstantStructVal:
363 It->second = std::unique_ptr<ConstantStruct>(
364 new ConstantStruct(cast<llvm::ConstantStruct>(Val: LLVMC), *this));
365 break;
366 case llvm::Value::ConstantVectorVal:
367 It->second = std::unique_ptr<ConstantVector>(
368 new ConstantVector(cast<llvm::ConstantVector>(Val: LLVMC), *this));
369 break;
370 case llvm::Value::ConstantDataArrayVal:
371 It->second = std::unique_ptr<ConstantDataArray>(
372 new ConstantDataArray(cast<llvm::ConstantDataArray>(Val: LLVMC), *this));
373 break;
374 case llvm::Value::ConstantDataVectorVal:
375 It->second = std::unique_ptr<ConstantDataVector>(
376 new ConstantDataVector(cast<llvm::ConstantDataVector>(Val: LLVMC), *this));
377 break;
378 case llvm::Value::FunctionVal:
379 It->second = std::unique_ptr<Function>(
380 new Function(cast<llvm::Function>(Val: LLVMC), *this));
381 break;
382 case llvm::Value::GlobalIFuncVal:
383 It->second = std::unique_ptr<GlobalIFunc>(
384 new GlobalIFunc(cast<llvm::GlobalIFunc>(Val: LLVMC), *this));
385 break;
386 case llvm::Value::GlobalVariableVal:
387 It->second = std::unique_ptr<GlobalVariable>(
388 new GlobalVariable(cast<llvm::GlobalVariable>(Val: LLVMC), *this));
389 break;
390 case llvm::Value::GlobalAliasVal:
391 It->second = std::unique_ptr<GlobalAlias>(
392 new GlobalAlias(cast<llvm::GlobalAlias>(Val: LLVMC), *this));
393 break;
394 case llvm::Value::NoCFIValueVal:
395 It->second = std::unique_ptr<NoCFIValue>(
396 new NoCFIValue(cast<llvm::NoCFIValue>(Val: LLVMC), *this));
397 break;
398 case llvm::Value::ConstantPtrAuthVal:
399 It->second = std::unique_ptr<ConstantPtrAuth>(
400 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(Val: LLVMC), *this));
401 break;
402 case llvm::Value::ConstantExprVal:
403 It->second = std::unique_ptr<ConstantExpr>(
404 new ConstantExpr(cast<llvm::ConstantExpr>(Val: LLVMC), *this));
405 break;
406 default:
407 It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
408 break;
409 }
410 auto *NewC = It->second.get();
411 for (llvm::Value *COp : LLVMC->operands())
412 getOrCreateValueInternal(LLVMV: COp, U: LLVMC);
413 return NewC;
414 }
415 // Argument
416 if (auto *LLVMArg = dyn_cast<llvm::Argument>(Val: LLVMV)) {
417 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
418 return It->second.get();
419 }
420 // BasicBlock
421 if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(Val: LLVMV)) {
422 assert(isa<llvm::BlockAddress>(U) &&
423 "This won't create a SBBB, don't call this function directly!");
424 if (auto *SBBB = getValue(V: LLVMBB))
425 return SBBB;
426 return nullptr;
427 }
428 // Metadata
429 if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(Val: LLVMV)) {
430 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
431 return It->second.get();
432 }
433 // InlineAsm
434 if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(Val: LLVMV)) {
435 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
436 return It->second.get();
437 }
438 llvm_unreachable("Unhandled LLVMV type!");
439}
440
441Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
442 auto Pair = LLVMValueToValueMap.try_emplace(Key: LLVMArg);
443 auto It = Pair.first;
444 if (Pair.second) {
445 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
446 return cast<Argument>(Val: It->second.get());
447 }
448 return cast<Argument>(Val: It->second.get());
449}
450
451Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
452 return cast<Constant>(Val: getOrCreateValueInternal(LLVMV: LLVMC, U: nullptr));
453}
454
455BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
456 assert(getValue(LLVMBB) == nullptr && "Already exists!");
457 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
458 auto *BB = cast<BasicBlock>(Val: registerValue(VPtr: std::move(NewBBPtr)));
459 // Create SandboxIR for BB's body.
460 BB->buildBasicBlockFromLLVMIR(LLVMBB);
461 return BB;
462}
463
464VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
465 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
466 return cast<VAArgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
467}
468
469FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
470 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
471 return cast<FreezeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
472}
473
474FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
475 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
476 return cast<FenceInst>(Val: registerValue(VPtr: std::move(NewPtr)));
477}
478
479SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
480 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
481 return cast<SelectInst>(Val: registerValue(VPtr: std::move(NewPtr)));
482}
483
484ExtractElementInst *
485Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
486 auto NewPtr =
487 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
488 return cast<ExtractElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
489}
490
491InsertElementInst *
492Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
493 auto NewPtr =
494 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
495 return cast<InsertElementInst>(Val: registerValue(VPtr: std::move(NewPtr)));
496}
497
498ShuffleVectorInst *
499Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
500 auto NewPtr =
501 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
502 return cast<ShuffleVectorInst>(Val: registerValue(VPtr: std::move(NewPtr)));
503}
504
505ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
506 auto NewPtr =
507 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
508 return cast<ExtractValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
509}
510
511InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
512 auto NewPtr =
513 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
514 return cast<InsertValueInst>(Val: registerValue(VPtr: std::move(NewPtr)));
515}
516
517UncondBrInst *Context::createUncondBrInst(llvm::UncondBrInst *UBI) {
518 auto NewPtr = std::unique_ptr<UncondBrInst>(new UncondBrInst(UBI, *this));
519 return cast<UncondBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
520}
521
522CondBrInst *Context::createCondBrInst(llvm::CondBrInst *CBI) {
523 auto NewPtr = std::unique_ptr<CondBrInst>(new CondBrInst(CBI, *this));
524 return cast<CondBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
525}
526
527LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
528 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
529 return cast<LoadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
530}
531
532StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
533 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
534 return cast<StoreInst>(Val: registerValue(VPtr: std::move(NewPtr)));
535}
536
537ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
538 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
539 return cast<ReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
540}
541
542CallInst *Context::createCallInst(llvm::CallInst *I) {
543 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
544 return cast<CallInst>(Val: registerValue(VPtr: std::move(NewPtr)));
545}
546
547InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
548 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
549 return cast<InvokeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
550}
551
552CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
553 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
554 return cast<CallBrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
555}
556
557UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
558 auto NewPtr =
559 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
560 return cast<UnreachableInst>(Val: registerValue(VPtr: std::move(NewPtr)));
561}
562LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
563 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
564 return cast<LandingPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
565}
566CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
567 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
568 return cast<CatchPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
569}
570CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
571 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
572 return cast<CleanupPadInst>(Val: registerValue(VPtr: std::move(NewPtr)));
573}
574CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
575 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
576 return cast<CatchReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
577}
578CleanupReturnInst *
579Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
580 auto NewPtr =
581 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
582 return cast<CleanupReturnInst>(Val: registerValue(VPtr: std::move(NewPtr)));
583}
584GetElementPtrInst *
585Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
586 auto NewPtr =
587 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
588 return cast<GetElementPtrInst>(Val: registerValue(VPtr: std::move(NewPtr)));
589}
590CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
591 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
592 return cast<CatchSwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
593}
594ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
595 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
596 return cast<ResumeInst>(Val: registerValue(VPtr: std::move(NewPtr)));
597}
598SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
599 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
600 return cast<SwitchInst>(Val: registerValue(VPtr: std::move(NewPtr)));
601}
602UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
603 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
604 return cast<UnaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
605}
606BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
607 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
608 return cast<BinaryOperator>(Val: registerValue(VPtr: std::move(NewPtr)));
609}
610AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
611 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
612 return cast<AtomicRMWInst>(Val: registerValue(VPtr: std::move(NewPtr)));
613}
614AtomicCmpXchgInst *
615Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
616 auto NewPtr =
617 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
618 return cast<AtomicCmpXchgInst>(Val: registerValue(VPtr: std::move(NewPtr)));
619}
620AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
621 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
622 return cast<AllocaInst>(Val: registerValue(VPtr: std::move(NewPtr)));
623}
624CastInst *Context::createCastInst(llvm::CastInst *I) {
625 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
626 return cast<CastInst>(Val: registerValue(VPtr: std::move(NewPtr)));
627}
628PHINode *Context::createPHINode(llvm::PHINode *I) {
629 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
630 return cast<PHINode>(Val: registerValue(VPtr: std::move(NewPtr)));
631}
632ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
633 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
634 return cast<ICmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
635}
636FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
637 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
638 return cast<FCmpInst>(Val: registerValue(VPtr: std::move(NewPtr)));
639}
640Value *Context::getValue(llvm::Value *V) const {
641 auto It = LLVMValueToValueMap.find(Val: V);
642 if (It != LLVMValueToValueMap.end())
643 return It->second.get();
644 return nullptr;
645}
646
647Context::Context(LLVMContext &LLVMCtx)
648 : LLVMCtx(LLVMCtx), IRTracker(*this),
649 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
650
651Context::~Context() = default;
652
653void Context::clear() {
654 // TODO: Ideally we should clear only function-scope objects, and keep global
655 // objects, like Constants to avoid recreating them.
656 LLVMValueToValueMap.clear();
657}
658
659Module *Context::getModule(llvm::Module *LLVMM) const {
660 auto It = LLVMModuleToModuleMap.find(Val: LLVMM);
661 if (It != LLVMModuleToModuleMap.end())
662 return It->second.get();
663 return nullptr;
664}
665
666Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
667 auto Pair = LLVMModuleToModuleMap.try_emplace(Key: LLVMM);
668 auto It = Pair.first;
669 if (!Pair.second)
670 return It->second.get();
671 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
672 return It->second.get();
673}
674
675Function *Context::createFunction(llvm::Function *F) {
676 // Create the module if needed before we create the new sandboxir::Function.
677 // Note: this won't fully populate the module. The only globals that will be
678 // available will be the ones being used within the function.
679 getOrCreateModule(LLVMM: F->getParent());
680
681 // There may be a function declaration already defined. Regardless destroy it.
682 if (Function *ExistingF = cast_or_null<Function>(Val: getValue(V: F)))
683 detach(V: ExistingF);
684
685 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
686 auto *SBF = cast<Function>(Val: registerValue(VPtr: std::move(NewFPtr)));
687 // Create arguments.
688 for (auto &Arg : F->args())
689 getOrCreateArgument(LLVMArg: &Arg);
690 // Create BBs.
691 for (auto &BB : *F)
692 createBasicBlock(LLVMBB: &BB);
693 return SBF;
694}
695
696Module *Context::createModule(llvm::Module *LLVMM) {
697 auto *M = getOrCreateModule(LLVMM);
698 // Create the functions.
699 for (auto &LLVMF : *LLVMM)
700 createFunction(F: &LLVMF);
701 // Create globals.
702 for (auto &Global : LLVMM->globals())
703 getOrCreateValue(LLVMV: &Global);
704 // Create aliases.
705 for (auto &Alias : LLVMM->aliases())
706 getOrCreateValue(LLVMV: &Alias);
707 // Create ifuncs.
708 for (auto &IFunc : LLVMM->ifuncs())
709 getOrCreateValue(LLVMV: &IFunc);
710
711 return M;
712}
713
714void Context::runEraseInstrCallbacks(Instruction *I) {
715 for (const auto &CBEntry : EraseInstrCallbacks)
716 CBEntry.second(I);
717}
718
719void Context::runCreateInstrCallbacks(Instruction *I) {
720 for (auto &CBEntry : CreateInstrCallbacks)
721 CBEntry.second(I);
722}
723
724void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
725 for (auto &CBEntry : MoveInstrCallbacks)
726 CBEntry.second(I, WhereIt);
727}
728
729void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
730 for (auto &CBEntry : SetUseCallbacks)
731 CBEntry.second(U, NewSrc);
732}
733
734// An arbitrary limit, to check for accidental misuse. We expect a small number
735// of callbacks to be registered at a time, but we can increase this number if
736// we discover we needed more.
737[[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
738
739Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
740 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
741 "EraseInstrCallbacks size limit exceeded");
742 CallbackID ID{NextCallbackID++};
743 EraseInstrCallbacks[ID] = std::move(CB);
744 return ID;
745}
746void Context::unregisterEraseInstrCallback(CallbackID ID) {
747 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(Key: ID);
748 assert(Erased &&
749 "Callback ID not found in EraseInstrCallbacks during deregistration");
750}
751
752Context::CallbackID
753Context::registerCreateInstrCallback(CreateInstrCallback CB) {
754 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
755 "CreateInstrCallbacks size limit exceeded");
756 CallbackID ID{NextCallbackID++};
757 CreateInstrCallbacks[ID] = std::move(CB);
758 return ID;
759}
760void Context::unregisterCreateInstrCallback(CallbackID ID) {
761 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(Key: ID);
762 assert(Erased &&
763 "Callback ID not found in CreateInstrCallbacks during deregistration");
764}
765
766Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
767 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
768 "MoveInstrCallbacks size limit exceeded");
769 CallbackID ID{NextCallbackID++};
770 MoveInstrCallbacks[ID] = std::move(CB);
771 return ID;
772}
773void Context::unregisterMoveInstrCallback(CallbackID ID) {
774 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(Key: ID);
775 assert(Erased &&
776 "Callback ID not found in MoveInstrCallbacks during deregistration");
777}
778
779Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
780 assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
781 "SetUseCallbacks size limit exceeded");
782 CallbackID ID{NextCallbackID++};
783 SetUseCallbacks[ID] = std::move(CB);
784 return ID;
785}
786void Context::unregisterSetUseCallback(CallbackID ID) {
787 [[maybe_unused]] bool Erased = SetUseCallbacks.erase(Key: ID);
788 assert(Erased &&
789 "Callback ID not found in SetUseCallbacks during deregistration");
790}
791
792} // namespace llvm::sandboxir
793