1//===- Context.cpp - State Tracking for llubi -----------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file tracks the global states (e.g., memory) of the interpreter.
10//
11//===----------------------------------------------------------------------===//
12
13#include "Context.h"
14#include "llvm/Support/MathExtras.h"
15
16namespace llvm::ubi {
17
18Context::Context(Module &M)
19 : Ctx(M.getContext()), M(M), DL(M.getDataLayout()),
20 TLIImpl(M.getTargetTriple()) {}
21
22Context::~Context() = default;
23
24bool Context::initGlobalValues() {
25 // Register all function and block targets that may be used by indirect calls
26 // and branches.
27 for (Function &F : M) {
28 if (F.hasAddressTaken()) {
29 // TODO: Use precise alignment for function pointers if it is necessary.
30 auto FuncObj = allocate(Size: 0, Align: F.getPointerAlignment(DL).value(), Name: F.getName(),
31 AS: DL.getProgramAddressSpace(), InitKind: MemInitKind::Zeroed);
32 if (!FuncObj)
33 return false;
34 ValidFuncTargets.try_emplace(Key: FuncObj->getAddress(),
35 Args: std::make_pair(x: &F, y&: FuncObj));
36 FuncAddrMap.try_emplace(Key: &F, Args: deriveFromMemoryObject(Obj: FuncObj));
37 }
38
39 for (BasicBlock &BB : F) {
40 if (!BB.hasAddressTaken())
41 continue;
42 auto BlockObj = allocate(Size: 0, Align: 1, Name: BB.getName(), AS: DL.getProgramAddressSpace(),
43 InitKind: MemInitKind::Zeroed);
44 if (!BlockObj)
45 return false;
46 ValidBlockTargets.try_emplace(Key: BlockObj->getAddress(),
47 Args: std::make_pair(x: &BB, y&: BlockObj));
48 BlockAddrMap.try_emplace(Key: &BB, Args: deriveFromMemoryObject(Obj: BlockObj));
49 }
50 }
51 // TODO: initialize global variables.
52 return true;
53}
54
55AnyValue Context::getConstantValueImpl(Constant *C) {
56 if (isa<PoisonValue>(Val: C))
57 return AnyValue::getPoisonValue(Ctx&: *this, Ty: C->getType());
58
59 if (isa<ConstantAggregateZero>(Val: C))
60 return AnyValue::getNullValue(Ctx&: *this, Ty: C->getType());
61
62 if (isa<ConstantPointerNull>(Val: C))
63 return Pointer::null(
64 BitWidth: DL.getPointerSizeInBits(AS: C->getType()->getPointerAddressSpace()));
65
66 if (auto *CI = dyn_cast<ConstantInt>(Val: C)) {
67 if (auto *VecTy = dyn_cast<VectorType>(Val: CI->getType()))
68 return std::vector<AnyValue>(getEVL(EC: VecTy->getElementCount()),
69 AnyValue(CI->getValue()));
70 return CI->getValue();
71 }
72
73 if (auto *CDS = dyn_cast<ConstantDataSequential>(Val: C)) {
74 std::vector<AnyValue> Elts;
75 Elts.reserve(n: CDS->getNumElements());
76 for (uint32_t I = 0, E = CDS->getNumElements(); I != E; ++I)
77 Elts.push_back(x: getConstantValue(C: CDS->getElementAsConstant(i: I)));
78 return std::move(Elts);
79 }
80
81 if (auto *CA = dyn_cast<ConstantAggregate>(Val: C)) {
82 std::vector<AnyValue> Elts;
83 Elts.reserve(n: CA->getNumOperands());
84 for (uint32_t I = 0, E = CA->getNumOperands(); I != E; ++I)
85 Elts.push_back(x: getConstantValue(C: CA->getOperand(i_nocapture: I)));
86 return std::move(Elts);
87 }
88
89 if (auto *BA = dyn_cast<BlockAddress>(Val: C))
90 return BlockAddrMap.at(Val: BA->getBasicBlock());
91
92 if (auto *F = dyn_cast<Function>(Val: C))
93 return FuncAddrMap.at(Val: F);
94
95 llvm_unreachable("Unrecognized constant");
96}
97
98const AnyValue &Context::getConstantValue(Constant *C) {
99 auto It = ConstCache.find(x: C);
100 if (It != ConstCache.end())
101 return It->second;
102
103 return ConstCache.emplace(args&: C, args: getConstantValueImpl(C)).first->second;
104}
105
106MemoryObject::~MemoryObject() = default;
107MemoryObject::MemoryObject(uint64_t Addr, uint64_t Size, StringRef Name,
108 unsigned AS, MemInitKind InitKind)
109 : Address(Addr), Size(Size), Name(Name), AS(AS),
110 State(InitKind != MemInitKind::Poisoned ? MemoryObjectState::Alive
111 : MemoryObjectState::Dead) {
112 switch (InitKind) {
113 case MemInitKind::Zeroed:
114 Bytes.resize(N: Size, NV: Byte{.Value: 0, .Kind: ByteKind::Concrete});
115 break;
116 case MemInitKind::Uninitialized:
117 Bytes.resize(N: Size, NV: Byte{.Value: 0, .Kind: ByteKind::Undef});
118 break;
119 case MemInitKind::Poisoned:
120 Bytes.resize(N: Size, NV: Byte{.Value: 0, .Kind: ByteKind::Poison});
121 break;
122 }
123}
124
125IntrusiveRefCntPtr<MemoryObject> Context::allocate(uint64_t Size,
126 uint64_t Align,
127 StringRef Name, unsigned AS,
128 MemInitKind InitKind) {
129 // Even if the memory object is zero-sized, it still occupies a byte to obtain
130 // a unique address.
131 uint64_t AllocateSize = std::max(a: Size, b: (uint64_t)1);
132 if (MaxMem != 0 && SaturatingAdd(X: UsedMem, Y: AllocateSize) >= MaxMem)
133 return nullptr;
134 uint64_t AlignedAddr = alignTo(Value: AllocationBase, Align);
135 auto MemObj =
136 makeIntrusiveRefCnt<MemoryObject>(A&: AlignedAddr, A&: Size, A&: Name, A&: AS, A&: InitKind);
137 MemoryObjects[AlignedAddr] = MemObj;
138 AllocationBase = AlignedAddr + AllocateSize;
139 UsedMem += AllocateSize;
140 return MemObj;
141}
142
143bool Context::free(uint64_t Address) {
144 auto It = MemoryObjects.find(x: Address);
145 if (It == MemoryObjects.end())
146 return false;
147 UsedMem -= std::max(a: It->second->getSize(), b: (uint64_t)1);
148 It->second->markAsFreed();
149 MemoryObjects.erase(position: It);
150 return true;
151}
152
153Pointer Context::deriveFromMemoryObject(IntrusiveRefCntPtr<MemoryObject> Obj) {
154 assert(Obj && "Cannot determine the address space of a null memory object");
155 return Pointer(Obj, APInt(DL.getPointerSizeInBits(AS: Obj->getAddressSpace()),
156 Obj->getAddress()));
157}
158
159Function *Context::getTargetFunction(const Pointer &Ptr) {
160 if (Ptr.address().getActiveBits() > 64)
161 return nullptr;
162 auto It = ValidFuncTargets.find(Val: Ptr.address().getZExtValue());
163 if (It == ValidFuncTargets.end())
164 return nullptr;
165 // TODO: check the provenance of pointer.
166 return It->second.first;
167}
168BasicBlock *Context::getTargetBlock(const Pointer &Ptr) {
169 if (Ptr.address().getActiveBits() > 64)
170 return nullptr;
171 auto It = ValidBlockTargets.find(Val: Ptr.address().getZExtValue());
172 if (It == ValidBlockTargets.end())
173 return nullptr;
174 // TODO: check the provenance of pointer.
175 return It->second.first;
176}
177
178void MemoryObject::markAsFreed() {
179 State = MemoryObjectState::Freed;
180 Bytes.clear();
181}
182
183void MemoryObject::writeRawBytes(uint64_t Offset, const void *Data,
184 uint64_t Length) {
185 assert(SaturatingAdd(Offset, Length) <= Size && "Write out of bounds");
186 const uint8_t *ByteData = static_cast<const uint8_t *>(Data);
187 for (uint64_t I = 0; I < Length; ++I)
188 Bytes[Offset + I].set(ByteData[I]);
189}
190
191void MemoryObject::writeInteger(uint64_t Offset, const APInt &Int,
192 const DataLayout &DL) {
193 uint64_t BitWidth = Int.getBitWidth();
194 uint64_t IntSize = divideCeil(Numerator: BitWidth, Denominator: 8);
195 assert(SaturatingAdd(Offset, IntSize) <= Size && "Write out of bounds");
196 for (uint64_t I = 0; I < IntSize; ++I) {
197 uint64_t ByteIndex = DL.isLittleEndian() ? I : (IntSize - 1 - I);
198 uint64_t Bits = std::min(a: BitWidth - ByteIndex * 8, b: uint64_t(8));
199 Bytes[Offset + I].set(Int.extractBitsAsZExtValue(numBits: Bits, bitPosition: ByteIndex * 8));
200 }
201}
202void MemoryObject::writeFloat(uint64_t Offset, const APFloat &Float,
203 const DataLayout &DL) {
204 writeInteger(Offset, Int: Float.bitcastToAPInt(), DL);
205}
206void MemoryObject::writePointer(uint64_t Offset, const Pointer &Ptr,
207 const DataLayout &DL) {
208 writeInteger(Offset, Int: Ptr.address(), DL);
209 // TODO: provenance
210}
211
212} // namespace llvm::ubi
213