1#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"
2#include "llvm/ExecutionEngine/Orc/Mangling.h"
3
4using namespace llvm;
5using namespace orc;
6
7bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {
8 std::unique_lock<std::mutex> Lock(Mutex);
9 if (Reoptimizing)
10 return false;
11
12 Reoptimizing = true;
13 return true;
14}
15
16void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {
17 std::unique_lock<std::mutex> Lock(Mutex);
18 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
19 Reoptimizing = false;
20 CurVersion++;
21}
22
23void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {
24 std::unique_lock<std::mutex> Lock(Mutex);
25 assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");
26 Reoptimizing = false;
27}
28
29Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) {
30 ExecutionSession::JITDispatchHandlerAssociationMap WFs;
31 using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);
32 WFs[Mangle("__orc_rt_reoptimize_tag")] =
33 ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(Instance: this,
34 Method: &ReOptimizeLayer::rt_reoptimize);
35 return ES.registerJITDispatchHandlers(JD&: PlatformJD, WFs: std::move(WFs));
36}
37
38void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,
39 ThreadSafeModule TSM) {
40 auto &JD = R->getTargetJITDylib();
41
42 bool HasNonCallable = false;
43 for (auto &KV : R->getSymbols()) {
44 auto &Flags = KV.second;
45 if (!Flags.isCallable())
46 HasNonCallable = true;
47 }
48
49 if (HasNonCallable) {
50 BaseLayer.emit(R: std::move(R), TSM: std::move(TSM));
51 return;
52 }
53
54 auto &MUState = createMaterializationUnitState(TSM);
55
56 if (auto Err = R->withResourceKeyDo(F: [&](ResourceKey Key) {
57 registerMaterializationUnitResource(Key, State&: MUState);
58 })) {
59 ES.reportError(Err: std::move(Err));
60 R->failMaterialization();
61 return;
62 }
63
64 if (auto Err =
65 ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {
66 ES.reportError(Err: std::move(Err));
67 R->failMaterialization();
68 return;
69 }
70
71 auto InitialDests =
72 emitMUImplSymbols(MUState, Version: MUState.getCurVersion(), JD, TSM: std::move(TSM));
73 if (!InitialDests) {
74 ES.reportError(Err: InitialDests.takeError());
75 R->failMaterialization();
76 return;
77 }
78
79 RSManager.emitRedirectableSymbols(MR: std::move(R), InitialDests: std::move(*InitialDests));
80}
81
82Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,
83 ReOptMaterializationUnitID MUID,
84 unsigned CurVersion,
85 ThreadSafeModule &TSM) {
86 return TSM.withModuleDo(F: [&](Module &M) -> Error {
87 Type *I64Ty = Type::getInt64Ty(C&: M.getContext());
88 GlobalVariable *Counter = new GlobalVariable(
89 M, I64Ty, false, GlobalValue::InternalLinkage,
90 Constant::getNullValue(Ty: I64Ty), "__orc_reopt_counter");
91 auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion);
92 if (auto Err = ArgBufferConst.takeError())
93 return Err;
94 GlobalVariable *ArgBuffer =
95 new GlobalVariable(M, (*ArgBufferConst)->getType(), true,
96 GlobalValue::InternalLinkage, (*ArgBufferConst));
97 for (auto &F : M) {
98 if (F.isDeclaration())
99 continue;
100 auto &BB = F.getEntryBlock();
101 auto *IP = &*BB.getFirstInsertionPt();
102 IRBuilder<> IRB(IP);
103 Value *Threshold = ConstantInt::get(Ty: I64Ty, V: CallCountThreshold, IsSigned: true);
104 Value *Cnt = IRB.CreateLoad(Ty: I64Ty, Ptr: Counter);
105 // Use EQ to prevent further reoptimize calls.
106 Value *Cmp = IRB.CreateICmpEQ(LHS: Cnt, RHS: Threshold);
107 Value *Added = IRB.CreateAdd(LHS: Cnt, RHS: ConstantInt::get(Ty: I64Ty, V: 1));
108 (void)IRB.CreateStore(Val: Added, Ptr: Counter);
109 Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cond: Cmp, SplitBefore: IP, Unreachable: false);
110 createReoptimizeCall(M, IP&: *SplitTerminator, ArgBuffer);
111 }
112 return Error::success();
113 });
114}
115
116Expected<SymbolMap>
117ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,
118 uint32_t Version, JITDylib &JD,
119 ThreadSafeModule TSM) {
120 DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;
121 cantFail(Err: TSM.withModuleDo(F: [&](Module &M) -> Error {
122 MangleAndInterner Mangle(ES, M.getDataLayout());
123 for (auto &F : M)
124 if (!F.isDeclaration()) {
125 std::string NewName =
126 (F.getName() + ".__def__." + Twine(Version)).str();
127 RenamedMap[Mangle(F.getName())] = Mangle(NewName);
128 F.setName(NewName);
129 }
130 return Error::success();
131 }));
132
133 auto RT = JD.createResourceTracker();
134 if (auto Err =
135 JD.define(MU: std::make_unique<BasicIRLayerMaterializationUnit>(
136 args&: BaseLayer, args: *getManglingOptions(), args: std::move(TSM)),
137 RT))
138 return Err;
139 MUState.setResourceTracker(RT);
140
141 SymbolLookupSet LookupSymbols;
142 for (auto [K, V] : RenamedMap)
143 LookupSymbols.add(Name: V);
144
145 auto ImplSymbols =
146 ES.lookup(SearchOrder: {{&JD, JITDylibLookupFlags::MatchAllSymbols}}, Symbols: LookupSymbols,
147 K: LookupKind::Static, RequiredState: SymbolState::Resolved);
148 if (auto Err = ImplSymbols.takeError())
149 return Err;
150
151 SymbolMap Result;
152 for (auto [K, V] : RenamedMap)
153 Result[K] = (*ImplSymbols)[V];
154
155 return Result;
156}
157
158void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,
159 ReOptMaterializationUnitID MUID,
160 uint32_t CurVersion) {
161 auto &MUState = getMaterializationUnitState(MUID);
162 if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {
163 SendResult(Error::success());
164 return;
165 }
166
167 ThreadSafeModule TSM = cloneToNewContext(TSMW: MUState.getThreadSafeModule());
168 auto OldRT = MUState.getResourceTracker();
169 auto &JD = OldRT->getJITDylib();
170
171 if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {
172 ES.reportError(Err: std::move(Err));
173 MUState.reoptimizeFailed();
174 SendResult(Error::success());
175 return;
176 }
177
178 auto SymbolDests =
179 emitMUImplSymbols(MUState, Version: CurVersion + 1, JD, TSM: std::move(TSM));
180 if (!SymbolDests) {
181 ES.reportError(Err: SymbolDests.takeError());
182 MUState.reoptimizeFailed();
183 SendResult(Error::success());
184 return;
185 }
186
187 if (auto Err = RSManager.redirect(JD, NewDests: std::move(*SymbolDests))) {
188 ES.reportError(Err: std::move(Err));
189 MUState.reoptimizeFailed();
190 SendResult(Error::success());
191 return;
192 }
193
194 MUState.reoptimizeSucceeded();
195 SendResult(Error::success());
196}
197
198Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer(
199 Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) {
200 size_t ArgBufferSize = SPSReoptimizeArgList::size(Arg: MUID, Args: CurVersion);
201 std::vector<char> ArgBuffer(ArgBufferSize);
202 shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());
203 if (!SPSReoptimizeArgList::serialize(OB, Arg: MUID, Args: CurVersion))
204 return make_error<StringError>(Args: "Could not serealize args list",
205 Args: inconvertibleErrorCode());
206 return ConstantDataArray::get(Context&: M.getContext(), Elts: ArrayRef(ArgBuffer));
207}
208
209void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,
210 GlobalVariable *ArgBuffer) {
211 GlobalVariable *DispatchCtx =
212 M.getGlobalVariable(Name: "__orc_rt_jit_dispatch_ctx");
213 if (!DispatchCtx)
214 DispatchCtx = new GlobalVariable(M, PointerType::get(C&: M.getContext(), AddressSpace: 0),
215 false, GlobalValue::ExternalLinkage,
216 nullptr, "__orc_rt_jit_dispatch_ctx");
217 GlobalVariable *ReoptimizeTag =
218 M.getGlobalVariable(Name: "__orc_rt_reoptimize_tag");
219 if (!ReoptimizeTag)
220 ReoptimizeTag = new GlobalVariable(M, PointerType::get(C&: M.getContext(), AddressSpace: 0),
221 false, GlobalValue::ExternalLinkage,
222 nullptr, "__orc_rt_reoptimize_tag");
223 Function *DispatchFunc = M.getFunction(Name: "__orc_rt_jit_dispatch");
224 if (!DispatchFunc) {
225 std::vector<Type *> Args = {PointerType::get(C&: M.getContext(), AddressSpace: 0),
226 PointerType::get(C&: M.getContext(), AddressSpace: 0),
227 PointerType::get(C&: M.getContext(), AddressSpace: 0),
228 IntegerType::get(C&: M.getContext(), NumBits: 64)};
229 FunctionType *FuncTy =
230 FunctionType::get(Result: Type::getVoidTy(C&: M.getContext()), Params: Args, isVarArg: false);
231 DispatchFunc = Function::Create(Ty: FuncTy, Linkage: GlobalValue::ExternalLinkage,
232 N: "__orc_rt_jit_dispatch", M: &M);
233 }
234 size_t ArgBufferSizeConst =
235 SPSReoptimizeArgList::size(Arg: ReOptMaterializationUnitID{}, Args: uint32_t{});
236 Constant *ArgBufferSize = ConstantInt::get(
237 Ty: IntegerType::get(C&: M.getContext(), NumBits: 64), V: ArgBufferSizeConst, IsSigned: false);
238 IRBuilder<> IRB(&IP);
239 (void)IRB.CreateCall(Callee: DispatchFunc,
240 Args: {DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize});
241}
242
243ReOptimizeLayer::ReOptMaterializationUnitState &
244ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {
245 std::unique_lock<std::mutex> Lock(Mutex);
246 ReOptMaterializationUnitID MUID = NextID;
247 MUStates.emplace(args&: MUID,
248 args: ReOptMaterializationUnitState(MUID, cloneToNewContext(TSMW: TSM)));
249 ++NextID;
250 return MUStates.at(k: MUID);
251}
252
253ReOptimizeLayer::ReOptMaterializationUnitState &
254ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {
255 std::unique_lock<std::mutex> Lock(Mutex);
256 return MUStates.at(k: MUID);
257}
258
259void ReOptimizeLayer::registerMaterializationUnitResource(
260 ResourceKey Key, ReOptMaterializationUnitState &State) {
261 std::unique_lock<std::mutex> Lock(Mutex);
262 MUResources[Key].insert(V: State.getID());
263}
264
265Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {
266 std::unique_lock<std::mutex> Lock(Mutex);
267 for (auto MUID : MUResources[K])
268 MUStates.erase(x: MUID);
269
270 MUResources.erase(Val: K);
271 return Error::success();
272}
273
274void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,
275 ResourceKey SrcK) {
276 std::unique_lock<std::mutex> Lock(Mutex);
277 MUResources[DstK].insert_range(R&: MUResources[SrcK]);
278 MUResources.erase(Val: SrcK);
279}
280